diff --git a/third_party/METIS/.gitignore b/third_party/METIS/.gitignore deleted file mode 100644 index 4796f27827f898a91c2fc13cbdfc15b778c4808d..0000000000000000000000000000000000000000 --- a/third_party/METIS/.gitignore +++ /dev/null @@ -1,61 +0,0 @@ -# Prerequisites -*.d - -# Object files -*.o -*.ko -*.obj -*.elf - -# Linker output -*.ilk -*.map -*.exp - -# Precompiled Headers -*.gch -*.pch - -# Libraries -*.lib -*.a -*.la -*.lo - -# Shared objects (inc. Windows DLLs) -*.dll -*.so -*.so.* -*.dylib - -# Executables -*.exe -*.out -*.app -*.i*86 -*.x86_64 -*.hex - -# Debug files -*.dSYM/ -*.su -*.idb -*.pdb - -# Kernel Module Compile Results -*.mod* -*.cmd -.tmp_versions/ -modules.order -Module.symvers -Mkfile.old -dkms.conf - -# GK things -build/ -graphs/*.part.* -graphs/*.iperm -graphs/*.epart.* -graphs/*.npart.* -.svn/ - diff --git a/third_party/dlpack/.gitignore b/third_party/dlpack/.gitignore deleted file mode 100644 index 21c857e0ff5c9edf91127771ff5e54aa36a03afa..0000000000000000000000000000000000000000 --- a/third_party/dlpack/.gitignore +++ /dev/null @@ -1,32 +0,0 @@ -# Compiled Object files -*.slo -*.lo -*.o -*.obj - -# Precompiled Headers -*.gch -*.pch - -# Compiled Dynamic libraries -*.so -*.dylib -*.dll - -# Fortran module files -*.mod -*.smod - -# Compiled Static libraries -*.lai -*.la -*.a -*.lib - -# Executables -*.exe -*.out -*.app -*~ -build -bin diff --git a/third_party/dmlc-core/.gitignore b/third_party/dmlc-core/.gitignore deleted file mode 100644 index 124d39604b6dfdc90c9c320f1c991be2d85f0cf3..0000000000000000000000000000000000000000 --- a/third_party/dmlc-core/.gitignore +++ /dev/null @@ -1,48 +0,0 @@ -# Compiled Object files -*.slo -*.lo -*.o -*.obj - -# Precompiled Headers -*.gch -*.pch - -# Compiled Dynamic libraries -*.so -*.dylib -*.dll - -# Fortran module files -*.mod - -# Compiled Static libraries -*.lai -*.la -*.a -*.lib - -# Executables -*.exe -*.out -*.app -*~ -config.mk -*.pyc - -# Vim -*.swp -*.swo -*.swn -*.csv -.vimrc - -# Emacs -.clang_complete -deps -recommonmark -build - -# CLion -.idea -cmake-build-* diff --git a/third_party/dmlc-core/make/config.mk b/third_party/dmlc-core/make/config.mk new file mode 100644 index 0000000000000000000000000000000000000000..a6be9ad5934f808bf8b5ea241afd5ce145683ea6 --- /dev/null +++ b/third_party/dmlc-core/make/config.mk @@ -0,0 +1,53 @@ +#----------------------------------------------------- +# dmlc-core: the configuration compile script +# +# This is the default configuration setup for +# If you want to change configuration, do the following steps: +# +# - copy this file to the root of dmlc-core folder +# - modify the configuration you want +# - type make or make -j n on each of the folder +#---------------------------------------------------- + +# choice of compiler +export CC = gcc +export CXX = g++ +export MPICXX = mpicxx + +# choice of archiver +export AR = ar + +# the additional link flags you want to add +ADD_LDFLAGS = + +# the additional compile flags you want to add +ADD_CFLAGS = + +# whether to compile with -fPIC option +# Note: to build shared library(so files), fPIC is required +WITH_FPIC = 1 + +# whether use openmp during compile +USE_OPENMP = 0 + +# whether use HDFS support during compile +USE_HDFS = 0 + +# whether use AWS S3 support during compile +USE_S3 = 0 + +# whether use Azure blob support during compile +USE_AZURE = 0 + +# path to libjvm.so +LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server + +# whether building unittest (gtest is required) +BUILD_TEST=0 + +# path to gtest library (only used when $BUILD_TEST=1) +# there should be an include path in $GTEST_PATH/include and library in $GTEST_PATH/lib +GTEST_PATH= + +# path to third-party dependences such as glog +DEPS_PATH= diff --git a/third_party/googletest/.gitignore b/third_party/googletest/.gitignore deleted file mode 100644 index f08cb72a33cd199478f41be1bd487f916330472c..0000000000000000000000000000000000000000 --- a/third_party/googletest/.gitignore +++ /dev/null @@ -1,84 +0,0 @@ -# Ignore CI build directory -build/ -xcuserdata -cmake-build-debug/ -.idea/ -bazel-bin -bazel-genfiles -bazel-googletest -bazel-out -bazel-testlogs -# python -*.pyc - -# Visual Studio files -.vs -*.sdf -*.opensdf -*.VC.opendb -*.suo -*.user -_ReSharper.Caches/ -Win32-Debug/ -Win32-Release/ -x64-Debug/ -x64-Release/ - -# Ignore autoconf / automake files -Makefile.in -aclocal.m4 -configure -build-aux/ -autom4te.cache/ -googletest/m4/libtool.m4 -googletest/m4/ltoptions.m4 -googletest/m4/ltsugar.m4 -googletest/m4/ltversion.m4 -googletest/m4/lt~obsolete.m4 -googlemock/m4 - -# Ignore generated directories. -googlemock/fused-src/ -googletest/fused-src/ - -# macOS files -.DS_Store -googletest/.DS_Store -googletest/xcode/.DS_Store - -# Ignore cmake generated directories and files. -CMakeFiles -CTestTestfile.cmake -Makefile -cmake_install.cmake -googlemock/CMakeFiles -googlemock/CTestTestfile.cmake -googlemock/Makefile -googlemock/cmake_install.cmake -googlemock/gtest -/bin -/googlemock/gmock.dir -/googlemock/gmock_main.dir -/googlemock/RUN_TESTS.vcxproj.filters -/googlemock/RUN_TESTS.vcxproj -/googlemock/INSTALL.vcxproj.filters -/googlemock/INSTALL.vcxproj -/googlemock/gmock_main.vcxproj.filters -/googlemock/gmock_main.vcxproj -/googlemock/gmock.vcxproj.filters -/googlemock/gmock.vcxproj -/googlemock/gmock.sln -/googlemock/ALL_BUILD.vcxproj.filters -/googlemock/ALL_BUILD.vcxproj -/lib -/Win32 -/ZERO_CHECK.vcxproj.filters -/ZERO_CHECK.vcxproj -/RUN_TESTS.vcxproj.filters -/RUN_TESTS.vcxproj -/INSTALL.vcxproj.filters -/INSTALL.vcxproj -/googletest-distribution.sln -/CMakeCache.txt -/ALL_BUILD.vcxproj.filters -/ALL_BUILD.vcxproj diff --git a/third_party/googletest/googlemock/build-aux/.keep b/third_party/googletest/googlemock/build-aux/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/googletest/googlemock/make/Makefile b/third_party/googletest/googlemock/make/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..386293a0eabd68a687f3fc77840e97a1aca2b762 --- /dev/null +++ b/third_party/googletest/googlemock/make/Makefile @@ -0,0 +1,117 @@ +# A sample Makefile for building both Google Mock and Google Test and +# using them in user tests. This file is self-contained, so you don't +# need to use the Makefile in Google Test's source tree. Please tweak +# it to suit your environment and project. You may want to move it to +# your project's root directory. +# +# SYNOPSIS: +# +# make [all] - makes everything. +# make TARGET - makes the given target. +# make clean - removes all files generated by make. + +# Please tweak the following variable definitions as needed by your +# project, except GMOCK_HEADERS and GTEST_HEADERS, which you can use +# in your own targets but shouldn't modify. + +# Points to the root of Google Test, relative to where this file is. +# Remember to tweak this if you move this file, or if you want to use +# a copy of Google Test at a different location. +GTEST_DIR = ../../googletest + +# Points to the location of the Google Test libraries +GTEST_LIB_DIR = . + +# Points to the root of Google Mock, relative to where this file is. +# Remember to tweak this if you move this file. +GMOCK_DIR = .. + +# Where to find user code. +USER_DIR = ../test + +# Flags passed to the preprocessor. +# Set Google Test and Google Mock's header directories as system +# directories, such that the compiler doesn't generate warnings in +# these headers. +CPPFLAGS += -isystem $(GTEST_DIR)/include -isystem $(GMOCK_DIR)/include + +# Flags passed to the C++ compiler. +CXXFLAGS += -g -Wall -Wextra -pthread -std=c++11 + +# Google Test libraries +GTEST_LIBS = libgtest.a libgtest_main.a libgmock.a libgmock_main.a + +# All tests produced by this Makefile. Remember to add new tests you +# created to the list. +TESTS = gmock_test + +# All Google Test headers. Usually you shouldn't change this +# definition. +GTEST_HEADERS = $(GTEST_DIR)/include/gtest/*.h \ + $(GTEST_DIR)/include/gtest/internal/*.h + +# All Google Mock headers. Note that all Google Test headers are +# included here too, as they are #included by Google Mock headers. +# Usually you shouldn't change this definition. +GMOCK_HEADERS = $(GMOCK_DIR)/include/gmock/*.h \ + $(GMOCK_DIR)/include/gmock/internal/*.h \ + $(GTEST_HEADERS) + +# House-keeping build targets. + +all : $(GTEST_LIBS) $(TESTS) + +clean : + rm -f $(GTEST_LIBS) $(TESTS) *.o + +# Builds gmock.a and gmock_main.a. These libraries contain both +# Google Mock and Google Test. A test should link with either gmock.a +# or gmock_main.a, depending on whether it defines its own main() +# function. It's fine if your test only uses features from Google +# Test (and not Google Mock). + +# Usually you shouldn't tweak such internal variables, indicated by a +# trailing _. +GTEST_SRCS_ = $(GTEST_DIR)/src/*.cc $(GTEST_DIR)/src/*.h $(GTEST_HEADERS) +GMOCK_SRCS_ = $(GMOCK_DIR)/src/*.cc $(GMOCK_HEADERS) + +# For simplicity and to avoid depending on implementation details of +# Google Mock and Google Test, the dependencies specified below are +# conservative and not optimized. This is fine as Google Mock and +# Google Test compile fast and for ordinary users their source rarely +# changes. +gtest-all.o : $(GTEST_SRCS_) + $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) -I$(GMOCK_DIR) $(CXXFLAGS) \ + -c $(GTEST_DIR)/src/gtest-all.cc + +gtest_main.o : $(GTEST_SRCS_) + $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) -I$(GMOCK_DIR) $(CXXFLAGS) \ + -c $(GTEST_DIR)/src/gtest_main.cc + +gmock-all.o : $(GMOCK_SRCS_) + $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) -I$(GMOCK_DIR) $(CXXFLAGS) \ + -c $(GMOCK_DIR)/src/gmock-all.cc + +gmock_main.o : $(GMOCK_SRCS_) + $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) -I$(GMOCK_DIR) $(CXXFLAGS) \ + -c $(GMOCK_DIR)/src/gmock_main.cc + +libgtest.a : gtest-all.o + $(AR) $(ARFLAGS) $@ $^ + +libgtest_main.a : gtest_main.o + $(AR) $(ARFLAGS) $@ $^ + +libgmock.a : gmock-all.o + $(AR) $(ARFLAGS) $@ $^ + +libgmock_main.a : gmock_main.o + $(AR) $(ARFLAGS) $@ $^ + +# Builds a sample test. + +gmock_test.o : $(USER_DIR)/gmock_test.cc $(GMOCK_HEADERS) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/gmock_test.cc + +gmock_test : gmock_test.o $(GTEST_LIBS) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -L$(GTEST_LIB_DIR) -lgmock -lpthread $^ -o $@ diff --git a/third_party/googletest/googletest/make/Makefile b/third_party/googletest/googletest/make/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..b62da67a47dd23e9a3eddb9268b6d5f33fa521fd --- /dev/null +++ b/third_party/googletest/googletest/make/Makefile @@ -0,0 +1,88 @@ +# A sample Makefile for building Google Test and using it in user +# tests. Please tweak it to suit your environment and project. You +# may want to move it to your project's root directory. +# +# SYNOPSIS: +# +# make [all] - makes everything. +# make TARGET - makes the given target. +# make clean - removes all files generated by make. + +# Please tweak the following variable definitions as needed by your +# project, except GTEST_HEADERS, which you can use in your own targets +# but shouldn't modify. + +# Points to the root of Google Test, relative to where this file is. +# Remember to tweak this if you move this file. +GTEST_DIR = .. + +# Points to the location of the Google Test libraries +GTEST_LIB_DIR = . + +# Where to find user code. +USER_DIR = ../samples + +# Flags passed to the preprocessor. +# Set Google Test's header directory as a system directory, such that +# the compiler doesn't generate warnings in Google Test headers. +CPPFLAGS += -isystem $(GTEST_DIR)/include + +# Flags passed to the C++ compiler. +CXXFLAGS += -g -Wall -Wextra -pthread -std=c++11 + +# Google Test libraries +GTEST_LIBS = libgtest.a libgtest_main.a + +# All tests produced by this Makefile. Remember to add new tests you +# created to the list. +TESTS = sample1_unittest + +# All Google Test headers. Usually you shouldn't change this +# definition. +GTEST_HEADERS = $(GTEST_DIR)/include/gtest/*.h \ + $(GTEST_DIR)/include/gtest/internal/*.h + +# House-keeping build targets. + +all : $(GTEST_LIBS) $(TESTS) + +clean : + rm -f $(GTEST_LIBS) $(TESTS) *.o + +# Builds gtest.a and gtest_main.a. + +# Usually you shouldn't tweak such internal variables, indicated by a +# trailing _. +GTEST_SRCS_ = $(GTEST_DIR)/src/*.cc $(GTEST_DIR)/src/*.h $(GTEST_HEADERS) + +# For simplicity and to avoid depending on Google Test's +# implementation details, the dependencies specified below are +# conservative and not optimized. This is fine as Google Test +# compiles fast and for ordinary users its source rarely changes. +gtest-all.o : $(GTEST_SRCS_) + $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) $(CXXFLAGS) -c \ + $(GTEST_DIR)/src/gtest-all.cc + +gtest_main.o : $(GTEST_SRCS_) + $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) $(CXXFLAGS) -c \ + $(GTEST_DIR)/src/gtest_main.cc + +libgtest.a : gtest-all.o + $(AR) $(ARFLAGS) $@ $^ + +libgtest_main.a : gtest-all.o gtest_main.o + $(AR) $(ARFLAGS) $@ $^ + +# Builds a sample test. A test should link with either gtest.a or +# gtest_main.a, depending on whether it defines its own main() +# function. + +sample1.o : $(USER_DIR)/sample1.cc $(USER_DIR)/sample1.h $(GTEST_HEADERS) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/sample1.cc + +sample1_unittest.o : $(USER_DIR)/sample1_unittest.cc \ + $(USER_DIR)/sample1.h $(GTEST_HEADERS) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/sample1_unittest.cc + +sample1_unittest : sample1.o sample1_unittest.o $(GTEST_LIBS) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -L$(GTEST_LIB_DIR) -lgtest_main -lpthread $^ -o $@ diff --git a/third_party/googletest/googletest/scripts/test/Makefile b/third_party/googletest/googletest/scripts/test/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..cdff584637b7a6e9df1fa43ce8f588c43815e561 --- /dev/null +++ b/third_party/googletest/googletest/scripts/test/Makefile @@ -0,0 +1,59 @@ +# A Makefile for fusing Google Test and building a sample test against it. +# +# SYNOPSIS: +# +# make [all] - makes everything. +# make TARGET - makes the given target. +# make check - makes everything and runs the built sample test. +# make clean - removes all files generated by make. + +# Points to the root of fused Google Test, relative to where this file is. +FUSED_GTEST_DIR = output + +# Paths to the fused gtest files. +FUSED_GTEST_H = $(FUSED_GTEST_DIR)/gtest/gtest.h +FUSED_GTEST_ALL_CC = $(FUSED_GTEST_DIR)/gtest/gtest-all.cc + +# Where to find the sample test. +SAMPLE_DIR = ../../samples + +# Where to find gtest_main.cc. +GTEST_MAIN_CC = ../../src/gtest_main.cc + +# Flags passed to the preprocessor. +# We have no idea here whether pthreads is available in the system, so +# disable its use. +CPPFLAGS += -I$(FUSED_GTEST_DIR) -DGTEST_HAS_PTHREAD=0 + +# Flags passed to the C++ compiler. +CXXFLAGS += -g + +all : sample1_unittest + +check : all + ./sample1_unittest + +clean : + rm -rf $(FUSED_GTEST_DIR) sample1_unittest *.o + +$(FUSED_GTEST_H) : + ../fuse_gtest_files.py $(FUSED_GTEST_DIR) + +$(FUSED_GTEST_ALL_CC) : + ../fuse_gtest_files.py $(FUSED_GTEST_DIR) + +gtest-all.o : $(FUSED_GTEST_H) $(FUSED_GTEST_ALL_CC) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(FUSED_GTEST_DIR)/gtest/gtest-all.cc + +gtest_main.o : $(FUSED_GTEST_H) $(GTEST_MAIN_CC) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(GTEST_MAIN_CC) + +sample1.o : $(SAMPLE_DIR)/sample1.cc $(SAMPLE_DIR)/sample1.h + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(SAMPLE_DIR)/sample1.cc + +sample1_unittest.o : $(SAMPLE_DIR)/sample1_unittest.cc \ + $(SAMPLE_DIR)/sample1.h $(FUSED_GTEST_H) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(SAMPLE_DIR)/sample1_unittest.cc + +sample1_unittest : sample1.o sample1_unittest.o gtest-all.o gtest_main.o + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $^ -o $@ diff --git a/third_party/libxsmm/.gitignore b/third_party/libxsmm/.gitignore deleted file mode 100644 index 5429e7db7578b1f5691d8d8f15a39aa6056be4c9..0000000000000000000000000000000000000000 --- a/third_party/libxsmm/.gitignore +++ /dev/null @@ -1,96 +0,0 @@ -My Amplifier* -VTune Amplifier Results -libxsmm*-* -libxsmm*_* -opentuner.db -bin/libxsmm_generator -include/libxsmm_version.h -include/libxsmm.f -lib/libxsmm* -lib/module -ide/GPUCache -ide/_vs*-*.bat -ide/.vs -ide/obj -ide/r*ah -samples2 -samples/*/bin -samples/*/*.sln -samples/*/*.dat -samples/*/*.pdf -samples/*/*.png -inspector* -licenses -bazel-* -python* -html -site -bin -tmp -obj -.couscous -.vscode -.state -.make -.vs -threadsafety-*.txt -malloc-trace-*.txt -blas-trace-*.txt -codecov-*.txt -keywords.txt -notes.txt -err*.txt -out*.txt -log.txt -_*.txt -.env.sh -.env_?????? -.tool_??????.sh -.libxsmm_??????.* -*.lastcodeanalysissucceeded -*.amplxeproj -*.advixeproj -*.inspxeproj -*.stackdump -*.opensdf -*.opendb -*.VC.db -*.dylib -*.sarif -*.docx -*.user -*.tlog -*.gcno -*.gcda -*.gcov -*.html -*.iobj -*.ipdb -*.URL -*.log -*.suo -*.exe -*.zip -*.pyc -*.sdf -*.ilk -*.pdb -*.vsp -*.obj -*.lib -*.mod -*.bin -*.jit -*.smm -*.soa -*.csr -*.dll -*.mhd -*.out -*.err -*.so -*.o -*.a -*.i -*.s -*.*~ diff --git a/third_party/libxsmm/.state b/third_party/libxsmm/.state new file mode 100644 index 0000000000000000000000000000000000000000..720e0337bd2cd0e0409c110e090550c359147405 --- /dev/null +++ b/third_party/libxsmm/.state @@ -0,0 +1,68 @@ +"ABSDIR=/public$HOME/dgl/third_party/libxsmm\n" +"ABSLIBS=0\n" +"ALPHA=1\n" +"AR=/usr/bin/gcc-ar\n" +"ASIMD=0\n" +"ASNEEDED=0\n" +"AUTOPIN=0\n" +"BETA=1\n" +"BLAS_CLDFLAGS=-lm\n" +"CACHE=1\n" +"CACHELINE=64\n" +"CC=gcc\n" +"CC_NAME=gcc\n" +"CC_VERSION=8.5.0\n" +"CFLAGS=-fPIC -Wall -O2 -fopenmp-simd -funroll-loops -ftree-vectorize -fdata-sections -ffunction-sections -fvisibility=hidden -pthread\n" +"COMMAND=/usr/bin/command\n" +"COMPATIBLE=0\n" +"CPUFLAGS_X86=fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid amd_dcm aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb hw_pstate ssbd ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec xgetbv1 xsaves clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif overflow_recov succor smca sme sev sev_es\n" +"CTARGET=-msse4.2\n" +"CXXFLAGS=-fPIC -std=c++14 -Wall -O2 -fopenmp-simd -funroll-loops -ftree-vectorize -fdata-sections -ffunction-sections -fvisibility=hidden -fvisibility-inlines-hidden -pthread\n" +"CXXLDFLAGS=-lc\n" +"CXX_NAME=g++\n" +"CXX_VERSION=8.5.0\n" +"DBG=0\n" +"FAT=0\n" +"FCFLAGS=-fPIC -O2 -ftree-vectorize -fdata-sections -ffunction-sections\n" +"FLD=gcc\n" +"FLUSH=stdbuf -o0 -e0\n" +"FORTRAN=0\n" +"GCC_VERSION=8.5.0\n" +"GLIBC=1\n" +"ILP64=0\n" +"INSTRUMENT=0\n" +"INTRINSICS=1006\n" +"IPO=0\n" +"JITDUMP=0\n" +"LD=gcc\n" +"LDFLAGS=-Wl,--gc-sections -Wl,-z,relro,-z,now -lm -lrt -ldl\n" +"LIBATOMIC=0\n" +"LIBC=-lc\n" +"LNKSOFT=1\n" +"MAINTAINER=0\n" +"MALLOC=0\n" +"MIC=0\n" +"MKL=0\n" +"MNAME=x86_64\n" +"OFFLOAD=0\n" +"OMP=0\n" +"OMPFLAG=-fopenmp\n" +"OMPLIB=-L/usr/lib/gcc/x86_64-redhat-linux/8/ -lgomp\n" +"OMPRT=gomp\n" +"PERF=0\n" +"PLATFORM=0\n" +"PREFETCH=1\n" +"SONAMELNK=2\n" +"SPACES=0\n" +"STATIC=1\n" +"SYM=0\n" +"THREADS=1\n" +"THRESHOLD=0\n" +"TRACE=0\n" +"UNAME=Linux\n" +"VISIBILITY=0\n" +"WCHECK=0\n" +"WERROR_CFLAG=-Werror\n" +"WRAP=1\n" +"XLD=g++\n" +"\n" diff --git a/third_party/libxsmm/.theme/main.html b/third_party/libxsmm/.theme/main.html new file mode 100644 index 0000000000000000000000000000000000000000..bc42330bb9eab5ca884b5ce30b74fa90440c2957 --- /dev/null +++ b/third_party/libxsmm/.theme/main.html @@ -0,0 +1,14 @@ +{% extends "base.html" %} + +{% block site_meta %} + + + {{ super() }} +{% endblock %} + +{% block footer %} +
+ {%- if config.copyright %} +

{{ config.copyright }}

+ {%- endif %} +{% endblock %} diff --git a/third_party/libxsmm/bin/.make b/third_party/libxsmm/bin/.make new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/libxsmm/bin/libxsmm_gemm_generator b/third_party/libxsmm/bin/libxsmm_gemm_generator new file mode 100755 index 0000000000000000000000000000000000000000..b69afe7e5ca42e5ced46197f316127e1df29c569 Binary files /dev/null and b/third_party/libxsmm/bin/libxsmm_gemm_generator differ diff --git a/third_party/libxsmm/documentation/libxsmm-dev.pptm b/third_party/libxsmm/documentation/libxsmm-dev.pptm new file mode 100644 index 0000000000000000000000000000000000000000..2251456058af800dd2a3a9a953f3d57ac0ed6fc5 Binary files /dev/null and b/third_party/libxsmm/documentation/libxsmm-dev.pptm differ diff --git a/third_party/libxsmm/documentation/libxsmm_aux.md b/third_party/libxsmm/documentation/libxsmm_aux.md new file mode 100644 index 0000000000000000000000000000000000000000..4626cfba18c929f03094debcd8617e63dbcf60ec --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_aux.md @@ -0,0 +1,203 @@ +## Service Functions + +### Target Architecture + +This functionality is available for the C and Fortran interface. There are [ID based](https://github.com/hfp/libxsmm/blob/master/include/libxsmm_cpuid.h#L47) (same for C and Fortran) and string based functions to query the code path (as determined by the CPUID), or to set the code path regardless of the presented CPUID features. The latter may degrade performance if a lower set of instruction set extensions is requested, which can be still useful for studying the performance impact of different instruction set extensions. +**Note**: There is no additional check performed if an unsupported instruction set extension is requested, and incompatible JIT-generated code may be executed (unknown instruction signaled). + +```C +int libxsmm_get_target_archid(void); +void libxsmm_set_target_archid(int id); + +const char* libxsmm_get_target_arch(void); +void libxsmm_set_target_arch(const char* arch); +``` + +Available code paths (IDs and corresponding strings): + +* LIBXSMM_TARGET_ARCH_GENERIC: "**generic**", "none", "0" +* LIBXSMM_X86_GENERIC: "**x86**", "x64", "sse2" +* LIBXSMM_X86_SSE3: "**sse3**" +* LIBXSMM_X86_SSE42: "**wsm**", "nhm", "sse4", "sse4_2", "sse4.2" +* LIBXSMM_X86_AVX: "**snb**", "avx" +* LIBXSMM_X86_AVX2: "**hsw**", "avx2" +* LIBXSMM_X86_AVX512_MIC: "**knl**", "mic" +* LIBXSMM_X86_AVX512_KNM: "**knm**" +* LIBXSMM_X86_AVX512_CORE: "**skx**", "skl", "avx3", "avx512" +* LIBXSMM_X86_AVX512_CLX: "**clx**" +* LIBXSMM_X86_AVX512_CPX: "**cpx**" +* LIBXSMM_X86_AVX512_SPR: "**spr**" + +The **bold** names are returned by `libxsmm_get_target_arch` whereas `libxsmm_set_target_arch` accepts all of the above strings (similar to the environment variable LIBXSMM_TARGET). + +### Verbosity Level + +The [verbose mode](index.md#verbose-mode) (level of verbosity) can be controlled using the C or Fortran API, and there is an environment variable which corresponds to `libxsmm_set_verbosity` (LIBXSMM_VERBOSE). + +```C +int libxsmm_get_verbosity(void); +void libxsmm_set_verbosity(int level); +``` + +### Timer Facility + +Due to the performance oriented nature of LIBXSMM, timer-related functionality is available for the C and Fortran interface ([libxsmm_timer.h](https://github.com/hfp/libxsmm/blob/master/include/libxsmm_timer.h#L37) and [libxsmm.f](https://github.com/hfp/libxsmm/blob/master/include/libxsmm.f#L32)). The timer is used in many of the [code samples](https://github.com/hfp/libxsmm/tree/master/samples) to measure the duration of executing a region of the code. The timer is based on a monotonic clock tick, which uses a platform-specific resolution. The counter may rely on the time stamp counter instruction (RDTSC), which is not necessarily counting CPU cycles (reasons are out of scope in this context). However, `libxsmm_timer_ncycles` delivers raw clock ticks (RDTSC). + +```C +typedef unsigned long long libxsmm_timer_tickint; +libxsmm_timer_tickint libxsmm_timer_tick(void); +double libxsmm_timer_duration( + libxsmm_timer_tickint tick0, + libxsmm_timer_tickint tick1); +libxsmm_timer_tickint libxsmm_timer_ncycles( + libxsmm_timer_tickint tick0, + libxsmm_timer_tickint tick1); +``` + +### User-Data Dispatch + +To register a user-defined key-value pair with LIBXSMM's fast key-value store, the key must be binary reproducible. Structured key-data (`struct` or `class` type which can be padded in a compiler-specific fashion) must be completely cleared, i.e., all gaps may be zero-filled before initializing data members (`memset(&mykey, 0, sizeof(mykey))`). This is because some compilers can leave padded data uninitialized, which breaks binary reproducible keys, hence the flow is: claring heterogeneous keys (struct), initialization (members), and registration. The size of the key is arbitrary but limited to LIBXSMM_DESCRIPTOR_MAXSIZE (96 Byte), and the size of the value can be of an arbitrary size. The given value is copied and may be initialized at registration-time or when dispatched. Registered data is released at program termination but can be manually unregistered and released (`libxsmm_xrelease`), e.g., to register a larger value for an existing key. + +```C +void* libxsmm_xregister(const void* key, size_t key_size, size_t value_size, const void* value_init); +void* libxsmm_xdispatch(const void* key, size_t key_size); +``` + +The Fortran interface is designed to follow the same flow as the C language: (1) `libxsmm_xdispatch` is used to query the value, and (2) if the value is a NULL-pointer, it is registered per `libxsmm_xregister`. Similar to C (`memset`), structured key-data must be zero-filled (`libxsmm_xclear`) even when followed by an element-wise initialization. A key based on a contiguous array has no gaps by definition and it is enough to initialize the array elements. A [Fortran example](https://github.com/hfp/libxsmm/blob/master/samples/utilities/dispatch/dispatch_udt.f) is given as part of the [Dispatch Microbenchmark](https://github.com/hfp/libxsmm/tree/master/samples/utilities/dispatch). + +```Fortran +FUNCTION libxsmm_xregister(key, keysize, valsize, valinit) + TYPE(C_PTR), INTENT(IN), VALUE :: key + TYPE(C_PTR), INTENT(IN), VALUE, OPTIONAL :: valinit + INTEGER(C_INT), INTENT(IN) :: keysize, valsize + TYPE(C_PTR) :: libxsmm_xregister +END FUNCTION + +FUNCTION libxsmm_xdispatch(key, keysize) + TYPE(C_PTR), INTENT(IN), VALUE :: key + INTEGER(C_INT), INTENT(IN) :: keysize + TYPE(C_PTR) :: libxsmm_xdispatch +END FUNCTION +``` + +**Note**: This functionality can be used to, e.g., dispatch multiple kernels in one step if a code location relies on multiple kernels. This way, one can pay the cost of dispatch one time per task rather than according to the number of JIT-kernels used by this task. However, the functionality is not limited to multiple kernels but any data can be registered and queried. User-data dispatch uses the same implementation as regular code-dispatch. + +### Memory Allocation + +The C interface ([libxsmm_malloc.h](https://github.com/hfp/libxsmm/blob/master/include/libxsmm_malloc.h)) provides functions for aligned memory one of which allows to specify the alignment (or to request an automatically selected alignment). The automatic alignment is also available with a `malloc` compatible signature. The size of the automatic alignment depends on a heuristic, which uses the size of the requested buffer. +**Note**: The function `libxsmm_free` must be used to deallocate buffers allocated by LIBXSMM's allocation functions. + +```C +void* libxsmm_malloc(size_t size); +void* libxsmm_aligned_malloc(size_t size, size_t alignment); +void* libxsmm_aligned_scratch(size_t size, size_t alignment); +void libxsmm_free(const volatile void* memory); +int libxsmm_get_malloc_info(const void* m, libxsmm_malloc_info* i); +int libxsmm_get_scratch_info(libxsmm_scratch_info* info); +``` + +The library exposes two memory allocation domains: (1) default memory allocation, and (2) scratch memory allocation. There are similar service functions for both domains that allow to customize the allocation and deallocation function. The "context form" even supports a user-defined "object", which may represent an allocator or any other external facility. To set the allocator of the default domain is analogous to setting the allocator of the scratch memory domain (shown below). + +```C +int libxsmm_set_scratch_allocator(void* context, + libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn); +int libxsmm_get_scratch_allocator(void** context, + libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn); +``` + +The scratch memory allocation is very effective and delivers a decent speedup over subsequent regular memory allocations. In contrast to the default allocator, a watermark for repeatedly allocated and deallocated buffers is established. The scratch memory domain is (arbitrarily) limited to 4 GB of memory which can be adjusted to a different number of Bytes (available per [libxsmm_malloc.h](https://github.com/hfp/libxsmm/blob/master/include/libxsmm_malloc.h), and also per environment variable LIBXSMM_SCRATCH_LIMIT with optional "k|K", "m|M", "g|G" units, unlimited per "-1"). + +```C +void libxsmm_set_scratch_limit(size_t nbytes); +size_t libxsmm_get_scratch_limit(void); +``` + +By establishing a pool of "temporary" memory, the cost of repeated allocation and deallocation cycles is avoided when the watermark is reached. The scratch memory is scope-oriented with a limited number of pools for buffers of different life-time or held for different threads. The [verbose mode](index.md#verbose-mode) with a verbosity level of at least two (LIBXSMM_VERBOSE=2) shows some statistics about the populated scratch memory. + +```bash +Scratch: 173 MB (mallocs=5, pools=1) +``` + +To improve thread-scalability and to avoid frequent memory allocation/deallocation, the scratch memory allocator can be leveraged by [intercepting existing malloc/free calls](libxsmm_tune.md#intercepted-allocations). + +**Note**: be careful with scratch memory as it only grows during execution (in between `libxsmm_init` and `libxsmm_finalize` unless `libxsmm_release_scratch` is called). This is true even when `libxsmm_free` is (and should be) used! + +### Meta Image File I/O + +Loading and storing data (I/O) is normally out of LIBXSMM's scope. However, comparing results (correctness) or writing files for visual inspection is clearly desired. This is particularly useful for the DNN domain. The MHD library domain provides support for the Meta Image File format (MHD). Tools such as [ITK-SNAP](http://itksnap.org/) or [ParaView](https://www.paraview.org/) can be used to inspect, compare, and modify images (even beyond two-dimensional images). + +Writing an image is per `libxsmm_mhd_write`, and loading an image is split in two stages: (1) `libxsmm_mhd_read_header`, and (2) `libxsmm_mhd_read`. The first step allows to allocate a properly sized buffer, which is then used to obtain the data per `libxsmm_mhd_read`. When reading data, an on-the-fly type conversion is supported. Further, data that is already in memory can be compared against file-data without allocating memory or reading this file into memory. + +To load an image from a familiar format (JPG, PNG, etc.), one may save the raw data using for instance [IrfanView](http://www.irfanview.com/) and rely on a "header-only" MHD-file (plain text). This may look like: + +```ini +NDims = 2 +DimSize = 202 134 +ElementType = MET_UCHAR +ElementNumberOfChannels = 1 +ElementDataFile = mhd_image.raw +``` + +In the above case, a single channel (gray-scale) 202x134-image is described with pixel data stored separately (`mhd_image.raw`). Multi-channel images are expected to interleave the pixel data. The pixel type is per `libxsmm_mhd_elemtype` ([libxsmm_mhd.h](https://github.com/hfp/libxsmm/blob/master/include/libxsmm_mhd.h#L38)). + +### Thread Synchronization + +LIBXSMM comes with a number of light-weight abstraction layers (macro and API-based), which are distinct from the internal API (include files in [src](https://github.com/hfp/libxsmm/tree/master/src) directory) and that are exposed for general use (and hence part of the [include](https://github.com/hfp/libxsmm/tree/master/include) directory). + +The synchronization layer is mainly based on macros: LIBXSMM_LOCK_\* provide spin-locks, mutexes, and reader-writer locks (LIBXSMM_LOCK_SPINLOCK, LIBXSMM_LOCK_MUTEX, and LIBXSMM_LOCK_RWLOCK respectively). Usually the spin-lock is also named LIBXSMM_LOCK_DEFAULT. The implementation is intentionally based on OS-native primitives unless LIBXSMM is reconfigured (per LIBXSMM_LOCK_SYSTEM) or built using `make OMP=1` (using OpenMP inside of the library is not recommended). The life-cycle of a lock looks like: + +```C +/* attribute variable and lock variable */ +LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_DEFAULT) attr; +LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_DEFAULT) lock; +/* attribute initialization */ +LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_DEFAULT, &attr); +/* lock initialization per initialized attribute */ +LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_DEFAULT, &lock, &attr); +/* the attribute can be destroyed */ +LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_DEFAULT, &attr); +/* lock destruction (usage: see below/next code block) */ +LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_DEFAULT, &lock); +``` + +Once the lock is initialized (or an array of locks), it can be exclusively locked or try-locked, and released at the end of the locked section (LIBXSMM_LOCK_ACQUIRE, LIBXSMM_LOCK_TRYLOCK, and LIBXSMM_LOCK_RELEASE respectively): + +```C +LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_DEFAULT, &lock); +/* locked code section */ +LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_DEFAULT, &lock); +``` + +If the lock-kind is LIBXSMM_LOCK_RWLOCK, non-exclusive a.k.a. shared locking allows to permit multiple readers (LIBXSMM_LOCK_ACQREAD, LIBXSMM_LOCK_TRYREAD, and LIBXSMM_LOCK_RELREAD) if the lock is not acquired exclusively (see above). An attempt to only read-lock anything else but an RW-lock is an exclusive lock (see above). + +```C +if (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK) == + LIBXSMM_LOCK_TRYREAD(LIBXSMM_LOCK_RWLOCK, &rwlock)) +{ /* locked code section */ + LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &rwlock); +} +``` + +Locking different sections for read (LIBXSMM_LOCK_ACQREAD, LIBXSMM_LOCK_RELREAD) and write (LIBXSMM_LOCK_ACQUIRE, LIBXSMM_LOCK_RELEASE) may look like: + +```C +LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &rwlock); +/* locked code section: only reads are performed */ +LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &rwlock); + +LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &rwlock); +/* locked code section: exclusive write (no R/W) */ +LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &rwlock); +``` + +For a lock not backed by an OS level primitive (fully featured lock), the synchronization layer also provides a simple lock based on atomic operations: + +```C +static union { char pad[LIBXSMM_CACHELINE]; volatile LIBXSMM_ATOMIC_LOCKTYPE state; } lock; +LIBXSMM_ATOMIC_ACQUIRE(&lock.state, LIBXSMM_SYNC_NPAUSE, LIBXSMM_ATOMIC_RELAXED); +/* locked code section */ +LIBXSMM_ATOMIC_RELEASE(&lock.state, LIBXSMM_ATOMIC_RELAXED); +``` + +In addition to the LIBXSMM_LOCK_\* macros or LIBXSMM_ATOMIC_LOCKTYPE, API-based lock primitives are also available (libxsmm_mutex_\*, and libxsmm_rwlock_\*). However, the underlying implementation of the latter is experimental. + diff --git a/third_party/libxsmm/documentation/libxsmm_be.md b/third_party/libxsmm/documentation/libxsmm_be.md new file mode 100644 index 0000000000000000000000000000000000000000..f84fe381de3eb8eb1e80aec4973b87b56ce2ad4b --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_be.md @@ -0,0 +1,76 @@ +## Backend + +### Code Generator (JIT) + +There can be situations in which it is up-front not clear which problem-sizes will be needed when running an application. To leverage LIBXSMM's high-performance kernels, the library implements a JIT (Just-In-Time) code generation backend which generates the requested kernels on the fly (in-memory). This is accomplished by emitting the corresponding byte-code directly into an executable buffer. The actual JIT code is generated per the CPUID flags, and therefore does not rely on the code path selected when building the library. In the current implementation, some limitations apply to the JIT backend specifically: + +1. To stay agnostic to any threading model used, Pthread mutexes are guarding the updates of the JIT'ted code cache (link line with `-lpthread` is required); building with OMP=1 employs an OpenMP critical section as an alternative locking mechanism. +2. There is limited support for the Windows calling convention (only kernels without prefetch signature). + +The JIT backend can also be disabled at build time (`make JIT=0`) as well as at runtime (`LIBXSMM_TARGET=0`, or anything prior to Intel AVX). The latter is an environment variable which allows to set a code path independent of the CPUID (LIBXSMM_TARGET=0|1|sse|snb|hsw|knl|knm|skx|clx|cpx|spr). Please note that LIBXSMM_TARGET cannot enable the JIT backend if it was disabled at build time (JIT=0). + +One can use the afore mentioned THRESHOLD parameter to control the matrix sizes for which the JIT compilation will be automatically performed. However, explicitly requested kernels (by calling `libxsmm_?mmdispatch`) fall not under a threshold for the problem-size. In any case, JIT code generation can be used for accompanying statically generated code. + +### Generator Driver + +In rare situations, it might be useful to directly incorporate generated C code (with inline assembly regions). This is accomplished by invoking a driver program (with certain command line arguments). + +**Note**: The stand-alone generator-driver is considered legacy (deprecated). Associated functionality may be removed and future instruction set extensions may not be addressed with printed assembly code. The cost of dispatching JIT-code for every code region of an application, and for every visit of such region, can be amortized in several ways and without dispensing JIT-generated code. Dispatching [multiple kernels at once](libxsmm_aux.md#user-data-dispatch) or (most effectively) tabulating JIT'ted function pointers manually, can elleviate or remove first-time code generation and (more important) the cost of subsequently dispatching kernels (when code was already JIT-generated). + +The generator driver program is usually built as part of LIBXSMM's build process, but also available as a separate build target: + +```bash +make generator +bin/libxsmm_gemm_generator +``` + +The code generator driver program accepts the following arguments: + +1. Select: dense, dense_asm, sparse, sparse_csr, or sparse_csr_reg +2. Filename of a file to append to +3. Routine name to be created +4. M parameter +5. N parameter +6. K parameter +7. LDA (0 indicates A is sparse if 1st arg. is "sparse*") +8. LDB (0 indicates B is sparse if 1st arg. is "sparse*") +9. LDC parameter +10. Alpha (1) +11. Beta: (0 or 1) +12. Alignment override for A (1 auto, 0 unalignment) +13. Alignment override for C (1 auto, 0 unalignment) +14. Architecture (noarch, wsm, snb, hsw, knl, knm, skx, clx, cpx) +15. Prefetch strategy, see below (only nopf or pfsigonly for "sparse*") +16. SP (single-precision), DP (double-recision), or I16 (only "dense*") +17. CSC file in Matrix market format (only if 1st arg. is "sparse*"). + +The prefetch strategy can be: + +1. "nopf": data is not prefetched, just three arguments: A, B, and C +2. "pfsigonly": no prefetches, kernel signature: A, B, C, A', B', and C' +3. "BL2viaC": uses accesses to C to prefetch B' +4. "AL2": uses accesses to A to prefetch A +5. "curAL2": prefetches current A ahead in the kernel +6. "AL2_BL2viaC": combines AL2 and BL2viaC +7. "curAL2_BL2viaC": combines curAL2 and BL2viaC + +Here are some examples of invoking the driver program: + +```bash +bin/libxsmm_gemm_generator dense foo.c foo 16 16 16 32 32 32 1 1 1 1 hsw nopf DP +bin/libxsmm_gemm_generator dense_asm foo.c foo 16 16 16 32 32 32 1 1 1 1 knl AL2_BL2viaC DP +bin/libxsmm_gemm_generator sparse foo.c foo 16 16 16 32 0 32 1 1 1 1 hsw nopf DP bar.csc +``` + +Please note, there are additional examples given in samples/generator and samples/seissol. + +### Development Concepts + +The low-level code generator is hosted by a single translation unit ([src/generator_x86_instructions.c](https://github.com/hfp/libxsmm/blob/master/src/generator_x86_instructions.h)). The code generator emits instructions as enumerated in [src/generator_common.h](https://github.com/hfp/libxsmm/blob/master/src/generator_common.h). A kernel then is a buffered stream of instructions in either binary/encoded or textual form. The latter is leveraged by stand-alone generator drivers that can print C functions with an assembly section (inline). A [generator driver](#generator-driver) may exists for some of LIBXSMM's function domains. Please note that emitting the textual form is not needed to inspect the emitted code since the binary encoded form can be easily disassembled ([objdump](index.md#objdump)). + +The binary encoded form is directly suitable for execution by casting the code-buffer into a function-pointer of the corresponding signature. It is advised to rely on LIBXSMM's internal memory allocation routines to acquire an executable buffer (see libxsmm_malloc_flags, libxsmm_xmalloc, and libxsmm_malloc_attrib in [src/libxsmm_main.h](https://github.com/hfp/libxsmm/blob/master/src/libxsmm_main.h)). This ensures correct behavior in security-hardened environments. As a bonus, [profiler support](libxsmm_prof.md) for the emitted code is enabled transparently. + +To debug the JIT'ted code, GNU GDB can be used to disassemble a given memory address (`disas address,+length`). Having the code disassembled side-by-side (while debugging) helps to look ahead and to have some orientation. For the latter, [objdump](index.md#objdump) can be used to acquire the source code (assembly) along with hexadecimal line numbers (length). The offset position (for GDB's disas) directly corresponds to objectdump's line numbers. + +The kernel development is much like assembly programming, except that an API is used to emit instructions. For further reference, some existing source code for building kernels can be inspected (e.g., matcopy). This may help to capture the concept of mapping registers (basically a table to avoid hard-coding register names). + diff --git a/third_party/libxsmm/documentation/libxsmm_compat.md b/third_party/libxsmm/documentation/libxsmm_compat.md new file mode 100644 index 0000000000000000000000000000000000000000..e8274b01ef265ce67538436d1cb929a3f61e946c --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_compat.md @@ -0,0 +1,97 @@ +## Linux + +All Linux distributions are meant to be fully supported (please [report](https://github.com/hfp/libxsmm/issues/new) any compatibility issue). A shared library (`STATIC=0`) necessarily implies some performance hit when accessing thread-local memory (contended multicore execution). The GNU Compiler Collection prior to v5.1 may imply performance hits in some CPUID-dispatched code paths (non-JIT). + +> In case of outdated Binutils, compilation can fail to assemble code that originates from code sections using Intrinsics (see issue [#170](https://github.com/hfp/libxsmm/issues/170) and [#212](https://github.com/hfp/libxsmm/issues/212#issuecomment-394620082)). To resolve the problem, please use `INTRINSICS=1` along with the desired target e.g., `AVX=3 MIC=0`, or `AVX=2`. + +## CRAY + +In addition to the regular Linux support, The CRAY Compiling Environment (CCE) is supported: Intel Compiler as well as the GNU Compiler Collection are detected even when invoked per CCE, and the CRAY compiler is likely configured to build for the architecture of the compute nodes and hence the compiler is sufficiently treated without specific build flags (`COMPATIBLE=1` is implicitly set). The CCE may suppress to build a shared library (`STATIC=0`), which also affects the TRACE facility (requires dynamic linkage even for static archives). + +```bash +make CXX=CC CC=cc FC=ftn +``` + +The compatibility settings imply minor issues when using the CRAY compiler: full control and [customization](http://libxsmm.readthedocs.io/libxsmm_tune/) is not implemented, enabling symbols (`SYM=1`) appears to imply an unoptimized debug-build (due to the `-g` flag being present). Some sample codes/benchmarks enable symbols but are meant to not enable debug-code. The LIBXSMM library however is built without symbols by default. + +## Windows + +### Microsoft Windows + +Microsoft Windows is [supported](https://github.com/hfp/libxsmm/wiki/Q&A#what-operating-systems-are-covered-by-libxsmm-and-what-about-microsoft-windows) using the Microsoft Visual Studio environment (no `make`). It is advised to review the build settings. However, the following configurations are available: `debug`, `release`, and release mode with `symbols`. JIT-code generation is enabled but limited to the MM domain (GEMM kernels and matcopy kernels; no transpose kernels). GEMM kernels with prefetch signature remain as non-prefetch kernels i.e., prefetch locations are ignored due to the effort of fully supporting the Windows calling convention. As a workaround and to properly preserve caller-state, each JIT-kernel call may be wrapped by an own function. + +### Cygwin + +Cygwin (non-MinGW) is fully supported. Please note, that all limitations of Microsoft Windows apply. + +```bash +make +``` + +LIBXSMM can be built as a static library as well as a dynamic link library (STATIC=0). + +### MinGW/Cygwin + +This is about the Cygwin-hosted bits of MinGW. The `-fno-asynchronous-unwind-tables` compiler flag is automatically applied. Please note, that all limitations of Microsoft Windows apply. + +```bash +make \ + CXX=x86_64-w64-mingw32-g++ \ + CC=x86_64-w64-mingw32-gcc \ + FC=x86_64-w64-mingw32-gfortran +``` + +To run tests, `BLAS=0` may be supplied (since Cygwin does not seem to provide BLAS-bits for the MinGW part). However, this may be different for "native" MinGW, or can be fixed by supplying a BLAS library somehow else. + +### MinGW + +This is about the "native" MinGW environment. Please note, there is the original [MinGW](https://mingw.osdn.io/) as well as a [fork](http://mingw-w64.org/) (made in 2007). Both of which can target Windows 64-bit. Here, the [MSYS2 installer](https://www.msys2.org/) (scroll down on that page to see the full installation instructions) has been used (see the [details](https://github.com/msys2/msys2/wiki/MSYS2-installation) on how to install missing packages). + +```bash +pacman -S msys/make msys/python msys/diffutils \ + mingw64/mingw-w64-x86_64-gcc mingw64/mingw-w64-x86_64-gcc-fortran \ + mingw64/mingw-w64-x86_64-openblas +``` + +Similar to Cygwin/MinGW, the `-fno-asynchronous-unwind-tables` flag is automatically applied. + +```bash +make +``` + +LIBXSMM can be built as a static library as well as a dynamic link library (`STATIC=0`). + +## Apple macOS + +LIBXSMM for macOS (OSX) is fully supported (i.e., it qualifies a release). The default is to rely on Apple's Clang based (platform-)compiler ("gcc"). However, the actual GCC as well as the Intel Compiler for macOS can be used. + +## FreeBSD + +LIBXSMM is occasionally tested under FreeBSD. For libxsmmext, it is necessary to install OpenMP (`sudo pkg install openmp`). + +```bash +bash +gmake +``` +An attempt to run the [tests](https://github.com/hfp/libxsmm/wiki/Validation) may ask for a LAPACK/BLAS installation (unless `BLAS=0` is given). Both, Netlib BLAS (reference) and OpenBLAS are available (in case of linker error due to the GNU Fortran runtime library, one can try `gmake CXX=g++7 CC=gcc7 FC=gfortran7` i.e., select a consistent tool chain and adjust `LD_LIBRARY_PATH` accordingly e.g., `/usr/local/lib/gcc7`). + +## PGI Compiler + +The PGI Compiler 2019 (and later) is supported. Earlier versions were only occasionally tested and automatically enabled the `COMPATIBLE=1` and `INTRINSIC=0` settings. Still, atomic builtins seem incomplete (at least with `pgcc`) hence LIBXSMM built with PGI Compiler is not fully thread-safe (tests/threadsafety can fail). Support for GNU's libatomic has been incorporated mainly for PGI but is also missing built-in compiler support hence supposedly atomic operations are mapped to normal (non-atomic) code sequences (`LIBXSMM_SYNC_SYSTEM`). + +```bash +make CXX=pgc++ CC=pgcc FC=pgfortran +``` + +### ARM AArch64 + +This section is not strictly about compiler compatibility but rather about AArch64 (v8.1) being supported, which practically covers the baseline ARM 64-bit architecture from embedded and mobile to supercomputers. The build and installation process of LIBXSMM is the same as for Intel Architecture (IA) and the library can be natively compiled or cross-compiled. The latter for instance looks like: + +```bash +make PLATFORM=1 AR=aarch64-linux-gnu-ar \ + FC=aarch64-linux-gnu-gfortran \ + CXX=aarch64-linux-gnu-g++ \ + CC=aarch64-linux-gnu-gcc +``` + +**Note**: Apple M1 is supported but JIT code generation may fail due to macOS 11 ("Big Sur"). LIBXSMM does not currently support macOS 11.x (regardless of ARM or Intel Architecture). diff --git a/third_party/libxsmm/documentation/libxsmm_dl.md b/third_party/libxsmm/documentation/libxsmm_dl.md new file mode 100644 index 0000000000000000000000000000000000000000..c67b131703531cf8766a19d4dee09fae599f3c87 --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_dl.md @@ -0,0 +1,133 @@ +## Deep Neural Networks + +To achieve best performance with small convolutions for CNN on SIMD architectures, a specific data layout must be used. As this layout depends on several architectural parameters, the goal of LIBXSMM's interface is to hide this complexity from the user by providing copy-in and copy-out routines. This happens using opaque data types, which themselves are later bound to a convolution operation. + +The interface is available for C. There is a collection of code samples ([samples/deeplearning](https://github.com/hfp/libxsmm/tree/master/samples/deeplearning)) available including a light-weight [framework for deep learning (GXM)](https://github.com/hfp/libxsmm/tree/master/samples/deeplearning/gxm), and samples with focus on [Convolutional Deep Neural Networks (DNNs)](https://github.com/hfp/libxsmm/tree/master/samples/deeplearning/cnnlayer), or [LSTM cells](https://github.com/hfp/libxsmm/tree/master/samples/deeplearning/lstmdriver), etc. The general concept of the CNN interface is circled around a few types: `libxsmm_dnn_layer`, `libxsmm_dnn_buffer`, `libxsmm_dnn_bias`, and `libxsmm_dnn_filter`. A handle of such a type is always setup by calling a create-function. + +```C +/** Simplified LIBXSMM types which are needed to create a handle. */ + +/** Structure which describes the input and output of data (DNN). */ +typedef struct libxsmm_dnn_conv_desc { + int N; /* number of images in mini-batch */ + int C; /* number of input feature maps */ + int H; /* height of input image */ + int W; /* width of input image */ + int K; /* number of output feature maps */ + int R; /* height of filter kernel */ + int S; /* width of filter kernel */ + int u; /* vertical stride */ + int v; /* horizontal stride */ + int pad_h; /* height of logical rim padding to input + for adjusting output height */ + int pad_w; /* width of logical rim padding to input + for adjusting output width */ + int pad_h_in; /* height of zero-padding in input buffer, + must equal to pad_h for direct conv */ + int pad_w_in; /* width of zero-padding in input buffer, + must equal to pad_w for direct conv */ + int pad_h_out; /* height of zero-padding in output buffer */ + int pad_w_out; /* width of zero-padding in output buffer */ + int threads; /* number of threads to use when running + convolution */ + libxsmm_dnn_datatype datatype; /* datatypes use for all input and outputs */ + libxsmm_dnn_tensor_format buffer_format; /* format which is for buffer buffers */ + libxsmm_dnn_tensor_format filter_format; /* format which is for filter buffers */ + libxsmm_dnn_conv_algo algo; /* convolution algorithm used */ + libxsmm_dnn_conv_option options; /* additional options */ + libxsmm_dnn_conv_fuse_op fuse_ops; /* used ops into convolutions */ +} libxsmm_dnn_conv_desc; + +/** Type of algorithm used for convolutions. */ +typedef enum libxsmm_dnn_conv_algo { + /** let the library decide */ + LIBXSMM_DNN_CONV_ALGO_AUTO, /* ignored for now */ + /** direct convolution. */ + LIBXSMM_DNN_CONV_ALGO_DIRECT +} libxsmm_dnn_conv_algo; + +/** Denotes the element/pixel type of an image/channel. */ +typedef enum libxsmm_dnn_datatype { + LIBXSMM_DNN_DATATYPE_F32, + LIBXSMM_DNN_DATATYPE_I32, + LIBXSMM_DNN_DATATYPE_I16, + LIBXSMM_DNN_DATATYPE_I8 +} libxsmm_dnn_datatype; + +libxsmm_dnn_layer* libxsmm_dnn_create_conv_layer( + libxsmm_dnn_conv_desc conv_desc, libxsmm_dnn_err_t* status); +libxsmm_dnn_err_t libxsmm_dnn_destroy_conv_layer( + const libxsmm_dnn_layer* handle); +``` + +A sample call looks like (without error checks): + +```C +/* declare LIBXSMM variables */ +libxsmm_dnn_conv_desc conv_desc; +libxsmm_dnn_err_t status; +libxsmm_dnn_layer* handle; +/* setting conv_desc values.... */ +conv_desc.N = ... +/* create handle */ +handle = libxsmm_dnn_create_conv_layer(conv_desc, &status); +``` + +Next activation and filter buffers need to be linked, initialized and bound to the handle. Afterwards the convolution can be executed in a threading environment of choice (error checks are omitted for brevity): + +```C +float *input, *output, *filter; +libxsmm_dnn_buffer* libxsmm_reg_input; +libxsmm_dnn_buffer* libxsmm_reg_output; +libxsmm_dnn_filter* libxsmm_reg_filter; + +/* allocate data */ +input = (float*)libxsmm_aligned_malloc(...); +output = ...; + +/* link data to buffers */ +libxsmm_reg_input = libxsmm_dnn_link_buffer( libxsmm_handle, LIBXSMM_DNN_INPUT, input, + LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); +libxsmm_reg_output = libxsmm_dnn_link_buffer( libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, + LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); +libxsmm_reg_filter = libxsmm_dnn_link_filter( libxsmm_handle, LIBXSMM_DNN_FILTER, filter, + LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); + +/* copy in data to LIBXSMM format: naive format is: */ +/* (mini-batch)(number-featuremaps)(featuremap-height)(featuremap-width) for layers, */ +/* and the naive format for filters is: */ +/* (number-output-featuremaps)(number-input-featuremaps)(kernel-height)(kernel-width) */ +libxsmm_dnn_copyin_buffer(libxsmm_reg_input, (void*)naive_input, LIBXSMM_DNN_TENSOR_FORMAT_NCHW); +libxsmm_dnn_zero_buffer(libxsmm_reg_output); +libxsmm_dnn_copyin_filter(libxsmm_reg_filter, (void*)naive_filter, LIBXSMM_DNN_TENSOR_FORMAT_KCRS); + +/* bind layer to handle */ +libxsmm_dnn_bind_input_buffer(libxsmm_handle, libxsmm_reg_input, LIBXSMM_DNN_REGULAR_INPUT); +libxsmm_dnn_bind_output_buffer(libxsmm_handle, libxsmm_reg_output, LIBXSMM_DNN_REGULAR_OUTPUT); +libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_reg_filter, LIBXSMM_DNN_REGULAR_FILTER); + +/* allocate and bind scratch */ +scratch = libxsmm_aligned_scratch(libxsmm_dnn_get_scratch_size( + libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, &status), 2097152); +libxsmm_dnn_bind_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, scratch); + +/* run the convolution */ +#pragma omp parallel +{ + libxsmm_dnn_convolve_st(libxsmm_handle, LIBXSMM_DNN_CONV_KIND_FWD, 0, + omp_get_thread_num(), omp_get_num_threads()); +} + +/* copy out data */ +libxsmm_dnn_copyout_buffer(libxsmm_output, (void*)naive_libxsmm_output, + LIBXSMM_DNN_TENSOR_FORMAT_NCHW); + +/* clean up */ +libxsmm_dnn_release_scratch(...); +libxsmm_dnn_release_buffer(...); +... +libxsmm_dnn_destroy_buffer(...); +... +libxsmm_dnn_destroy_conv_layer(...); +``` + diff --git a/third_party/libxsmm/documentation/libxsmm_fortran.md b/third_party/libxsmm/documentation/libxsmm_fortran.md new file mode 100644 index 0000000000000000000000000000000000000000..660f5ec7e531eea47636d54f779c26892237736e --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_fortran.md @@ -0,0 +1,14 @@ +Title: LIBXSMM +project: LIBXSMM +author: Intel Corporation +summary: Library targeting Intel Architecture for specialized matrix operations. +project_github: https://github.com/hfp/libxsmm +project_download: https://github.com/hfp/libxsmm/releases/latest +favicon: ../.theme/img/favicon.png +css: ../.theme/ford.css +output_dir: ../html +src_dir: ../include +search: true +page_dir: . + +Library targeting Intel Architecture for specialized matrix operations: [libxsmm.readthedocs.io/](https://libxsmm.readthedocs.io/) diff --git a/third_party/libxsmm/documentation/libxsmm_magazine.docx b/third_party/libxsmm/documentation/libxsmm_magazine.docx new file mode 100644 index 0000000000000000000000000000000000000000..22710a9150a18dbad6fb7cb064032ad807b0f666 Binary files /dev/null and b/third_party/libxsmm/documentation/libxsmm_magazine.docx differ diff --git a/third_party/libxsmm/documentation/libxsmm_mm.docx b/third_party/libxsmm/documentation/libxsmm_mm.docx new file mode 100644 index 0000000000000000000000000000000000000000..a28b32f2b1248a7a81c444b479dd901075dcccce Binary files /dev/null and b/third_party/libxsmm/documentation/libxsmm_mm.docx differ diff --git a/third_party/libxsmm/documentation/libxsmm_mm.md b/third_party/libxsmm/documentation/libxsmm_mm.md new file mode 100644 index 0000000000000000000000000000000000000000..7a69cd4fafa958aa4a549635f081c3599fe85418 --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_mm.md @@ -0,0 +1,238 @@ +## Matrix Multiplication + +### Overview + +To perform the dense matrix-matrix multiplication Cm x n = alpha · Am x k · Bk x n + beta · Cm x n, the full-blown GEMM interface can be treated with "default arguments" (which is deviating from the BLAS standard, however without compromising the binary compatibility). Default arguments are derived from compile-time constants (configurable) for historic reasons (LIBXSMM's "pre-JIT era"). + +```C +libxsmm_?gemm(NULL/*transa*/, NULL/*transb*/, + &m/*required*/, &n/*required*/, &k/*required*/, + NULL/*alpha*/, a/*required*/, NULL/*lda*/, + b/*required*/, NULL/*ldb*/, + NULL/*beta*/, c/*required*/, NULL/*ldc*/); +``` + +For the C interface (with type prefix `s` or `d`), all arguments including m, n, and k are passed by pointer. This is needed for binary compatibility with the original GEMM/BLAS interface. + +```C +libxsmm_gemm(NULL/*transa*/, NULL/*transb*/, + m/*required*/, n/*required*/, k/*required*/, + NULL/*alpha*/, a/*required*/, NULL/*lda*/, + b/*required*/, NULL/*ldb*/, + NULL/*beta*/, c/*required*/, NULL/*ldc*/); +``` + +The C++ interface is also supplying overloaded versions where m, n, and k can be passed by‑value (making it clearer that m, n, and k are non-optional arguments). + +```FORTRAN +! Dense matrix multiplication (single/double-precision). +CALL libxsmm_?gemm(m=m, n=n, k=k, a=a, b=b, c=c) +! Dense matrix multiplication (generic interface). +CALL libxsmm_gemm(m=m, n=n, k=k, a=a, b=b, c=c) +``` + +The FORTRAN interface supports optional arguments (without affecting the binary compatibility with the original BLAS interface) by allowing to omit arguments where the C/C++ interface allows for NULL to be passed. + +```C +/** Dense matrix multiplication (single/double-precision). */ +libxsmm_blas_?gemm(NULL/*transa*/, NULL/*transb*/, + &m/*required*/, &n/*required*/, &k/*required*/, + NULL/*alpha*/, a/*required*/, NULL/*lda*/, + b/*required*/, NULL/*ldb*/, + NULL/*beta*/, c/*required*/, NULL/*ldc*/); +``` + +For convenience, a BLAS-based dense matrix multiplication (`libxsmm_blas_gemm`) is provided for all supported languages. This only re-exposes the underlying GEMM/BLAS implementation, but the interface accepts optional arguments (or NULL pointers in C) where the regular GEMM expects a value. To remove any BLAS-dependency, please follow the [Link Instructions](index.md#link-instructions). A BLAS-based GEMM can be useful for validation/benchmark purposes, and more important as a fallback when building an application-specific dispatch mechanism. + +```C +/** OpenMP parallelized dense matrix multiplication. */ +libxsmm_?gemm_omp(&transa, &transb, &m, &n, &k, + &alpha, a, &lda, b, &ldb, &beta, c, &ldc); +``` + +A more recently added variant of matrix multiplication is parallelized based on the OpenMP standard. These routines will open an internal parallel region and rely on "classic" thread based OpenMP. If these routines are called from inside of a parallel region, the parallelism will be based on tasks (OpenMP 3.0). Please note that all OpenMP-based routines are hosted by the extension library (libxsmmext), which keeps the main library agnostic with respect to a threading runtime. + +### Manual Code Dispatch + +Successively calling a kernel (i.e., multiple times) allows for amortizing the cost of the code dispatch. Moreover, to customize the dispatch mechanism, one can rely on the following interface. + +```C +/** Call dispatched (*function_ptr)(a, b, c [, pa, pb, pc]). */ +libxsmm_[s|d]mmfunction libxsmm_[type-prefix]mmdispatch( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + /** NULL: tight fit (m) */ const libxsmm_blasint* lda, + /** NULL: tight fit (k) */ const libxsmm_blasint* ldb, + /** NULL: tight fit (m) */ const libxsmm_blasint* ldc, + /** NULL: LIBXSMM_ALPHA */ const type* alpha, + /** NULL: LIBXSMM_BETA */ const type* beta, + /** NULL: LIBXSMM_FLAGS */ const int* flags, + /** NULL: LIBXSMM_PREFETCH_NONE (not LIBXSMM_PREFETCH!) */ + const int* prefetch); +``` + +Overloaded function signatures are provided and allow to omit arguments (C++ and FORTRAN), which are then derived from the [configurable defaults](https://github.com/hfp/libxsmm/blob/master/include/libxsmm_config.h). In C++, `libxsmm_mmfunction` can be used to instantiate a functor rather than making a distinction between numeric types per type-prefix. For lower precision GEMMs, `libxsmm_mmfunction` optionally takes a second type (output type). + +```C +/* generates or dispatches the code specialization */ +libxsmm_mmfunction xmm(m, n, k); +if (xmm) { /* JIT'ted code */ + /* can be parallelized per, e.g., OpenMP */ + for (int i = 0; i < n; ++i) { + xmm(a+i*asize, b+i*bsize, c+i*csize); + } +} +``` + +Similarly in FORTRAN (see [samples/smm/smm.f](https://github.com/hfp/libxsmm/blob/master/samples/smm/smm.f)), a generic interface (`libxsmm_mmdispatch`) can be used to dispatch a `LIBXSMM_?MMFUNCTION`. The handle encapsulated by such a `LIBXSMM_?MMFUNCTION` can be called per `libxsmm_call`. Beside of dispatching code, one can also call statically generated kernels (e.g., `libxsmm_dmm_4_4_4`) by using the prototype functions included with the FORTRAN and C/C++ interface. Prototypes are present whenever static code was requested at compile-time of the library (e.g. per `make MNK="1 2 3 4 5"`). + +```FORTRAN +TYPE(LIBXSMM_DMMFUNCTION) :: xmm +CALL libxsmm_dispatch(xmm, m, n, k) +IF (libxsmm_available(xmm)) THEN + DO i = LBOUND(c, 3), UBOUND(c, 3) ! consider OpenMP + CALL libxsmm_dmmcall(xmm, a(:,:,i), b(:,:,i), c(:,:,i)) + END DO +END IF +``` + +### Batched Multiplication + +In case of batched SMMs, it can be beneficial to supply "next locations" such that the upcoming operands are prefetched ahead of time. Such a location would be the address of the next matrix to be multiplied (and not any of the floating-point elements within the "current" matrix-operand). The "prefetch strategy" is requested at dispatch-time of a kernel. A [strategy](libxsmm_be.md#prefetch-strategy) other than `LIBXSMM_PREFETCH_NONE` turns the signature of a JIT'ted kernel into a function with six arguments (`a,b,c, pa,pb,pc` instead of `a,b,c`). To defer the decision about the strategy to a CPUID-based mechanism, one can choose `LIBXSMM_PREFETCH_AUTO`. + +```C +int prefetch = LIBXSMM_PREFETCH_AUTO; +int flags = 0; /* LIBXSMM_FLAGS */ +libxsmm_dmmfunction xmm = NULL; +double alpha = 1, beta = 0; +xmm = libxsmm_dmmdispatch(23/*m*/, 23/*n*/, 23/*k*/, + NULL/*lda*/, NULL/*ldb*/, NULL/*ldc*/, + &alpha, &beta, &flags, &prefetch); +``` + +Above, pointer-arguments of `libxsmm_dmmdispatch` can be NULL (or OPTIONAL in FORTRAN): for LDx this means a "tight" leading dimension, alpha, beta, and flags are given by a [default value](https://github.com/hfp/libxsmm/blob/master/include/libxsmm_config.h) (which is selected at compile-time), and for the prefetch strategy a NULL-argument refers to "no prefetch" (which is equivalent to an explicit `LIBXSMM_PREFETCH_NONE`). By design, the prefetch strategy can be changed at runtime (as soon as valid next-locations are used) without changing the call-site (kernel-signature with six arguments). + + + +```C +if (0 < n) { /* check that n is at least 1 */ +# pragma parallel omp private(i) + for (i = 0; i < (n - 1); ++i) { + const double *const ai = a + i * asize; + const double *const bi = b + i * bsize; + double *const ci = c + i * csize; + xmm(ai, bi, ci, ai + asize, bi + bsize, ci + csize); + } + xmm(a + (n - 1) * asize, b + (n - 1) * bsize, c + (n - 1) * csize, + /* pseudo prefetch for last element of batch (avoids page fault) */ + a + (n - 1) * asize, b + (n - 1) * bsize, c + (n - 1) * csize); +} +``` + +To process a batch of matrix multiplications and to prefetch the operands of the next multiplication ahead of time, the code presented in the [Overview](#overview) section may be modified as shown above. The last multiplication is peeled from the main batch to avoid prefetching out-of-bounds (OOB). Prefetching from an invalid address does not trap an exception, but an (unnecessary) page fault can be avoided. + + + +```C +/** Batched matrix multiplications (explicit data representation). */ +int libxsmm_mmbatch(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char* transa, const char* transb, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const void* alpha, const void* a, const libxsmm_blasint* lda, + const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc, + libxsmm_blasint index_base, libxsmm_blasint index_stride, + const libxsmm_blasint stride_a[], + const libxsmm_blasint stride_b[], + const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize, + int tid, int ntasks); +``` + +To further simplify the multiplication of matrices in a batch, LIBXSMM's batch interface can help to extract the necessary input from a variety of existing structures (integer indexes, array of pointers both with Byte sized strides). An expert interface (see above) can employ a user-defined threading runtime (`tid` and `ntasks`). In case of OpenMP, `libxsmm_mmbatch_omp` is ready-to-use and hosted by the extension library (libxsmmext). Of course, `libxsmm_mmbatch_omp` does not take `tid` and `ntasks` since both arguments are given by OpenMP. Similarly, a sequential version (shown below) is available per `libxsmm_gemm_batch` (libxsmm). + +Please note that an explicit data representation should exist and reused rather than created only to call the explicit batch-interface. Creating such a data structure only for this matter can introduce an overhead which is hard to amortize (speedup). If no explicit data structure exists, a "chain" of multiplications can be often algorithmically described (see [self-hosted batch loop](#implicit-batches)). + +```C +void libxsmm_gemm_batch(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char* transa, const char* transb, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const void* alpha, const void* a, const libxsmm_blasint* lda, + const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc, + libxsmm_blasint index_base, libxsmm_blasint index_stride, + const libxsmm_blasint stride_a[], + const libxsmm_blasint stride_b[], + const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize); +``` + +In recent BLAS library implementations, `dgemm_batch` and `sgemm_batch` have been introduced. This BLAS(-like) interface allows for groups of homogeneous batches, which is like an additional loop around the interface as introduced above. On the other hand, the BLAS(-like) interface only supports arrays of pointers for the matrices. In contrast, above interface supports arrays of pointers as well as arrays of indexes plus a flexible way to extract data from arrays of structures (AoS). LIBXSMM also supports this (new) BLAS(-like) interface with `libxsmm_?gemm_batch` and `libxsmm_?gemm_batch_omp` (the latter of which relies on LIBXSMM/ext). Further, existing calls to `dgemm_batch` and `sgemm_batch` can be intercepted and replaced with [LIBXSMM's call wrapper](#call-wrapper). The signatures of `libxsmm_dgemm_batch` and `libxsmm_sgemm_batch` are equal except for the element type (`double` and `float` respectively). + +```C +void libxsmm_dgemm_batch(const char transa_array[], const char transb_array[], + const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], + const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], + const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]); +``` + +**Note**: the multi-threaded implementation (`ntasks > 1` or "omp" form of the functions) avoids data races if indexes or pointers for the destination (C-)matrix are duplicated. This synchronization occurs automatically (`beta != 0`), but can be avoided by passing a negative `batchsize`, `group_size` and/or a negative `group_count`. + +### User-Data Dispatch + +It can be desired to dispatch user-defined data, i.e., to query a value based on a key. This functionality can be used to, e.g., dispatch multiple kernels in one step if a code location relies on multiple kernels. This way, one can pay the cost of dispatch one time per task rather than according to the number of JIT-kernels used by this task. This functionality is detailed in the section about [Service Functions](libxsmm_aux.md#user-data-dispatch). + +### Call Wrapper + +#### Overview + +Since the library is binary compatible with existing GEMM calls (BLAS), such calls can be replaced at link-time or intercepted at runtime of an application such that LIBXSMM is used instead of the original BLAS library. There are two cases to consider: (1) static linkage, and (2) dynamic linkage of the application against the original BLAS library. When calls are intercepted, one can select a sequential (default) or an OpenMP-parallelized implementation (`make WRAP=2`). + +```bash +LIBXSMM STATISTIC: 1000 multiplications +dgemm(trans=NN mnk=32,32,21 ldx=32,21,32 a,b=1,0): 8% [main$omp$1] +dgemm(trans=NN mnk=32,21,32 ldx=32,32,32 a,b=1,0): 8% [main$omp$1] +dgemm(trans=NN mnk=10,21,32 ldx=10,32,10 a,b=1,0): 5% [main$omp$1] +dgemm(trans=NN mnk=32,10,32 ldx=32,32,32 a,b=1,0): 5% [main$omp$1] +dgemm(trans=NN mnk=32,32,10 ldx=32,10,32 a,b=1,0): 5% [main$omp$1] +``` + +Intercepted GEMMs can also build a sophisticated statistic (histogram) with LIBXSMM_VERBOSE=4 (or higher). The histogram displays the call sites (debug symbol name) of all intercepted GEMMs ([example](https://github.com/hfp/libxsmm/blob/master/samples/utilities/wrap/autobatch.c) above depicts an OpenMP region hosted by the main function). With level 5 (or higher), the histogram yields the entire content, and eventually less relevant entries are not pruned. An application must be built with symbols (`-g`) and export symbols similar to shared libraries (`-Wl,--export-dynamic` even when linked statically) in order to display the symbol names of where the GEMMs originated (call site). + +**Note**: Intercepting GEMM calls is low effort but implies overhead, which can be relatively high for small-sized problems. LIBXSMM's native programming interface has lower overhead and allows to amortize this overhead when using the same multiplication kernel in a consecutive fashion along with sophisticated data prefetch. + +#### Static Linkage + +An application which is linked statically against BLAS requires to wrap the `sgemm_` and the `dgemm_` symbol (an alternative is to wrap only `dgemm_`). To relink the application (without editing the build system) can often be accomplished by copying and pasting the linker command as it appeared in the console output of the build system, and then re-invoking a modified link step (please also consider `-Wl,--export-dynamic`). + +```bash +gcc [...] -Wl,--wrap=dgemm_,--wrap=sgemm_ \ + /path/to/libxsmmext.a /path/to/libxsmm.a \ + /path/to/your_regular_blas.a +``` + +In addition, existing [BLAS(-like) batch-calls](#blas-batch-interface) can be intercepted as well: + +```bash +gcc [...] -Wl,--wrap=dgemm_batch_,--wrap=sgemm_batch_ \ + -Wl,--wrap=dgemm_batch,--wrap=sgemm_batch \ + -Wl,--wrap=dgemm_,--wrap=sgemm_ \ + /path/to/libxsmmext.a /path/to/libxsmm.a \ + /path/to/your_regular_blas.a +``` + +Above, GEMM and GEMM_BATCH are intercepted both, however this can be chosen independently. For GEMM_BATCH the Fortran and C-form of the symbol may be intercepted both (regular GEMM can always be intercepted per `?gemm_` even when `?gemm` is used in C-code). + +**Note**: The static link-time wrapper technique may only work with a GCC tool chain (GNU Binutils: `ld`, or `ld` via compiler-driver), and it has been tested with GNU GCC, Intel Compiler, and Clang. However, this does not work under Microsoft Windows (even when using the GNU tool chain or Cygwin). + +#### Dynamic Linkage + +An application that is dynamically linked against BLAS allows to intercept the GEMM calls at startup time (runtime) of the unmodified executable by using the LD_PRELOAD mechanism. The shared library of LIBXSMMext (`make STATIC=0`) can be used to intercept GEMM calls: + +```bash +LD_LIBRARY_PATH=/path/to/libxsmm/lib:${LD_LIBRARY_PATH} \ +LD_PRELOAD=libxsmmext.so \ + ./myapplication +``` + diff --git a/third_party/libxsmm/documentation/libxsmm_prof.md b/third_party/libxsmm/documentation/libxsmm_prof.md new file mode 100644 index 0000000000000000000000000000000000000000..f5f7b0cf32ee24f4304ef295ba95774fbb70a729 --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_prof.md @@ -0,0 +1,45 @@ +## Performance Analysis + +### Intel VTune Profiler + +To analyze which kind of kernels have been called, and from where these kernels have been invoked (call stack), the library allows profiling its JIT code using Intel VTune Profiler. To enable this support, VTune's root directory needs to be set at build-time of the library. Enabling symbols (SYM=1 or DBG=1) incorporates VTune's JIT Profiling API: + +```bash +source /opt/intel/vtune_profiler/vtune-vars.sh +make SYM=1 +``` + +Above, the root directory is automatically determined from the environment (VTUNE_PROFILER_\*_DIR or VTUNE_AMPLIFIER_\*_DIR with older versions). This variable is present after source'ing the Intel VTune environment (`source /path/to/vtune_amplifier/amplxe-vars.sh` with older version), but it can be manually provided as well (`make VTUNEROOT=/path/to/vtune_amplifier`). Symbols are not really required to display kernel names for the dynamically generated code, however enabling symbols makes the analysis much more useful for the rest of the (static) code, and hence it has been made a prerequisite. For example, when "call stacks" are collected it is possible to find out where the JIT code has been invoked by the application: + +```bash +vtune -r resultdir -data-limit 0 -collect hotspots \ + -knob enable-stack-collection=true \ + -knob sampling-mode=hw \ + -knob stack-size=0 \ + -- ./myapplication +``` + +In case of an MPI-parallelized application, it can be useful to only collect results from a "representative" rank, and to also avoid running the event collector in every rank of the application. With Intel MPI both of which can be achieved by: + +```bash +mpirun -gtool 'vtune -r resultdir -data-limit 0 -collect hotspots \ + -knob sampling-mode=hw -knob enable-stack-collection=true \ + -knob stack-size=0:4=exclusive' \ + [...] ./myapplication +``` + +The `:4=exclusive` is related to Intel MPI or mpirun's gtool arguments and unrelated to VTune's command line syntax (see `vtune --help` or `amplxe-cl --help` with older versions); such argument(s) need to appear at the end of the gtool-string. For instance, the shown command line selects the 5th rank (zero-based) along with exclusive usage of the performance monitoring unit (PMU) such that only one event-collector runs for all ranks (without rank-number, all ranks are sampled). + +Intel VTune Profiler presents invoked JIT code like functions, which belong to a module named "libxsmm.jit". The function name as well as the module name are supplied by LIBXSMM using VTune's JIT-Profiling API. Below, the shown "function name" (`libxsmm_knl_dnn_23x23x23_23_23_23_a1_b1_p6::mxm`) encodes an AVX-512 ("knl") double-precision kernel ("d") for small dense matrix multiplication, which performs no transposes ("nn"). The name further encodes M=N=K=LDA=LDB=LDC=23, Alpha=Beta=1.0, and a prefetch strategy ("p6"). + +![The shown "function name" (`libxsmm_knl_dnn_23x23x23_23_23_23_a1_b1_p6::mxm`) encodes an Intel AVX-512 ("knl") double-precision kernel ("d") for small dense matrix multiplication, which performs no transposes ("nn"). The name further encodes M=N=K=LDA=LDB=LDC=23, Alpha=Beta=1.0, and some prefetch strategy ("p6").](libxsmm_prof-vtune.png) + +An application that cannot rely on LIBXSMM's build system can apply `-DLIBXSMM_VTUNE=2` during compilation, and link against `${VTUNE_AMPLIFIER_XE_2017_DIR}/lib64/libjitprofiling.a`. For example, TensorFlow with LIBXSMM and Intel VTune Profiler may use this way to gain insight into LIBXSMM's JIT-code (see [here](tensorflow.md#performance-profiling)). + +### Linux perf + +With LIBXSMM, there is both basic (`perf map`) and extended support (`jitdump`) when profiling an application. To enable perf support at runtime, the environment LIBXSMM_VERBOSE needs to be set to a negative value. + +* The basic support can be enabled at compile-time with PERF=1 (implies SYM=1) using `make PERF=1`. At runtime of the application, a map-file ('jit-*pid*.map') is generated ('/tmp' directory). This file is automatically read by Linux perf, and enriches the information about unknown code such as JIT'ted kernels. +* The support for "jitdump" can be enabled by supplying JITDUMP=1 (implies PERF=1) or PERF=2 (implies JITDUMP=1) when making the library: `make JITDUMP=1` or `make PERF=2`. At runtime of the application, a dump-file ('jit-*pid*.dump') is generated (in perf's debug directory, usually `$HOME/.debug/jit/`) which includes information about JIT'ted kernels (such as addresses, symbol names, code size, and the code itself). The dump file can be injected into `perf.data` (using `perf inject -j`), and it enables an annotated view of the assembly in perf's report (requires a reasonably recent version of Linux perf). + diff --git a/third_party/libxsmm/documentation/libxsmm_qna.md b/third_party/libxsmm/documentation/libxsmm_qna.md new file mode 100644 index 0000000000000000000000000000000000000000..f4071048517b844bdb329b66d139b184fd14cb6d --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_qna.md @@ -0,0 +1,58 @@ +## What is the background of the name "LIBXSMM"? +The "MM" stands for Matrix Multiplication, and the "S" clarifies the working domain i.e., Small Matrix Multiplication. The latter also means the name is neither a variation of "MXM" nor an eXtreme Small Matrix Multiplication but rather about Intel Architecture (x86) - and no, the library is [64‑bit only](https://github.com/hfp/libxsmm/issues/103#issuecomment-256887962). The spelling of the name might follow the syllables of libx\\/smm, libx'smm, or libx‑smm. +> **NOTE**: the library does [not](https://github.com/hfp/libxsmm/issues/103#issuecomment-256887962) support 32-bit architecture (64‑bit only) + +## What is a small matrix multiplication? +When characterizing the problem-size using the M, N, and K parameters, a problem-size suitable for LIBXSMM falls approximately within *(M N K)1/3 \<= 128* (which illustrates that non-square matrices or even "tall and skinny" shapes are covered as well). The library is typically used to generate code up to the specified [threshold](#auto-dispatch). Raising the threshold may not only generate excessive amounts of code (due to unrolling in M or K dimension), but also miss to implement a tiling scheme to effectively utilize the cache hierarchy. For auto-dispatched problem-sizes above the configurable threshold (explicitly JIT'ted code is **not** subject to the threshold), LIBXSMM is falling back to BLAS. In terms of GEMM, the supported kernels are limited to *Alpha := 1*, *Beta := \{ 1, 0 \}*, and *TransA := 'N'*. +> **NOTE**: *Alpha*, *Beta*, and *TransA* are limited to `1`, `{ 1, 0 }`, and `'N'` respectively. + +## What is a small convolution? +In the last years, new workloads such as deep learning and more specifically convolutional neural networks (CNN) emerged, and are pushing the limits of today's hardware. One of the expensive kernels is a small convolution with certain kernel sizes (3, 5, or 7) such that calculations in the frequency space is not the most efficient method when compared with direct convolutions. LIBXSMM's current support for convolutions aims for an easy to use invocation of small (direct) convolutions, which are intended for CNN training and classification. The [Interface](#interface-for-convolutions) is currently ramping up, and the functionality increases quickly towards a broader set of use cases. + +## What about "medium-sized" and big(ger) matrix multiplications? +A more recent addition are GEMM routines, which are parallelized using OpenMP (`libxsmm_?gemm_omp`). These routines leverage the same specialized kernel routines as the small matrix multiplications, in-memory code generation (JIT), and automatic code/parameter dispatch but they implement a tile-based multiplication scheme i.e., a scheme that is suitable for larger problem-sizes. For *Alpha*, *Beta*, *TransA*, and *TransB*, the limitations of the small matrix multiplication kernels apply. More details can be found in the [description of the xgemm sample code](https://github.com/hfp/libxsmm/tree/master/samples/xgemm#xgemm-tiled-gemm-routines). + +## How to determine whether an application can benefit from using LIBXSMM or not? +Given the application uses BLAS to carry out matrix multiplications, one may use the [Call Wrapper](#call-wrapper), and measure the application performance e.g., time to solution. However, the latter can significantly improve when using LIBXSMM's API directly. To check whether there are applicable GEMM-calls, the [Verbose Mode](#verbose-mode) can help to collect an insight. Further, when an application uses [Intel MKL 11.2](https://registrationcenter.intel.com/en/forms/?productid=2558) (or higher), then running the application with the environment variable MKL_VERBOSE=1 (`env MKL_VERBOSE=1 ./workload > verbose.txt`) can collect a similar insight (`grep -a "MKL_VERBOSE DGEMM(N,N" verbose.txt | cut -d'(' -f2 | cut -d, -f3-5"`). + +## Is LIBXSMM compatible from version-to-version, or what is the ABI commitment? +One may have a look at issue [#120](https://github.com/hfp/libxsmm/issues/120#issuecomment-264498939) or [#282](https://github.com/hfp/libxsmm/issues/282#issuecomment-485390494), but in summary: +* Binary compatibility is not continuously tested (only manually for a subset of the API namely SMM domain). +* Major versions are likely breaking binary compatibility with existing integrations (that is typical). +* Minor versions may break binary compatibility of recently introduced features (may not be typical). +* Update and patch versions are binary compatible but may only be released on request (issue). + +LIBXSMM's API for Small Matrix Multiplications (SMMs) is considered stable, and all major known applications (e.g., CP2K, EDGE, NEK5K, and SeisSol) either rely on SMMs or are able (and want) to benefit from an improved API of any of the other domains (e.g., DL). Until at least v2.0, LIBXSMM is not able to track or even maintain binary compatibility and hence the SONAME also goes with the semantic version. A [list of public functions](https://github.com/hfp/libxsmm/blob/master/.abi.txt) is maintained (but there is no distinction for a small subset of them that are only meant for communication between LIBXSMM and LIBXSMM/ext). + +## I am relying on a prebuilt version of CP2K (or another application), is LIBXSMM incorporated and which version is it? +This can be determined using the environment variable `LIBXSMM_VERBOSE=2` (or higher verbosity). It is not even required to use an input or workload since the information in question is presented when the program terminates. For example: + +``` +LIBXSMM_VERBOSE=1 exe/Linux-x86-64-intelx/cp2k.psmp +[...] +LIBXSMM_VERSION: release-1.11 +LIBXSMM_TARGET: clx +``` + +## I am relying on a prebuilt version of an application, and I am concerned about optimal compiler flags. +LIBXSMM uses JIT-generated code according to the CPUID of the system. This is independent of the compiler flags used to build the library. If LIBXSMM was incorporated per [classic ABI](https://libxsmm.readthedocs.io/#classic-library-abi), `LIBXSMM_DUMP_BUILD=1` environment variable allows to print build flags used for LIBXSMM at termination of the application. This output of `LIBXSMM_DUMP_BUILD=1` can yield hints about the flags used to build the application (if similar). + +For concerns regarding the code of an application that cannot benefit from LIBXSMM, one may have a look at the build recipes of the [XCONFIGURE](http://xconfigure.readthedocs.io/) project. + +## What Operating Systems are covered by LIBXSMM, and what about Microsoft Windows? +The answer here focuses on the actual runtime support rather than the supported compiler tool chains used to build the library. All flavors of Linux are supported (if the library was successfully built), which includes installations running a security-hardened Linux kernel (SELinux). The Apple OS (OSX) is supported, which also includes more recent SIP-enabled versions (System Integrity Protection). The BSD OS is likely supported, but building the library is only occasionally validated. Microsoft Windows is supported for non-JIT operation, and for most (e.g., GEMM and MATCOPY) of the JIT-kernels (prefetch signature is not supported). There is currently no support for JIT in the DNN domain (no further check is performed i.e., crash at runtime). See also [issue #71](https://github.com/hfp/libxsmm/issues/71). + +## Does LIBXSMM has some support for GEMV? +The library generates acceptable code when using `M=1` or `N=1`. For example, building with `make M=16 N=1 K=16 AVX=2` and inspecting the assembly (build directory) or dumping/disassembling the JIT code (see reference documentation) shows the minimum number of load/store instructions. Given that GEMV is a memory bound operation, this suggests reasonable code quality. LIBXSMM selects from multiple microkernels (specific for each ISA extension) by using a fixed scheme/heuristic, which should be acceptable for GEMV. The sample code under [samples/smm](https://github.com/hfp/libxsmm/blob/master/samples/smm) provides ready-to-use benchmark drivers that can help to compare the performance with LAPACK/BLAS. Afore mentioned benchmarks exercise streaming all possible combinations of operands. + +## What about complex and mixed types? +This question refers to the following kind of element type of the GEMM interface of LIBXSMM: +* Complex types: complex numbers in single and double-precision, +* Mixed types: e.g. real double-precision and complex double-precision +There are no (immediate) plans to support more types for the GEMM part. Please note, that LIBXSMM indeed supports lower precision GEMM (wgemm). + +## What about voting for features? +All feedback and [issue reports](https://github.com/hfp/libxsmm/issues) are handled openly, are welcome and considered ([answered](https://github.com/hfp/libxsmm/issues?q=is%3Aissue+is%3Aclosed), and [collected](https://github.com/hfp/libxsmm/wiki/Development#longer-term-issues)). However, we do not seek for "feature votes" since the development of the library is not a democratic process. + +## \ What is the purpose of ROW_MAJOR vs. COL_MAJOR? +This build configuration is deprecated ([issue 85](https://github.com/hfp/libxsmm/issues/85)), otherwise there is nothing one cannot achieve with row-major as opposed to column-major storage order. In particular the choice is not about whether a program is written in C/C++ or in FORTRAN. The ROW_MAJOR setting is just offered for existing code, which calls into function(s) that assume row-major storage order and where these calls are to be replaced by LIBXSMM in a "1:1 fashion". It is encouraged to avoid the ROW_MAJOR setting since BLAS implies COL_MAJOR (and LIBXSMM is supposed to be compatible with BLAS). [More...](https://github.com/hfp/libxsmm/issues/80) diff --git a/third_party/libxsmm/documentation/libxsmm_samples.md b/third_party/libxsmm/documentation/libxsmm_samples.md new file mode 100644 index 0000000000000000000000000000000000000000..a99e5c2ca43d1e425b9df01c4bc0df82bf638e87 --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_samples.md @@ -0,0 +1,706 @@ +# [LIBXSMM Samples](https://github.com/hfp/libxsmm/raw/master/documentation/libxsmm_samples.pdf) + +## CP2K Artificial Benchmark + +The first code sample given for LIBXSMM was a performance reproducer exercising the same set of kernels usually generated for CP2K's SMM library. The code sample attempted to model the way "matrix stacks" are processed in CP2K, however there are two different code paths in CP2K: (1) the "main" code path used when processing stacks on the host-side, and (2) a code path targeting offload devices. Beside of the host-sided parallelization via MPI (and perhaps OpenMP), the secondly mentioned code path relies on an additional level of parallelization (which is obviously necessary to drive a potentially highly parallel offload device). Also, the additional level of parallelism is not exactly "nested" in the sense that it participates on sharing the same resources as the host-side. In fact, this "artificial benchmark" (cp2k code sample) is modeling a code path as utilized in the secondly mentioned case (offload device). + +## Hello LIBXSMM + +This example is focused on a specific functionality but may be considered as "Hello LIBXSMM". Copy and paste the example code and build it either manually and as described in our [main documentation](https://libxsmm.readthedocs.io/#hello-libxsmm) (see underneath the source code), or use GNU Make: + +```bash +cd /path/to/libxsmm +make + +cd /path/to/libxsmm/samples/hello +make + +./hello +``` + +Alternatively, one can use the Bazel build system. To further simplify, [Bazelisk](https://github.com/bazelbuild/bazelisk) is used to boot-strap [Bazel](https://bazel.build/): + +```bash +cd /path/to/libxsmm/samples/hello +bazelisk build //... + +./bazel-bin/hello +``` + +The [C/C++ code](https://github.com/hfp/libxsmm/blob/master/samples/hello/hello.cpp) given here uses LIBXSMM in header-only form (`#include `), which is in contrast to the code shown in the [main documentation](https://libxsmm.readthedocs.io/#hello-libxsmm). The [Fortran code](https://github.com/hfp/libxsmm/blob/master/samples/hello/hello.f) (`hello.f`) can be manually compiled like `gfortran -I/path/to/libxsmm/include hello.f -L/path/to/libxsmm/lib -libxsmmf -lxsmm -lxsmmnoblas -o hello` or as part of the above described invocation of GNU Make. + +## Magazine + +### Overview + +This collection of code samples accompany an article written for [issue #34](https://software.intel.com/sites/default/files/parallel-universe-issue-34.pdf) of the magazine [The Parallel Universe](https://software.intel.com/en-us/download/parallel-universe-magazine-issue-34-october-2018), an Intel publication. The articles focuses on Blaze-, Eigen-, and LIBXSMM-variants of Small Matrix Multiplications (SMMs). The set of sample codes now also includes a variant relying on BLAS and a variant that showcases LIBXSMM's explicit batch-interface. + +The baseline requirements are libraries that can operate on column-major storage order, "zero copy" when using existing memory buffers, and an API that is powerful enough to describe leading dimensions. Typically a library-internal parallelization of matrix multiplication is desired. However, for the magazine sample collection there is no performance gain expected since the matrices are small, and nested parallelism may only add overhead. Hence library-internal parallelism is disabled (BLAZE_USE_SHARED_MEMORY_PARALLELIZATION=0, EIGEN_DONT_PARALLELIZE). LIBXSMM provides parallelization on a per-functions basis and no global toggle is needed. + +The sample codes rely on the minimum programming language supported by the library in question (API): C++ in case of Blaze and Eigen, and C in case of LIBXSMM (both C++ and Fortran interfaces are available as well). For Blaze and Eigen, the build-system ensures to not map implementation into a BLAS library (normally desired but this would not test the library-native implementation). + +### Results + +To reproduce or repeat the performance measurements on a system of choice, all matrix operands are streamed by default. The file [magazine.h](https://github.com/hfp/libxsmm/blob/master/samples/magazine/magazine.h) can be edited to reproduce the desired combination (STREAM_A, STREAM_B, and STREAM_C). Whether or not matrix operands are streamed is motivated in publication. To reduce dependency on the compiler's OpenMP implementation, the benchmarks run single-threaded by default (`make OMP=1` can parallelize the batch of matrix multiplications). The outer/batch-level parallelization is also disabled to avoid accounting for proper first-touch memory population on multi-socket systems (NUMA). For the latter, the init-function (located in magazine.h) is not parallelized for simplicity. + +```bash +cd libxsmm; make +cd samples/magazine; make +``` + +To run the benchmark kernels presented by the article: + +```bash +./benchmark.sh +``` + +Please note that if multiple threads are enabled and used, an appropriate pin-strategy should be used (OMP_PLACES=threads, OMP_PROC_BIND=TRUE). To finally produce the benchmark charts: + +```bash +./benchmark-plot.sh blaze +./benchmark-plot.sh eigen +./benchmark-plot.sh xsmm +``` + +The plot script relies at least on Gnuplot. ImageMagick (mogrify) can be also useful if PNGs are created, e.g., `./benchmark-plot.sh xsmm png 0` (the last argument disables single-file charts in contrast to multi-page PDFs created by default, the option also disables chart titles). + +The set of kernels executed during the benchmark can be larger than the kernels presented by the plots: [benchmark.set](https://github.com/hfp/libxsmm/blob/master/samples/magazine/benchmark.set) selects the kernels independent of the kernels executed (union). + +## NEK Sample Collection + +This directory contains kernels taken from Nek{Box,5000}. They aim to represent most of the matrix-matrix workloads. + +Please note that the [mxm_std.f](https://github.com/hfp/libxsmm/blob/master/samples/nek/mxm_std.f) source code is protected by an (US) GOVERNMENT LICENSE, and under the copyright of the University of Chicago. + +### stpm + +Small tensor-product multiple (stpm) replicates the axhelm kernel, which computes the Laplacian with spectral elements. +Usage: + +```bash +./stpm m n k size1 size +``` + +The elements are m-by-n-by-k, mode picks the LIBXSMM interface used, and size scales the number of spectral elements. + +### rstr + +Restriction operator transforms elements from one size to another. This occurs in multi-grid, the convection operator, and, when the sizes are the same, the local Schwarz solves. Usage: + +```bash +./rstr m n k mm nn kk size1 size +``` + +The input elements are m-by-n-by-k and the output elements are mm-by-nn-by-kk. When m=mm, n=nn, k=kk, this half of a Schwarz solve. + +## SMM Sample Collection + +This collection of code samples exercises different memory streaming cases when performing the matrix multiplication *C~m x n~ = alpha · A~m x k~ · B~k x n~ + beta · C~m x n~*: (1) streaming the matrices A, B, and C which is usually referred as batched matrix multiplication, (2) streaming the inputs A and B but accumulating C within cache, (3) streaming the A and C matrices while B is kept in cache, (4) streaming the B and C matrices while A is kept in cache, and (4) not streaming any of the operands but repeating the very same multiplication until the requested number of matrix multiplications has been completed. + +Beside of measuring the duration of a test case, the performance is presented in GFLOPS/s. As an alternative metric, the memory bandwidth is given (the artificial "cached" case omits to present the cache-memory bandwidth). The "pseudo-performance" given in FLOPS/cycle is an artificial scoring, it not only uses a non-standard formula for calculating the FLOPS (*2 \* M \* N \* K - M \* N* rather than *2 \* M \* N \* K*) but also relies on (pseudo-)clock cycles: + +``` +$ ./specialized.sh 0 +m=32 n=32 k=32 size=87381 memory=2048.0 MB (DP) + +Batched (A,B,C)... + pseudo-perf.: 10.7 FLOPS/cycle + performance: 23.9 GFLOPS/s + bandwidth: 11.1 GB/s + duration: 239 ms +Finished +``` + +There are two sub collections of samples codes: (1) a collection of C++ code samples showing either BLAS, Compiler-generated code (inlined code), LIBXSMM/dispatched, LIBXSMM/specialized functions to carry out the multiplication, and (2) a Fortran sample code showing BLAS versus LIBXSMM including some result validation. + +**C/C++ Code Samples: Command Line Interface (CLI)** + +* Takes an optional number (1st arg.) to select the streaming-case (0...8) +* Optionally takes the M, N, and K parameter of the GEMM in this order +* If only M is supplied, the N and K "inherit" the M-value +* Example I (A,B,C): ./specialized.sh 0 16 8 9 +* Example II (A,B): ./specialized.sh 6 16 + +**Fortran Code Sample: Command Line Interface (CLI)** + +* Optionally takes the M, N, and K parameter of the GEMM in this order +* Optional problem size (in MB) of the workload; M/N/K must have been supplied +* Optional total problem size (in MB) implying the number of repeated run +* If only M is supplied, the N and K are "inheriting" the M-value +* Shows the performance of each of the streaming cases +* Example I: ./smm.sh 16 8 9 1024 16384 +* Example II: ./smm.sh 16 + +## SPECFEM Sample + +This sample contains a dummy example from a spectral-element stiffness kernel taken from [SPECFEM3D_GLOBE](https://github.com/geodynamics/specfem3d_globe). + +It is based on a 4th-order, spectral-element stiffness kernel for simulations of elastic wave propagation through the Earth. Matrix sizes used are (25,5), (5,25) and (5,5) determined by different cut-planes through a three dimensional (5,5,5)-element with a total of 125 GLL points. + + +### Usage Step-by-Step + +This example needs the LIBXSMM library to be built with static kernels, using MNK="5 25" (for matrix size (5,25), (25,5) and (5,5)). + +#### Build LIBXSMM + +##### General Default Compilation + +In LIBXSMM root directory, compile the library with: + +```bash +make MNK="5 25" ALPHA=1 BETA=0 +``` + +##### Additional Compilation Examples + +Compilation using only single precision version and aggressive optimization: + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 +``` + +For Sandy Bridge CPUs: + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 AVX=1 +``` + +For Haswell CPUs: + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 AVX=2 +``` + +For Knights Corner (KNC) (and thereby creating a Sandy Bridge version): + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 AVX=1 \ +OFFLOAD=1 KNC=1 +``` + +Installing libraries into a sub-directory workstation/: + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 AVX=1 \ +OFFLOAD=1 KNC=1 \ +PREFIX=workstation/ install-minimal +``` + +#### Build SpecFEM example code + +For default CPU host: + +```bash +cd sample/specfem +make +``` + +For Knights Corner (KNC): + +```bash +cd sample/specfem +make KNC=1 +``` + +Additionally, adding some specific Fortran compiler flags, for example: + +```bash +cd sample/specfem +make FCFLAGS="-O3 -fopenmp" [...] +``` + +Note that steps 1 and 2 could be shortened by specifying a "specfem" make target in the LIBXSMM root directory: + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 AVX=1 specfem +``` + +For Knights Corner, this would need two steps: + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 AVX=1 OFFLOAD=1 KNC=1 +make OPT=3 specfem_mic +``` + +### Run the Performance Test + +For default CPU host: + +```bash +./specfem.sh +``` + +For Knights Corner (KNC): + +```bash +./specfem.sh -mic +``` + +### Results + +Using Intel Compiler suite: icpc 15.0.2, icc 15.0.2, and ifort 15.0.2. + +#### Sandy Bridge - Intel(R) Xeon(R) CPU E5-2670 0 @ 2.60GHz + +Library compilation by (root directory): + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 AVX=1 +``` + +Single threaded example run: + +```bash +cd sample/specfem +make; OMP_NUM_THREADS=1 ./specfem.sh +``` + +Output: + +```bash +=============================================================== +average over 15 repetitions + timing with Deville loops = 0.1269 + timing with unrolled loops = 0.1737 / speedup = -36.87 % + timing with LIBXSMM dispatch = 0.1697 / speedup = -33.77 % + timing with LIBXSMM prefetch = 0.1611 / speedup = -26.98 % + timing with LIBXSMM static = 0.1392 / speedup = -9.70 % +=============================================================== +``` + +#### Haswell - Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz + +Library compilation by (root directory): + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 AVX=2 +``` + +Single threaded example run: + +```bash +cd sample/specfem +make; OMP_NUM_THREADS=1 ./specfem.sh +``` + +Output: + +```bash +=============================================================== +average over 15 repetitions + timing with Deville loops = 0.1028 + timing with unrolled loops = 0.1385 / speedup = -34.73 % + timing with LIBXSMM dispatch = 0.1408 / speedup = -37.02 % + timing with LIBXSMM prefetch = 0.1327 / speedup = -29.07 % + timing with LIBXSMM static = 0.1151 / speedup = -11.93 % +=============================================================== +``` + +Multi-threaded example run: + +```bash +cd sample/specfem +make OPT=3; OMP_NUM_THREADS=24 ./specfem.sh +``` + +Output: + +```bash +OpenMP information: + number of threads = 24 + +[...] + +=============================================================== +average over 15 repetitions + timing with Deville loops = 0.0064 + timing with unrolled loops = 0.0349 / speedup = -446.71 % + timing with LIBXSMM dispatch = 0.0082 / speedup = -28.34 % + timing with LIBXSMM prefetch = 0.0076 / speedup = -19.59 % + timing with LIBXSMM static = 0.0068 / speedup = -5.78 % +=============================================================== +``` + +#### Knights Corner - Intel Xeon Phi B1PRQ-5110P/5120D + +Library compilation by (root directory): + +```bash +make MNK="5 25" ALPHA=1 BETA=0 PRECISION=1 OPT=3 OFFLOAD=1 KNC=1 +``` + +Multi-threaded example run: + +```bash +cd sample/specfem +make FCFLAGS="-O3 -fopenmp -warn" OPT=3 KNC=1; ./specfem.sh -mic +``` + +Output: + +```bash +OpenMP information: + number of threads = 236 + +[...] + +=============================================================== +average over 15 repetitions + timing with Deville loops = 0.0164 + timing with unrolled loops = 0.6982 / speedup = -4162.10 % + timing with LIBXSMM dispatch = 0.0170 / speedup = -3.89 % + timing with LIBXSMM static = 0.0149 / speedup = 9.22 % +=============================================================== +``` + +## Matrix Transpose (TCOPY) + +### Overview + +This code sample aims to benchmark the performance of matrix transposes. The C/C++ and [FORTRAN sample code](https://github.com/hfp/libxsmm/blob/master/samples/transpose/transpose.f) differ slightly with the C/C++ code sample offering a richer set of command line options as well as build settings available inside of the [translation unit](https://github.com/hfp/libxsmm/blob/master/samples/transpose/transpose.c). + +The available command line options of the sample code may be reviewed by looking into the source code. Generally, the idea is to support the following: + +> transpose [<kind> [<m> [<n> [<ldi> [<ldo>]]]]] +transposef [<m> [<n> [<ldi> [<ldo>]]]] + +Above, `m` and `n` specify the matrix shape, and `ldi` the leading dimension of the matrix. The argument `ldo` allows to specify an output dimension, which may differ from `ldi`. The transpose kind shall be either out-of-place (`o`) or in-place (`i`). + +Running the C sample code may look like: + +```bash +$ ./transpose.sh o 20000 +m=20000 n=20000 ldi=20000 ldo=20000 size=3052MB (double, out-of-place) + bandwidth: 18.8 GB/s + duration: 159 ms +``` + +Instead of executing a wrapper script, one may affinitize the multi-threaded execution manually (OpenMP runtime). In case of an executable built using the Intel Compiler this may look like: + +```bash +LIBXSMM_VERBOSE=2 KMP_AFFINITY=balanced,granularity=fine,1 \ +./transpose o 20000 +m=20000 n=20000 ldi=20000 ldo=20000 size=3052MB (double, out-of-place) + bandwidth: 21.1 GB/s + duration: 141 ms + +Registry: 20 MB (gemm=0 mcopy=0 tcopy=1) +``` + +In the above case one can see from the verbose output (`LIBXSMM_VERBOSE=2`) that one kernel (tcopy) served transposing the entire matrix. To avoid duplicating JIT-kernels under contention (code registry), one may also consider `LIBXSMM_TRYLOCK=1`, which is available per API-call as well. + +### OpenTuner + +To tune the tile sizes ("block sizes") internal to LIBXSMM's transpose routine, the [OpenTuner](http://opentuner.org/) extensible framework for program autotuning can be used. In case of issues during the tuning phase ("no value has been set for this column"), please install the latest 1.2.x revision of SQLAlchemy (`pip install sqlalchemy==1.2.19`). A tuning script (`transpose_opentuner.py`) is provided, which accepts a range of matrix sizes as command line arguments. + +> transpose_opentuner.py <begin> <end> [*nexperiments-per-epoch*] [*tile-size-m*] [*tile-size-n*] + +To start a tuning experiment for a new set of arguments, it is highly recommended to start from scratch. Otherwise the population of previously generated tuning results is fetched from a database and used to tune an eventually unrelated range of matrix shapes. To get reliable timings, the total time for all experiments per epoch is minimized (hence a different number of experiments per epoch also asks for an own database). Optionally, the initial block size can be seeded (`tile-size-m` and `tile-size-n`). + +```bash +rm -rf opentuner.db +``` + +The script tunes matrices with randomized shape according to the specified range. The leading dimension is chosen tightly for the experiments. The optimizer not only maximizes the performance but also minimizes the value of *M \* N* (which also helps to prune duplicated results due to an additional preference). + +```bash +rm -rf opentuner.db +./transpose_opentuner.py --no-dups 1 1024 1000 + +rm -rf opentuner.db +./transpose_opentuner.py --no-dups 1024 2048 100 + +rm -rf opentuner.db +./transpose_opentuner.py --no-dups 2048 3072 20 + +rm -rf opentuner.db +./transpose_opentuner.py --no-dups 3072 4096 20 + +rm -rf opentuner.db +./transpose_opentuner.py --no-dups 4096 5120 16 + +rm -rf opentuner.db +./transpose_opentuner.py --no-dups 5120 6144 12 + +rm -rf opentuner.db +./transpose_opentuner.py --no-dups 6144 7168 8 + +rm -rf opentuner.db +./transpose_opentuner.py --no-dups 7168 8192 6 +``` + +The tuning script uses the environment variables `LIBXSMM_TCOPY_M` and `LIBXSMM_TCOPY_N`, which are internal to LIBXSMM. These variables are used to adjust certain thresholds in `libxsmm_otrans` or to request a specific tiling-scheme inside of the `libxsmm_otrans_omp` routine. + +## XGEMM: Tiled GEMM Routines + +### Overview + +This sample code calls the `libxsmm_?gemm_omp` routines provided by the LIBXSMM extension library (`libxsmmext`). These routines are meant for big(ger) xGEMM routines, and thereby provide an OpenMP-based parallelization. + +The driver program (`xgemm.c`) currently accepts all typical GEMM arguments (except for the transposition specifier): `m`, `n`, `k`, `lda`, `ldb`, `ldc`, `alpha`, and `beta`. All arguments are optional (or will inherit defaults from previously specified arguments). Matrix transposition as part of the `libxsmm_?gemm_omp` routines will become available in an upcoming release of LIBXSMM. Please also note that unsupported Alpha or Beta values will cause a fall back to the related BLAS routine. The single-precision matrix multiplications require to change the `ITYPE` in `xgemm.c`. + +```bash +./xgemm.sh 2000 +``` + +### OpenTuner + +To tune the tile sizes ("block sizes") internal to LIBXSMM, the [OpenTuner](http://opentuner.org/) extensible framework for program autotuning can be used. In case of issues during the tuning phase ("no value has been set for this column"), please install the latest 1.2.x revision of SQLAlchemy (`pip install sqlalchemy==1.2.19`). A tuning script (`xgemm_opentuner.py`) is provided, which optionally accepts a list of grouped parameters as command line arguments. The syntax of the arguments is per LIBXSMM's `MNK` build-option, and expands to "triplets" specifying the matrix shapes. For instance, four matrix multiplications of square-matrices can be benchmarked and tuned using the following command. + +```bash +./xgemm_opentuner.py 1024,1280,1536,1792 +``` + +To start a tuning experiment for a new set of arguments, it is highly recommended to start from scratch. Otherwise the population of previously generated tuning results is fetched from a database and used to tune an unrelated range of matrix shapes. Optionally, the initial block size can be seeded (`tile-size-m`, `tile-size-n`, and `tile-size-k`). + +```bash +rm -rf opentuner.db +``` + +The script tunes the geometric mean of the performance for each of the requested triplets. However, the optimizer not only maximizes the performance but also minimizes the value of *M \* N \* K* (which also helps to prune duplicated results due to an additional preference). As a limitation of the current implementation, the multiplication kernels are not accompanied by copy-kernels (and not accompanied by transpose kernels). This negatively impacts performance on power-of-two matrix shapes (POT) due to trashing the LLC. However, it has been found, that tuning for POT shapes likely achieves superior performance when compared to tuning for non-POT shapes of the same range. + +```bash +rm -rf opentuner.db +./xgemm_opentuner.py --no-dups 192,256,320,512,768 + +rm -rf opentuner.db +./xgemm_opentuner.py --no-dups 1024,1280,1536,1792 + +rm -rf opentuner.db +./xgemm_opentuner.py --no-dups 2048,2304,2560,2816 + +rm -rf opentuner.db +./xgemm_opentuner.py --no-dups 3072,3328,3584,3840 + +rm -rf opentuner.db +./xgemm_opentuner.py --no-dups 4096,4416,4736 + +rm -rf opentuner.db +./xgemm_opentuner.py --no-dups 5120,5440,5760 + +rm -rf opentuner.db +./xgemm_opentuner.py --no-dups 6144,6464,6784 + +rm -rf opentuner.db +./xgemm_opentuner.py --no-dups 7168,7488,7808 +``` + +Above, the series of matrix multiplications from 192-8K is separately tuned in eight ranges. The tuning script uses the environment variables `LIBXSMM_TGEMM_M`, `LIBXSMM_TGEMM_N`, and `LIBXSMM_TGEMM_K` which are internal to LIBXSMM. These variables are used to request a specific tiling-scheme within LIBXSMM's `libxsmm_?gemm_omp` routines. + + +This package contains the optimized kernels for the 1D dilated convolutional layer. +The C++ implementation has code for both FP32 and BF16 formats. +You can run this code on AVX-512 enabled CPUs. Ex. - Cascade Lake or Cooper lake. + + Install instructions + +IInstall PyTorch in an anaconda or virtual environment before installing the package. +Use GCC version 8.3.0 or higher. +conda activate environment # Activate anaconda or virtual environment containing PyTorch + +cd Conv1dOpti-extension/ +python setup.py install # Install package +cd .. + + +A user can either use run.sh script to run the torch_example.py code or +he/she can follow the following commands. + +export LD_LIBRARY_PATH={LIBXSMM_ROOT/lib} # Set LD_LIBRARY_PATH +export OMP_NUM_THREADS=28 # Set number of threads +export KMP_AFFINITY=compact,1,0,granularity=fine # Set KMP affinity + +python torch_example.py # Run the pytorch example + +In the previous example, we compare "nn.Conv1d" layer with our optimized "Conv1dOpti" layer. +The example shows how "nn.Conv1d" can be replaced with "Conv1dOpti" layer in a neural network +without requiring any other change. +The optimized python layer can be imported using "from Conv1dOpti_ext import Conv1dOpti" in python. +The example checks the accuracy of the results and calculates the computation time of both layers. + + + Limitations of the current version + +- Keep padding=0 in the options. The current layer doesn't do padding. Explicit padding is needed + for the optimized convolutional layer. You can use the example for reference. +- Optimized convolutional layer code can only run with stride = 1. +- Similarly, apply the nonlinearity (Ex. ReLU) separately. + + +To run code in BFloat16, set enable_BF16 flag to True. BFloat16 code runs only when the parameters of +Input width, number of filters and input channels to the layer are even number. +Ex. - Filters = 16, Channels = 16, Input width = 60000, enable_BF16 = True BF16 run +If any of the previous parameter is odd number then code runs in FP32 format. + + +Keep batch size as multiple of ununtilized cores (Ex. - 28, 56, 84, 128 .... on a 28 core cascade lake) +for optimal performance with the Conv1dOpti layer. Each batch will run on a seperate thread thus +performance may go down if some core are not free, or batch size is not equal to the number of free cores. +Keep the batch size as power of 2 with the MKLDNN backend (Conv1d) for optimal performance. # Deep Learning with GxM + +### Compiling and Building GxM + +1. Install Pre-requisite Libraries: Google logging module (glog), gflags, Google's data interchange format (Protobuf), OpenCV, LMDB +2. In Makefile.config, set GXM_LIBRARY_PATH variable to the path containing above libraries +3. In Makefile.config, set LIBXSMM_PATH variable to the path containing LIBXSMM library +4. Set/clear other flags in Makefile.config as required (see associated comments in Makefile.config) +5. source setup_env.sh +6. make clean; make + +### Running GxM + +The network topology definitions directory is "model_zoo". Currently, it contains definitions for +AlexNet (without LRN), ResNet-50, Inception v3 along with CIFAR10 and MNIST as simple test definitions. +Each topology definition is in a .prototxt file. ResNet-50 can run with "dummy data", raw JPEG image data +or with LMDB. Filenames indicate the data source along with the minibatch size. Inception v3 runs only with +compressed LMDB data. + +The hyperparameter definitions for each topology are also in the corresponding directory under "model_zoo" in +a .prototxt file with the suffix "solver". For a single-node, this file is called solver.prototxt. For multi-node +the filename also contains the global minibatch size (=single node minibatch size x number of nodes);, e.g., solver_896.prototxt contains hyperparameters for MB=56 per node and 16 nodes. The "solver*" file also contains a +flag that specifies whether to start execution from a checkpoint (and thus read load weights from the "./weights" +directory) or from scratch; by default execution starts from scratch. + +Optimal parallelization of Convolutional layers in LIBXSMM happens when the number of OpenMP threads = MiniBatch. +Therefore, on Xeon + +```bash +export OMP_NUM_THREADS= +export KMP_AFFINITY=compact,granularity=fine,1,0 +``` + +The command line for a training run is: + +```bash +./build/bin/gxm train +``` + +For example: + +```bash +./build/bin/gxm train model_zoo/resnet/1_resnet50_dummy_56.prototxt model_zoo/resnet/solver.prototxt +``` + +### Preping on RHEL 8.0 / CentOS 8.0 + +```bash +dnf install protobuf +wget http://mirror.centos.org/centos/8/PowerTools/x86_64/os/Packages/protobuf-compiler-3.5.0-7.el8.x86_64.rpm +dnf install protobuf-compiler-3.5.0-7.el8.x86_64.rpm +wget http://mirror.centos.org/centos/8/PowerTools/x86_64/os/Packages/protobuf-devel-3.5.0-7.el8.x86_64.rpm +dnf install protobuf-devel-3.5.0-7.el8.x86_64.rpm +dnf install lmdb +dnf install lmdb-devel +wget http://repo.okay.com.mx/centos/8/x86_64/release/opencv-devel-3.4.1-9.el8.x86_64.rpm +wget http://repo.okay.com.mx/centos/8/x86_64/release/opencv-3.4.1-9.el8.x86_64.rpm +dnf install opencv-3.4.1-9.el8.x86_64.rpm +dnf install opencv-devel-3.4.1-9.el8.x86_64.rpm +wget http://mirror.centos.org/centos/8/PowerTools/x86_64/os/Packages/gflags-2.1.2-6.el8.x86_64.rpm +wget http://mirror.centos.org/centos/8/PowerTools/x86_64/os/Packages/gflags-devel-2.1.2-6.el8.x86_64.rpm +dnf install gflags-2.1.2-6.el8.x86_64.rpm +dnf install gflags-devel-2.1.2-6.el8.x86_64.rpm +wget http://mirror.centos.org/centos/8/PowerTools/x86_64/os/Packages/glog-devel-0.3.5-3.el8.x86_64.rpm +wget http://mirror.centos.org/centos/8/PowerTools/x86_64/os/Packages/glog-0.3.5-3.el8.x86_64.rpm +dnf install glog-0.3.5-3.el8.x86_64.rpm +dnf install glog-devel-0.3.5-3.el8.x86_64.rpm +``` + +Make sure that the makefile follows the OpenCV Ver 3 path! + +## DNN Training with Incremental Sparsification + Sparse JIT Kernels + +### This project contains code for the following DNN models + +1. Resnet - ported from [link](https://pytorch.org/vision/stable/models.html) +2. Transformer - ported from [link](https://github.com/pytorch/fairseq) +3. DLRM - ported from [link](https://github.com/facebookresearch/dlrm) +4. PCL_MLP - A python extension of the `torch.nn.Linear` module that uses efficient sparse JIT kernels for matrix multiplication (supports forward, backward and update pass) - ported from [link](https://github.com/hfp/libxsmm/tree/master/samples/deeplearning/sparse_weight_mult) + +### Features + +1. Training scripts for all three models located at the root of each directory in a form of a shell file +2. By specifying each of the four parameters, the pruning criteria (magnitude-based or random-based), the pruning start time and end time and target sparsity you can apply incremental sparsity to model weights for training +3. Additionally, by specifying a tensorboard log directory, one can examine training logs and metrics using tensorboard. + +### Data preparation + +Each model requires an extensive amount of data to be properly stress-tested against incremental sparsity. According to [The State of Sparsity](https://arxiv.org/abs/1902.09574) and by extensive experimentation, using a relatively small dataset or an overparameterized model may lead to false performance implications. For instance, when a ResNet-50 model is trained with the CIFAR-10 dataset or if the base Transformer is trained with a limited sentence pair dataset (i.e., EN-VI) it may seem as if the model isn't impacted even with extremely high sparsity since the model was overdetermined to begin with. + +- For Resnet +- For Resnet training, a smaller subset of ImageNet was used, called ImageNette due to its massiveness in size. Download from [here](https://github.com/fastai/imagenette). +- For Transformer +- As a neural machine translation task, the transformer model requires the WMT2014 EN_DE dataset. Preprocessing steps are described [here](https://fairseq.readthedocs.io/en/latest/getting_started.html#data-pre-processing) +- For DLRM +- Training the DLRM requires the terabyte dataset [link](https://labs.criteo.com/2013/12/download-terabyte-click-logs/) + +### Running scripts + +Each project consists of two scripts: a script that launches `sbatch` scripts for experimenting various target sparsities (usually named as `launch_pruning_runs.sh`)and a script that runs a single experiment. Use accordingly. + +1. ResNet model +`./launch_pruning_jobs.sh ${TARGET_SPARSITY}` or +`python train.py ${TARGET_SPARSITY}` +2. Transformer(FAIRSEQ) model +`./launch_pruning_runs.sh` or `./prune_en_de.sh ${TARGET_SPARSITY} ${PRUNE_TYPE} ${EMB}` +where PRUNE_TYPE is either `magnitude` or `random` and EMB indicates whether the embedding portion is pruned alongside the weights +3. DLRM model +`./launch_pruning_runs.sh` or `./run_terabyte.sh ${TARGET_SPARSITY} ${PRUNE_TYPE}` +where PRUNE_TYPE is either `magnitude` or `random` +## Xsmm LSTM + +This code may be integrated with Tensorflow to make use of LIBXSMM's LSTM. Support for creating a Python wheel and a pip package can be found in the [directory](https://github.com/hfp/libxsmm/tree/master/samples/deeplearning/tf_lstm_ops) as well. + +## Dispatch + +### Microbenchmark + +This code sample benchmarks the performance of (1) the dispatch mechanism, and (2) the time needed to JIT-generate code for the first time. Both mechanisms are relevant when replacing GEMM calls (see [Call Wrapper](https://libxsmm.readthedocs.io/libxsmm_mm/#call-wrapper) section of the reference documentation), or in any case of calling LIBXSMM's native [GEMM functionality](https://libxsmm.readthedocs.io/libxsmm_mm/). + +**Command Line Interface (CLI)** + +* Optionally takes the number of dispatches/code-generations (default: 10000). +* Optionally takes the number of threads (default: 1). + +**Measurements (Benchmark)** + +* Duration of an empty function call (serves as a reference timing). +* Duration to find an already generated kernel (cached/non-cached). +* Duration to JIT-generate a GEMM kernel. + +In case of a multi-threaded benchmark, the timings represent a highly contended request (worst case). For thread-scaling, it can be observed that read-only accesses (code dispatch) stay roughly with a constant duration whereas write-accesses (code generation) are serialized and hence the duration scales linearly with the number of threads. + +The [Fortran example](https://github.com/hfp/libxsmm/blob/master/samples/utilities/dispatch/dispatch.f) (`dispatch.f`) could use `libxsmm_dmmdispatch` (or similar) like the C code (`dispatch.c`) but intentionally shows the lower-level dispatch interface `libxsmm_xmmdispatch` and also omits using the LIBXSMM module. Not using the module confirms: the same task can be achieved by relying only on FORTRAN 77 language level. + +### User-Data Dispatch + +Further, another [Fortran example](https://github.com/hfp/libxsmm/blob/master/samples/utilities/dispatch/dispatch_udt.f) about [user-data dispatch](https://libxsmm.readthedocs.io/libxsmm_aux/#user-data-dispatch) is not exactly a benchmark. Dispatching user-data containing multiple kernels can obviously save multiple singular dispatches. The C interface for dispatching user-data is designed to follow the same flow as the Fortran interface. + +## MHD Image I/O + +This code sample aims to provide a simple piece of code, which takes an image and produces a visual result using LIBXSMM's MHD image file I/O. Performing a single convolution is *not* a showcase of LIBXSMM's Deeplearning as the code only runs over a single image with one channel. +LIBXSMM's CNNs are vectorized over image channels (multiple images) according to the native vector-width of the processor and otherwise fall back to a high-level implementation. + +**Note**: For high-performance deep learning, please refer to the collection of [CNN layer samples](https://github.com/hfp/libxsmm/tree/master/samples/deeplearning/cnnlayer). + +The executable can run with the following arguments (all arguments are optional): + +> mhd [<filename-in> [<nrepeat> [<kw> [<kh>] [<filename-out>]]]] + +For stable timing (benchmark), the key operation (convolution) may be repeated (`nrepeat`). Further, `kw` and `kh` can specify the kernel-size of the convolution. The `filename-in` and `filename-out` name MHD-files used as input and output respectively. The `filename-in` may be a pseudo-file (that does not exist) but specify the image resolution of generated input (`w`[x`h`] where the file `wxh.mhd` stores the generated image data). To load an image from a familiar format (JPG, PNG, etc.), please have a look at [Meta Image File I/O](https://libxsmm.readthedocs.io/libxsmm_aux/#meta-image-file-io). + +## Scratch Memory Allocation (Microbenchmark) + +This code sample aims to benchmark the performance of the scratch memory allocation. This facility is a viable option to satisfy the need for temporary memory when using the DNN domain of LIBXSMM (small convolutions). Although any kind of readable/writable buffer can be bound to a convolution handle, LIBXSMM's `libxsmm_aligned_scratch` features a thread-safe linear allocator mechanism which can help to lower allocation overhead. + +## Wrapped DGEMM + +This code sample is calling DGEMM and there is no dependency on the LIBXSMM API as it only relies on LAPACK/BLAS interface. Two variants are linked when building the source code: (1) code which is dynamically linked against LAPACK/BLAS, (2) code which is linked using `--wrap=`*symbol* as possible when using a GNU GCC compatible tool chain. For more information, see the [Call Wrapper](https://libxsmm.readthedocs.io/libxsmm_mm/#call-wrapper) section of the reference documentation. + +The same (source-)code will execute in three flavors when running `dgemm-test.sh`: (1) code variant which is dynamically linked against the originally supplied LAPACK/BLAS library, (2) code variant which is linked using the wrapper mechanism of the GNU GCC tool chain, and (3) the first code but using the LD_PRELOAD mechanism (available under Linux). + +**Command Line Interface (CLI)** + +* Optionally takes the number of repeated DGEMM calls +* Shows the performance of the workload (wall time) + diff --git a/third_party/libxsmm/documentation/libxsmm_samples.pdf b/third_party/libxsmm/documentation/libxsmm_samples.pdf new file mode 100644 index 0000000000000000000000000000000000000000..19b257d36121ff946268d46df2f321b671ee346d Binary files /dev/null and b/third_party/libxsmm/documentation/libxsmm_samples.pdf differ diff --git a/third_party/libxsmm/documentation/libxsmm_tune.md b/third_party/libxsmm/documentation/libxsmm_tune.md new file mode 100644 index 0000000000000000000000000000000000000000..05c793e4eb37fbf2e3ec6523a86baaec978bb666 --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_tune.md @@ -0,0 +1,159 @@ +## Customization + +### Intercepted Allocations + +To improve thread-scalability and to avoid frequent memory allocation/deallocation, the [scratch memory allocator](libxsmm_aux.md#memory-allocation) can be leveraged by intercepting existing malloc/free calls. This facility is built into LIBXSMM's main library, but disabled at compile-time (by default); build with `make MALLOC=1` to permanently enable, or build with `make MALLOC=-1` to even require an environment variable `LIBXSMM_MALLOC=1` or an API-call (`libxsmm_set_malloc`). Both runtime settings allow an optional lower and/or an upper bound to select malloc-calls based on the size of the allocation. For the environment option, an extra variable is introduced, e.g., use `LIBXSMM_MALLOC=1 LIBXSMM_MALLOC_LIMIT=4m:1g`. + +```C +void libxsmm_set_malloc(int enabled, const size_t* lo, const size_t* hi); +int libxsmm_get_malloc(size_t* lo, size_t* hi); +``` + +Querying the status may return zero even if there was an attempt to enable this facility (limitation/experimental implementation). Please note, the regular [Scratch Memory API](libxsmm_aux.md#memory-allocation) (e.g., `libxsmm_[get|set]_scratch_limit`) and the related environment variables can apply as well (`LIBXSMM_SCRATCH_LIMIT`, `LIBXSMM_SCRATCH_POOLS`, `LIBXSMM_SCRATCH_SCALE`). If intercepted memory allocations are enabled, the scratch limit is adjusted by default to allow unlimited growth of the scratch domain. Further, an increased verbosity level can help to gain some insight (`LIBXSMM_VERBOSE=3`). + +Intercepting malloc/free is supported by linking LIBXSMM's static or shared main library. The latter of which can be used to intercept calls of an existing and unchanged binary (LD_PRELOAD mechanism). To statically link with LIBXSMM and to intercept existing malloc/free calls, the following changes to the application's link stage are recommended: + +```bash +gcc [...] -Wl,--export-dynamic \ + -Wl,--wrap=malloc,--wrap=calloc,--wrap=realloc \ + -Wl,--wrap=memalign,--wrap=free \ + /path/to/libxsmm.a +``` + +The main library causes a BLAS-dependency which may be already fulfilled for the application in question. However, if this is not the case (unresolved symbols), `libxsmmnoblas.a` must be linked in addition. Depending on the dependencies of the application, the link order may also need to be adjusted. Other i.e. a GNU-compatible compiler (as shown above), can induce additional requirements (compiler runtime libraries). + +**Note**: The Intel Compiler may need "libirc", i.e., `-lirc` in front of `libxsmm.a`. Linking LIBXSMM's static library may require above mentioned linker flags (`--wrap`) in particular when using Intel Fortran (IFORT) as a linker driver unless `CALL libxsmm_init()` is issued (or at least one symbol of LIBXSMM's main library is referenced; check with `nm application | grep libxsmm`). Linking the static library by using the GNU compiler does not strictly need special flags when linking the application. + +Linking the shared library form of LIBXSMM (`make STATIC=0`) has similar requirements with respect to the application but does not require `-Wl,--wrap` although `-Wl,--export-dynamic` is necessary if the application is statically linked (beside of LIBXSMM linked in a shared fashion). The LD_PRELOAD based mechanism does not need any changes to the link step of an application. However, `libxsmmnoblas` may be required if the application does not already link against BLAS. + +```bash +LD_PRELOAD="libxsmm.so libxsmmnoblas.so" +LD_LIBRARY_PATH=/path/to/libxsmm/lib:${LD_LIBRARY_PATH} +LIBXSMM_MALLOC=1 +``` + +**Note**: If the application already uses BLAS, of course `libxsmmnoblas` must not be used! + +The following code can be compiled and linked with `gfortran example.f -o example`: + +```fortran + PROGRAM allocate_test + DOUBLE PRECISION, ALLOCATABLE :: a(:), b(:), c(:) + INTEGER :: i, repeat = 100000 + DOUBLE PRECISION :: t0, t1, d + + ALLOCATE(b(16*1024)) + ALLOCATE(c(16*1024)) + CALL CPU_TIME(t0) + DO i = 1, repeat + ALLOCATE(a(16*1024*1024)) + DEALLOCATE(a) + END DO + CALL CPU_TIME(t1) + DEALLOCATE(b) + DEALLOCATE(c) + d = t1 - t0 + + WRITE(*, "(A,F10.1,A)") "duration:", (1D3 * d), " ms" + END PROGRAM +``` + +Running with `LIBXSMM_VERBOSE=3 LIBXSMM_MALLOC=1 LD_PRELOAD=... LD_LIBRARY_PATH=... ./example` displays: `Scratch: 132 MB (mallocs=1, pools=1)` which shows the innermost allocation/deallocation was served by the scratch memory allocator. + +### Static Specialization + +By default, LIBXSMM uses the [JIT backend](index.md#jit-backend) which is automatically building optimized code (JIT=1). Matrix multiplication kernels can be also statically specialized at compile-time of the library (M, N, and K values). This mechanism also extends the interface of the library because function prototypes are included into both the C and FORTRAN interface. + +```bash +make M="2 4" N="1" K="$(echo $(seq 2 5))" +``` + +The above example is generating the following set of (M,N,K) triplets: + +```bash +(2,1,2), (2,1,3), (2,1,4), (2,1,5), +(4,1,2), (4,1,3), (4,1,4), (4,1,5) +``` + +The index sets are in a loop-nest relationship (M(N(K))) when generating the indexes. Moreover, an empty index set resolves to the next non-empty outer index set of the loop nest (including to wrap around from the M to K set). An empty index set does not participate in the loop-nest relationship. Here is an example of generating multiplication routines which are "squares" with respect to M and N (N inherits the current value of the "M loop"): + +```bash +make M="$(echo $(seq 2 5))" K="$(echo $(seq 2 5))" +``` + +An even more flexible specialization is possible by using the MNK variable when building the library. It takes a list of indexes which are eventually grouped (using commas): + +```bash +make MNK="2 3, 23" +``` + +Each group of the above indexes is combined into all possible triplets generating the following set of (M,N,K) values: + +```bash +(2,2,2), (2,2,3), (2,3,2), (2,3,3), +(3,2,2), (3,2,3), (3,3,2), (3,3,3), (23,23,23) +``` + +Of course, both mechanisms (M/N/K and MNK based) can be combined by using the same command line (make). Static optimization and JIT can also be combined (no need to turn off the JIT backend). + +### User-Data Dispatch + +It can be desired to dispatch user-defined data, i.e., to query a value based on a key. This functionality can be used to, e.g., dispatch multiple kernels in one step if a code location relies on multiple kernels. This way, one can pay the cost of dispatch one time per task rather than according to the number of JIT-kernels used by this task. This functionality is detailed in the section about [Service Functions](libxsmm_aux.md#user-data-dispatch). + +### Targeted Compilation + +Specifying a code path is not necessary if the JIT backend is not disabled. However, disabling JIT compilation, statically generating a collection of kernels, and targeting a specific instruction set extension for the entire library looks like: + +```bash +make JIT=0 AVX=3 MNK="1 2 3 4 5" +``` + +The above example builds a library which cannot be deployed to anything else but the Intel Knights Landing processor family ("KNL") or future Intel Xeon processors supporting foundational Intel AVX‑512 instructions (AVX‑512F). The latter might be even more adjusted by supplying MIC=1 (along with AVX=3), however this does not matter since critical code is in inline assembly (and not affected). Similarly, SSE=0 (or JIT=0 without SSE or AVX build flag) employs an "arch-native" approach whereas AVX=1, AVX=2 (with FMA), and AVX=3 are specifically selecting the kind of Intel AVX code. Moreover, controlling the target flags manually or adjusting the code optimizations is also possible. The following example is GCC-specific and corresponds to OPT=3, AVX=3, and MIC=1: + +```bash +make OPT=3 TARGET="-mavx512f -mavx512cd -mavx512er -mavx512pf" +``` + +An extended interface can be generated which allows to perform software prefetches. Prefetching data might be helpful when processing batches of matrix multiplications where the next operands are farther away or otherwise unpredictable in their memory location. The prefetch strategy can be specified similar as shown in the section [Generator Driver](libxsmm_be.md#generator-driver), i.e., by either using the number of the shown enumeration, or by exactly using the name of the prefetch strategy. The only exception is PREFETCH=1 which is automatically selecting a strategy per an internal table (navigated by CPUID flags). The following example is requesting the "AL2jpst" strategy: + +```bash +make PREFETCH=8 +``` + +The prefetch interface is extending the signature of all kernels by three arguments (pa, pb, and pc). These additional arguments are specifying the locations of the operands of the next multiplication (the next a, b, and c matrices). Providing unnecessary arguments in case of the three-argument kernels is not big a problem (beside of some additional call-overhead), however running a 3-argument kernel with more than three arguments and thereby picking up garbage data is misleading or disabling the hardware prefetcher (due to software prefetches). In this case, a misleading prefetch location is given plus an eventual page fault due to an out-of-bounds (garbage-)location. + +Further, a generated configuration ([template](https://github.com/hfp/libxsmm/blob/master/include/libxsmm_config.h)) of the library encodes the parameters for which the library was built for (static information). This helps optimizing client code related to the library's functionality. For example, the LIBXSMM_MAX_\* and LIBXSMM_AVG_\* information can be used with the LIBXSMM_PRAGMA_LOOP_COUNT macro to hint loop trip counts when handling matrices related to the problem domain of LIBXSMM. + +### Auto-dispatch + +The function `libxsmm_?mmdispatch` helps amortizing the cost of the dispatch when multiple calls with the same M, N, and K are needed. The automatic code dispatch is orchestrating two levels: + +1. Specialized routine (implemented in assembly code), +2. BLAS library call (fallback). + +Both levels are accessible directly, which allows to customize the code dispatch. The fallback level may be supplied by the Intel Math Kernel Library (Intel MKL) 11.2 DIRECT CALL feature. + +Further, a preprocessor symbol denotes the largest problem-size (*M* x *N* x *K*) that belongs to the first level, and therefore determines if a matrix multiplication falls back to BLAS. The problem-size threshold can be configured by using for example: + +```bash +make THRESHOLD=$((60 * 60 * 60)) +``` + +The maximum of the given threshold and the largest requested specialization refines the value of the threshold. Please note that explicitly JIT'ting and executing a kernel is possible and independent of the threshold. If a problem-size is below the threshold, dispatching the code requires to figure out whether a specialized routine exists or not. + +For statically generated code, the precision can be selected: + +```bash +make PRECISION=2 +``` + +The default preference is to generate and register both single and double-precision code (PRECISION=0). Specifying PRECISION=1|2 is generating and registering single-precision or double-precision code respectively. + +The automatic dispatch is highly convenient because existing GEMM calls can serve specialized kernels (even in a binary compatible fashion), however there is (and always will be) an overhead associated with looking up the code-registry and checking whether the code determined by the GEMM call is already JIT'ted or not. This lookup has been optimized with various techniques such as specialized CPU instructions to calculate CRC32 checksums, to avoid costly synchronization (needed for thread-safety) until it is ultimately known that the requested kernel is not yet JIT'ted, and by implementing a small thread-local cache of recently dispatched kernels. The latter of which can be adjusted in size (only power-of-two sizes) but also disabled: + +```bash +make CACHE=0 +``` + +Please note that measuring the relative cost of automatically dispatching a requested kernel depends on the kernel size (obviously smaller matrices are multiplied faster on an absolute basis), however smaller matrix multiplications are bottlenecked by memory bandwidth rather than arithmetic intensity. The latter implies the highest relative overhead when (artificially) benchmarking the very same multiplication out of the CPU-cache. + diff --git a/third_party/libxsmm/documentation/libxsmm_valid.md b/third_party/libxsmm/documentation/libxsmm_valid.md new file mode 100644 index 0000000000000000000000000000000000000000..fda4ff307dcb887094c67366ff16be03647dec4d --- /dev/null +++ b/third_party/libxsmm/documentation/libxsmm_valid.md @@ -0,0 +1,97 @@ +## Basic Tests + +To run basic [tests](http://libxsmm.readthedocs.io/#classic-library-abi): + +```bash +make tests +``` + +Remember: a set of key-value pairs represents a single unique (re-)build (and test): + +```bash +make STATIC=0 tests +``` + +There is a whole collection of test targets available (`test-cp2k`, `test-cpp`, `test-nek`). However, it is then better to rely on test-suites. + +## Test Suites + +It is possible to run tests like LIBXSMM's continuous integration ([https://travis-ci.org/hfp/libxsmm](https://travis-ci.org/hfp/libxsmm)): + +```bash +scripts/tool_test.sh +``` + +The above command runs the entire collection ("scripts/tool_test.sh 0"). However, one test (of currently 11 tests) can be selected by number (1-11): + +```bash +scripts/tool_test.sh 1 +``` + +The suite itself can be also selected. For example, some DNN tests are described in `.test-dnn.yml`: + +```bash +TESTSET=test-dnn scripts/tool_test.sh +``` + +In general, all key-value pairs valid for LIBXSMM's `make` can be given as part of the environment: + +```bash +AVX=3 MIC=0 TESTSET=test-dnn scripts/tool_test.sh +``` + +Please note, the suite/test itself may be comprised of key-value pairs that take precedence. + +## CI Tests + +The `tool_test.sh` script is included in repository archives and releases i.e., it works for non-repository folders. In contrast, the Continuous Integration (CI) use case relies on the Git command being present and the folder being a Git-clone. + +Functionality + +* `[skip ci]` as part of a commit message will not trigger the CI agents, and tests are skipped for such a commit. +* `[full ci]` as part of a commit message will trigger a full test even if the setup uses the "Fast CI" option. + +The "Fast CI" option is enabled per filename given as 2nd command line argument: + +```bash +scripts/tool_test.sh 1 .fullci +``` + +In the above example, a file named `.fullci` may contain path/file patterns (wildcard format) triggering a full test if the files changed by the commit match any of the patterns. + +## Portability + +It is desirable to exercise portability and reliability of LIBXSMM's source code even on Non-Intel Architecture by the means of compilation, linkage, and generic tests. This section is *not* about Intel Architecture (or compatible). Successful compilation (or even running some of the tests successfully) does not mean LIBXSMM is valuable on that platform. + +Make sure to rely on `PLATFORM=1`, otherwise a compilation error should occur _Intel Architecture or compatible CPU required!_ This error avoids (automated) attempts to upstream LIBXSMM to an unsupported platform. LIBXSMM is upstreamed for Intel Architecture on all major Linux distributions, FreeBSD, and others. If compilation fails with _LIBXSMM is only supported on a 64-bit platform!_, `make PLATFORM=1 DBG=1` can be used to exercise compilation. + +If platform support is forced (`PLATFORM=1`), runtime code generation is disabled at compile-time (`JIT=0`). Runtime code generation can be also enabled (`PLATFORM=1 JIT=1`) but code-dispatch will still return NULL-kernels. However, some tests will start failing as missing JIT-support it is not signaled at compile-time as with `JIT=0`. + +**Note**: JIT-support normally guarantees a non-NULL code pointer ("kernel") if the request is according to the [limitations](https://github.com/hfp/libxsmm/wiki/Q&A#what-is-a-small-matrix-multiplication) (user-code is not asked to check for a NULL-kernel), which does not hold true if JIT is enabled on a platform that does not implement it. + +### TinyCC + +The Tiny C Compiler (TinyCC) supports Intel Architecture, but lacks at least support for thread-local storage (TLS). + +```bash +make CC=tcc THREADS=0 INTRINSICS=0 VLA=0 ASNEEDED=0 BLAS=0 FORCE_CXX=0 +``` + +### IBM XL Compiler for Linux (POWER) + +The POWER platform requires aforementioned `PLATFORM=1` to unlock compilation. + +```bash +make PLATFORM=1 CC=xlc CXX=xlc++ FC=xlf +``` + +### Cross-compilation for ARM + +ARM AArch64 is regularly [supported](https://github.com/hfp/libxsmm/wiki/Compatibility#arm-aarch64). However, 32-bit ARM requires aforementioned `PLATFORM=1` to unlock compilation (similar to 32-bit Intel Architecture). Unlocking compilation for 32-bit ARM is not be confused with supporting 32-bit ARM architectures. + +```bash +make PLATFORM=1 AR=arm-linux-gnueabi-ar \ + FC=arm-linux-gnueabi-gfortran \ + CXX=arm-linux-gnueabi-g++ \ + CC=arm-linux-gnueabi-gcc +``` diff --git a/third_party/libxsmm/ide/_vs2019-configure.bat b/third_party/libxsmm/ide/_vs2019-configure.bat new file mode 100644 index 0000000000000000000000000000000000000000..8ae4e7ca6d1fb52c90db2e0320e180c8ff1cef9b --- /dev/null +++ b/third_party/libxsmm/ide/_vs2019-configure.bat @@ -0,0 +1,17 @@ +@ECHO OFF +SETLOCAL + +ECHO ================================================================================ +ECHO One-time configuration (Cygwin w/ GNU GCC, GNU Make, and Python needed in PATH) +ECHO When configured, it is sufficient to start _vs2019.bat or _vs2019.sln +ECHO IMPORTANT: due to zero-config, configuration is not necessary anymore! +ECHO One may terminate this configuration (CTRL-C) +ECHO and simply start _vs2019.bat or _vs2019.sln. +PAUSE +cd .. +bash -c "make realclean ; make headers sources" +cd ide + +CALL %~d0"%~p0"_vs2019.bat + +ENDLOCAL \ No newline at end of file diff --git a/third_party/libxsmm/ide/libxsmm_generator_gemm_driver.vcxproj b/third_party/libxsmm/ide/libxsmm_generator_gemm_driver.vcxproj new file mode 100644 index 0000000000000000000000000000000000000000..32b535b30c4bff7ea4d3303da643b494874ad562 --- /dev/null +++ b/third_party/libxsmm/ide/libxsmm_generator_gemm_driver.vcxproj @@ -0,0 +1,395 @@ + + + + + debug + Win32 + + + debug + x64 + + + symbols + Win32 + + + symbols + x64 + + + release + Win32 + + + release + x64 + + + + + + + + + + + + libxsmm_generator_gemm_driver + {47EDE325-4516-48DA-862B-F689F12DDBD3} + 10.0 + + + + Application + Disabled + Disabled + v142 + + true + + + Application + true + true + Disabled + Disabled + v142 + + + + Application + true + true + Disabled + Disabled + v142 + + true + + + Application + Disabled + Disabled + v142 + + true + + + true + Application + true + Disabled + Disabled + v142 + + + + true + Application + true + Disabled + Disabled + true + v142 + + + + + + + + + + + + + + + + + + + + + + + + + + <_ProjectFileVersion>10.0.30319.1 + bin\ia32\ + bin\ia32\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + bin\intel64\ + bin\intel64\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + bin\ia32\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + bin\intel64\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + + + $(ProjectName)-$(Configuration) + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + + + $(ProjectName)-$(Configuration) + obj\$(Platform)-$(Configuration)\$(ProjectName)\ + + + + + $(ProjectName)-$(Configuration) + + + $(ProjectName)-$(Configuration) + + + + Full + $(SolutionDir)..\include;$(LIBXSMMROOT)\include;%(AdditionalIncludeDirectories) + __SUPPRESS_FOR_PRODUCT;_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES;_CRT_SECURE_NO_DEPRECATE;_SCL_SECURE_NO_DEPRECATE;_USE_MATH_DEFINES;WIN32_LEAN_AND_MEAN;NOMINMAX;NDEBUG;%(PreprocessorDefinitions) + true + MultiThreadedDLL + false + Level4 + Fast + NoTraps + true + true + StreamingSIMDExtensions2 + None + false + true + 3948,10373,10382 + HOST + true + + + 0x0407 + + + $(OutDir)$(TargetName)$(TargetExt) + true + true + true + Console + $(SolutionDir)..\lib\ia32;$(LIBXSMMROOT)\lib\ia32;%(AdditionalLibraryDirectories) + libxsmm.lib;libxsmmnoblas.lib;%(AdditionalDependencies) + + + + + Console + + + + + + MaxSpeed + $(SolutionDir)..\include;$(LIBXSMMROOT)\include;%(AdditionalIncludeDirectories) + __SUPPRESS_FOR_PRODUCT;_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES;_CRT_SECURE_NO_DEPRECATE;_SCL_SECURE_NO_DEPRECATE;_USE_MATH_DEFINES;WIN32_LEAN_AND_MEAN;NOMINMAX;NDEBUG;%(PreprocessorDefinitions) + true + MultiThreadedDLL + false + Level4 + Fast + NoTraps + true + true + StreamingSIMDExtensions2 + None + false + true + SingleFile + 3948,10373,10382 + HOST + true + + + 0x0407 + + + $(OutDir)$(TargetName)$(TargetExt) + true + true + true + Console + $(SolutionDir)..\lib\ia32;$(LIBXSMMROOT)\lib\ia32;%(AdditionalLibraryDirectories) + libxsmm-$(Configuration).lib;libxsmmnoblas-$(Configuration).lib;%(AdditionalDependencies) + true + + + + + Console + + + + + + X64 + + + Full + $(SolutionDir)..\include;$(LIBXSMMROOT)\include;%(AdditionalIncludeDirectories) + __SUPPRESS_FOR_PRODUCT;_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES;_CRT_SECURE_NO_DEPRECATE;_SCL_SECURE_NO_DEPRECATE;_USE_MATH_DEFINES;WIN32_LEAN_AND_MEAN;NOMINMAX;NDEBUG;%(PreprocessorDefinitions) + true + MultiThreadedDLL + false + Level4 + Fast + NoTraps + true + true + None + false + true + 3948,10373,10382 + HOST + true + + + 0x0407 + + + $(OutDir)$(TargetName)$(TargetExt) + true + true + Console + $(SolutionDir)..\lib\intel64;$(LIBXSMMROOT)\lib\intel64;%(AdditionalLibraryDirectories) + libxsmm.lib;libxsmmnoblas.lib;%(AdditionalDependencies) + + + + + Console + + + + + + X64 + + + MaxSpeed + $(SolutionDir)..\include;$(LIBXSMMROOT)\include;%(AdditionalIncludeDirectories) + __SUPPRESS_FOR_PRODUCT;_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES;_CRT_SECURE_NO_DEPRECATE;_SCL_SECURE_NO_DEPRECATE;_USE_MATH_DEFINES;WIN32_LEAN_AND_MEAN;NOMINMAX;NDEBUG;%(PreprocessorDefinitions) + true + MultiThreadedDLL + false + Level4 + Fast + NoTraps + true + true + None + false + true + SingleFile + 3948,10373,10382 + HOST + true + + + 0x0407 + + + $(OutDir)$(TargetName)$(TargetExt) + true + true + Console + $(SolutionDir)..\lib\intel64;$(LIBXSMMROOT)\lib\intel64;%(AdditionalLibraryDirectories) + libxsmm-$(Configuration).lib;libxsmmnoblas-$(Configuration).lib;%(AdditionalDependencies) + true + + + + + Console + + + + + + Disabled + $(SolutionDir)..\include;$(LIBXSMMROOT)\include;%(AdditionalIncludeDirectories) + __SUPPRESS_FOR_PRODUCT;_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES;_CRT_SECURE_NO_DEPRECATE;_SCL_SECURE_NO_DEPRECATE;_USE_MATH_DEFINES;WIN32_LEAN_AND_MEAN;NOMINMAX;_DEBUG;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + Level4 + ProgramDatabase + None + false + true + 3948,10373,10382 + HOST + true + + + 0x0407 + + + $(OutDir)$(TargetName)$(TargetExt) + true + true + true + Console + $(SolutionDir)..\lib\ia32;$(LIBXSMMROOT)\lib\ia32;%(AdditionalLibraryDirectories) + libxsmm-$(Configuration).lib;libxsmmnoblas-$(Configuration).lib;%(AdditionalDependencies) + MSVCRT + + + + + Console + + + + + + X64 + + + Disabled + $(SolutionDir)..\include;$(LIBXSMMROOT)\include;%(AdditionalIncludeDirectories) + __SUPPRESS_FOR_PRODUCT;_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES;_CRT_SECURE_NO_DEPRECATE;_SCL_SECURE_NO_DEPRECATE;_USE_MATH_DEFINES;WIN32_LEAN_AND_MEAN;NOMINMAX;_DEBUG;%(PreprocessorDefinitions) + MultiThreadedDebugDLL + Level4 + ProgramDatabase + None + false + true + 3948,10373,10382 + HOST + true + + + 0x0407 + + + $(OutDir)$(TargetName)$(TargetExt) + true + true + Console + $(SolutionDir)..\lib\intel64;$(LIBXSMMROOT)\lib\intel64;%(AdditionalLibraryDirectories) + libxsmm-$(Configuration).lib;libxsmmnoblas-$(Configuration).lib;%(AdditionalDependencies) + MSVCRT + + + + + Console + + + + + + + \ No newline at end of file diff --git a/third_party/libxsmm/include/.make b/third_party/libxsmm/include/.make new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/libxsmm/include/libxsmm.f b/third_party/libxsmm/include/libxsmm.f new file mode 100644 index 0000000000000000000000000000000000000000..828e32059f6393064a944011989d3f005ceee559 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm.f @@ -0,0 +1,2087 @@ +!=======================================================================! +! Copyright (c) Intel Corporation - All rights reserved. ! +! This file is part of the LIBXSMM library. ! +! ! +! For information on the license, see the LICENSE file. ! +! Further information: https://github.com/hfp/libxsmm/ ! +! SPDX-License-Identifier: BSD-3-Clause ! +!=======================================================================! +! Hans Pabst (Intel Corp.) +!=======================================================================! + + MODULE LIBXSMM + USE, INTRINSIC :: ISO_C_BINDING, ONLY: & + & C_DOUBLE, C_FLOAT, C_DOUBLE_COMPLEX, C_FLOAT_COMPLEX, & + & C_LONG_LONG, C_INT, C_SHORT, C_CHAR, C_INT8_T, C_BOOL, & + & C_F_POINTER, C_ASSOCIATED, C_LOC, C_PTR, & + & C_FUNPTR, C_NULL_FUNPTR, C_NULL_PTR + IMPLICIT NONE + + !> Name of the version (stringized set of version numbers). + CHARACTER(*), PARAMETER :: LIBXSMM_VERSION = "1.16.1-1534" + !> Name of the branch of which the version is derived from. + CHARACTER(*), PARAMETER :: LIBXSMM_BRANCH = "master" + !> Major version based on the last reachable tag under RCS. + INTEGER(C_INT), PARAMETER :: LIBXSMM_VERSION_MAJOR = 1 + !> Minor version based on the last reachable tag of the RCS. + INTEGER(C_INT), PARAMETER :: LIBXSMM_VERSION_MINOR = 16 + !> Update number based on the last reachable tag under RCS. + INTEGER(C_INT), PARAMETER :: LIBXSMM_VERSION_UPDATE = 1 + !> Patch number counting commits since the last version stamp. + INTEGER(C_INT), PARAMETER :: LIBXSMM_VERSION_PATCH = 1534 + + !> Parameters the library and static kernels were built for. + INTEGER(C_INT), PARAMETER :: LIBXSMM_CACHELINE = 64 + INTEGER(C_INT), PARAMETER :: LIBXSMM_ALIGNMENT = 64 + INTEGER(C_INT), PARAMETER :: LIBXSMM_PREFETCH = -1 + INTEGER(C_INT), PARAMETER :: LIBXSMM_MAX_MNK = 262144 + INTEGER(C_INT), PARAMETER :: LIBXSMM_MAX_DIM = 64 + INTEGER(C_INT), PARAMETER :: LIBXSMM_FLAGS = 0 + INTEGER(C_INT), PARAMETER :: LIBXSMM_ILP64 = 0 + + !> Parameters supplied for backward compatibility (deprecated). + INTEGER(C_INT), PARAMETER :: LIBXSMM_COL_MAJOR = 1 + INTEGER(C_INT), PARAMETER :: LIBXSMM_ROW_MAJOR = 0 + + !> LIBXSMM_BLASINT_KIND impacts BLAS interface (LP64: 32-bit, ILP64: 64-bit). + INTEGER(C_INT), PARAMETER :: LIBXSMM_BLASINT_KIND = C_INT + !> Integer kind used by timer interface. + INTEGER(C_INT), PARAMETER :: LIBXSMM_TICKINT_KIND = C_LONG_LONG + + !> Parameters representing the GEMM performed by the simplified interface. + REAL(C_DOUBLE), PARAMETER :: LIBXSMM_ALPHA = REAL(1, C_DOUBLE) + REAL(C_DOUBLE), PARAMETER :: LIBXSMM_BETA = REAL(1, C_DOUBLE) + + !> Flag enumeration which can be IORed. + INTEGER(C_INT), PARAMETER :: & + & LIBXSMM_GEMM_FLAG_NONE = 0, & + & LIBXSMM_GEMM_FLAG_TRANS_A = 1, & + & LIBXSMM_GEMM_FLAG_TRANS_B = 2, & + & LIBXSMM_GEMM_FLAG_TRANS_AB = IOR( & + & LIBXSMM_GEMM_FLAG_TRANS_A, LIBXSMM_GEMM_FLAG_TRANS_B), & + & LIBXSMM_GEMM_FLAG_BETA_0 = 16, & + & LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT = 2176, & + & LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0 = IOR( & + & LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT, & + & LIBXSMM_GEMM_FLAG_BETA_0) + + !> Flag enumeration which can be IORed. + INTEGER(C_INT), PARAMETER :: & + ! Handle recorded batch unsynchronized-parallel. + & LIBXSMM_MMBATCH_FLAG_DEFAULT = 0, & + ! Synchronize among C matrices. + & LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED = 512, & + ! Handle recorded batch sequentially. + & LIBXSMM_MMBATCH_FLAG_SEQUENTIAL = 1024, & + ! Only record a statistic of potential SMMs. + & LIBXSMM_MMBATCH_FLAG_STATISTIC = 2048 + + !> Enumerates element/data types. + INTEGER(C_INT), PARAMETER :: & + & LIBXSMM_DATATYPE_F64 = 0, & + & LIBXSMM_DATATYPE_F32 = 1, & + & LIBXSMM_DATATYPE_BF16 = 2, & + & LIBXSMM_DATATYPE_I64 = 3, & + & LIBXSMM_DATATYPE_I32 = 4, & + & LIBXSMM_DATATYPE_I16 = 5, & + & LIBXSMM_DATATYPE_I8 = 6, & + & LIBXSMM_DATATYPE_UNSUPPORTED = 7 + + !> Denotes the precision/data type of GEMM (for weak-typed + !> interface functions such as libxsmm_xmmdispatch). + INTEGER(C_INT), PARAMETER :: & + & LIBXSMM_GEMM_PRECISION_F64 = LIBXSMM_DATATYPE_F64, & + & LIBXSMM_GEMM_PRECISION_F32 = LIBXSMM_DATATYPE_F32, & + & LIBXSMM_GEMM_PRECISION_BF16 = LIBXSMM_DATATYPE_BF16, & + & LIBXSMM_GEMM_PRECISION_I32 = LIBXSMM_DATATYPE_I32, & + & LIBXSMM_GEMM_PRECISION_I16 = LIBXSMM_DATATYPE_I16, & + & LIBXSMM_GEMM_PRECISION_I8 = LIBXSMM_DATATYPE_I8 + + !> Enumeration of the available prefetch strategies which can be IORed. + INTEGER(C_INT), PARAMETER :: & + ! Automatically select strategy (frontend). + & LIBXSMM_PREFETCH_AUTO = -1, & + ! No prefetching and no prefetch function signature. + & LIBXSMM_PREFETCH_NONE = 0, & + ! Only function prefetch signature. + & LIBXSMM_PREFETCH_SIGONLY = 1, & + ! Prefetch PA using accesses to A. + & LIBXSMM_GEMM_PREFETCH_AL2 = 2, & + ! Prefetch PB using accesses to C. + & LIBXSMM_GEMM_PREFETCH_BL2_VIA_C = 4, & + ! Prefetch A ahead. + & LIBXSMM_GEMM_PREFETCH_AL2_AHEAD = 8, & + ! Composed prefetch strategies. + & LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C = IOR( & + & LIBXSMM_GEMM_PREFETCH_BL2_VIA_C, & + & LIBXSMM_GEMM_PREFETCH_AL2), & + & LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD = IOR( & + & LIBXSMM_GEMM_PREFETCH_BL2_VIA_C, & + & LIBXSMM_GEMM_PREFETCH_AL2_AHEAD), & + ! Current B into L1. + & LIBXSMM_GEMM_PREFETCH_BL1 = 16 + + !> Enumerates the available target architectures and instruction + !> set extensions as returned by libxsmm_get_target_archid(). + INTEGER(C_INT), PARAMETER :: & + & LIBXSMM_TARGET_ARCH_UNKNOWN = 0, & + & LIBXSMM_TARGET_ARCH_GENERIC = 1, & + & LIBXSMM_X86_GENERIC = 1002, & + & LIBXSMM_X86_SSE3 = 1003, & + & LIBXSMM_X86_SSE4 = 1004, & + & LIBXSMM_X86_AVX = 1005, & + & LIBXSMM_X86_AVX2 = 1006, & + & LIBXSMM_X86_AVX512 = 1007, & + & LIBXSMM_X86_AVX512_MIC = 1010, & + & LIBXSMM_X86_AVX512_KNM = 1011, & + & LIBXSMM_X86_AVX512_CORE = 1020, & + & LIBXSMM_X86_AVX512_CLX = 1021, & + & LIBXSMM_X86_AVX512_CPX = 1022 + + !> Generic function type (double-precision). + TYPE, BIND(C) :: LIBXSMM_DMMFUNCTION + TYPE(C_FUNPTR) :: handle = C_NULL_FUNPTR + END TYPE + + !> Generic function type (single-precision). + TYPE, BIND(C) :: LIBXSMM_SMMFUNCTION + TYPE(C_FUNPTR) :: handle = C_NULL_FUNPTR + END TYPE + + !> Generic function type (low-precision) + TYPE, BIND(C) :: LIBXSMM_WIMMFUNCTION + TYPE(C_FUNPTR) :: handle = C_NULL_FUNPTR + END TYPE + + !> Generic function types with certain arity. + ABSTRACT INTERFACE + PURE SUBROUTINE LIBXSMM_FUNCTION3(a, b, c) BIND(C) + IMPORT :: C_PTR + TYPE(C_PTR), INTENT(IN), VALUE :: a, b, c + END SUBROUTINE + + PURE SUBROUTINE LIBXSMM_FUNCTION6(a, b, c, pa, pb, pc) BIND(C) + IMPORT :: C_PTR + TYPE(C_PTR), INTENT(IN), VALUE :: a, b, c + TYPE(C_PTR), INTENT(IN), VALUE :: pa, pb, pc + END SUBROUTINE + END INTERFACE + + !> Structure of differences with matrix norms according + !> to http://www.netlib.org/lapack/lug/node75.html). + TYPE, BIND(C) :: LIBXSMM_MATDIFF_INFO + REAL(C_DOUBLE) norm1_abs, norm1_rel !! One-norm + REAL(C_DOUBLE) normi_abs, normi_rel !! Infinity-norm + REAL(C_DOUBLE) normf_rel !! Froebenius-norm + !> Maximum difference, L2-norm (absolute and relative), and R-squared. + REAL(C_DOUBLE) linf_abs, linf_rel, l2_abs, l2_rel, rsq + !> Statistics: sum/l1, min., max., arith. avg., and variance. + REAL(C_DOUBLE) l1_ref, min_ref, max_ref, avg_ref, var_ref + !> Statistics: sum/l1, min., max., arith. avg., and variance. + REAL(C_DOUBLE) l1_tst, min_tst, max_tst, avg_tst, var_tst + !> Values (v_ref, v_tst) and location (m, n) of largest linf_abs. + REAL(C_DOUBLE) v_ref, v_tst + !> Location (m, n) of largest difference (linf_abs). + INTEGER(LIBXSMM_BLASINT_KIND) m, n + END TYPE + + INTERFACE + !> Initialize the library; pay for setup cost at a specific point. + SUBROUTINE libxsmm_init() BIND(C) + END SUBROUTINE + + !> De-initialize the library and free internal memory (optional). + SUBROUTINE libxsmm_finalize() BIND(C) + END SUBROUTINE + + !> Get the default prefetch strategy. + PURE FUNCTION libxsmm_get_gemm_auto_prefetch() BIND(C) + IMPORT :: C_INT + INTEGER(C_INT) :: libxsmm_get_gemm_auto_prefetch + END FUNCTION + + !> Set the default prefetch strategy. + SUBROUTINE libxsmm_set_gemm_auto_prefetch(strategy) BIND(C) + IMPORT :: C_INT + INTEGER(C_INT), INTENT(IN), VALUE :: strategy + END SUBROUTINE + + !> Returns the architecture and instruction set extension as determined + !> by the CPUID flags, as set by the libxsmm_get_target_arch* functions, + !> or as set by the LIBXSMM_TARGET environment variable. + PURE FUNCTION libxsmm_get_target_archid() BIND(C) + IMPORT :: C_INT + INTEGER(C_INT) :: libxsmm_get_target_archid + END FUNCTION + + !> Set target architecture (archid: see PARAMETER enumeration) + !> for subsequent code generation (JIT). + SUBROUTINE libxsmm_set_target_archid(archid) BIND(C) + IMPORT :: C_INT + INTEGER(C_INT), INTENT(IN), VALUE :: archid + END SUBROUTINE + + !> Set target architecture for subsequent code generation (JIT). + !> arch="0"|"sse"|"snb"|"hsw"|"knl"|"knm"|"skx"|"clx"|"cpx", + !> or "0" to rely on the CPUID (default). + !> There are some alternative target names as well: + !> "sse", "avx", "avx2", "avx3" (incomplete list). + SUBROUTINE libxsmm_set_target_arch(arch) BIND(C) + IMPORT :: C_CHAR + CHARACTER(C_CHAR), INTENT(IN) :: arch(*) + END SUBROUTINE + + !> Get the level of verbosity. + PURE FUNCTION libxsmm_get_verbosity() BIND(C) + IMPORT :: C_INT + INTEGER(C_INT) :: libxsmm_get_verbosity + END FUNCTION + + !> Set the level of verbosity (0: off, positive value: verbosity level, + !> negative value: maximum verbosity, which also dumps JIT-code). + SUBROUTINE libxsmm_set_verbosity(level) BIND(C) + IMPORT :: C_INT + INTEGER(C_INT), INTENT(IN), VALUE :: level + END SUBROUTINE + + !> Impure function which returns the current clock tick of a + !> monotonic timer source; uses a platform-specific resolution. + !> Implicit FORTRAN 77 interface: not available. + INTEGER(LIBXSMM_TICKINT_KIND) & + & FUNCTION libxsmm_timer_tick() BIND(C) + IMPORT :: LIBXSMM_TICKINT_KIND + END FUNCTION + + !> Impure function (timer freq. may vary) which returns the duration + !> (in seconds) between two values received by libxsmm_timer_tick. + !> Implicit FORTRAN 77 interface: not available. + FUNCTION libxsmm_timer_duration(tick0, tick1) BIND(C) + IMPORT :: LIBXSMM_TICKINT_KIND, C_DOUBLE + INTEGER(LIBXSMM_TICKINT_KIND), INTENT(IN), VALUE :: tick0 + INTEGER(LIBXSMM_TICKINT_KIND), INTENT(IN), VALUE :: tick1 + REAL(C_DOUBLE) :: libxsmm_timer_duration + END FUNCTION + + !> Deallocates the JIT'ted code, or unregisters + !> and releases code from the registry. + !> Implicit FORTRAN 77 interface: + !> INTEGER(8) :: kernel + SUBROUTINE libxsmm_release_kernel(kernel) & + & BIND(C, NAME="libxsmm_release_kernel_") + IMPORT :: C_FUNPTR + TYPE(C_FUNPTR), INTENT(IN) :: kernel + END SUBROUTINE + + !> Type-generic (unsafe) code dispatch (trylock: impure routine). + !> Implicit FORTRAN 77 interface: + !> INTEGER(4) :: gemm_precision, flags, prefetch + !> INTEGER(4|8) :: m, n, k, lda, ldb, ldc + !> REAL(4|8) :: alpha, beta + !> INTEGER(8) :: kernel + SUBROUTINE libxsmm_xmmdispatch(kernel, gemm_precision, & + & m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch) & + & BIND(C, NAME="libxsmm_xmmdispatch_") + IMPORT :: C_FUNPTR, C_PTR, C_INT, LIBXSMM_BLASINT_KIND + TYPE(C_FUNPTR), INTENT(OUT) :: kernel + INTEGER(C_INT), INTENT(IN) :: gemm_precision + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + TYPE(C_PTR), INTENT(IN), VALUE :: lda, ldb, ldc + TYPE(C_PTR), INTENT(IN), VALUE :: alpha, beta + TYPE(C_PTR), INTENT(IN), VALUE :: flags, prefetch + END SUBROUTINE + + !> Type-generic (unsafe) code dispatch (trylock: impure routine). + !> Implicit FORTRAN 77 interface: + !> INTEGER(4) :: iprec, oprec, flags, prefetch + !> INTEGER(4|8) :: m, n, k, lda, ldb, ldc + !> REAL(4|8) :: alpha, beta + !> INTEGER(8) :: kernel + SUBROUTINE libxsmm_xmmdispatch2(kernel, iprec, oprec, & + & m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch) & + & BIND(C, NAME="libxsmm_xmmdispatch2_") + IMPORT :: C_FUNPTR, C_PTR, C_INT, LIBXSMM_BLASINT_KIND + TYPE(C_FUNPTR), INTENT(OUT) :: kernel + INTEGER(C_INT), INTENT(IN) :: iprec, oprec + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + TYPE(C_PTR), INTENT(IN), VALUE :: lda, ldb, ldc + TYPE(C_PTR), INTENT(IN), VALUE :: alpha, beta + TYPE(C_PTR), INTENT(IN), VALUE :: flags, prefetch + END SUBROUTINE + + !> Generic call routine (3-argument form). + !> Implicit FORTRAN 77 interface: + !> REAL(4|8) :: a, b, c + !> INTEGER(8) :: kernel + PURE SUBROUTINE libxsmm_xmmcall_abc(kernel, a, b, c) & + & BIND(C, NAME="libxsmm_xmmcall_abc_") + IMPORT :: C_FUNPTR, C_PTR + TYPE(C_FUNPTR), INTENT(IN) :: kernel + TYPE(C_PTR), INTENT(IN), VALUE :: a, b, c + END SUBROUTINE + + !> Generic call routine (6-argument form). + !> Implicit FORTRAN 77 interface: + !> REAL(4|8) :: a, b, c, pa, pb, pc + !> INTEGER(8) :: kernel + PURE SUBROUTINE libxsmm_xmmcall_prf(kernel, & + & a, b, c, pa, pb, pc) & + & BIND(C, NAME="libxsmm_xmmcall_prf_") + IMPORT :: C_FUNPTR, C_PTR + TYPE(C_FUNPTR), INTENT(IN) :: kernel + TYPE(C_PTR), INTENT(IN), VALUE :: a, b, c, pa, pb, pc + END SUBROUTINE + + !> Fill destination with zeros; treats dst in raw/binary fashion. + SUBROUTINE libxsmm_xclear(dst, nbytes) & + & BIND(C, NAME="libxsmm_xclear_") + IMPORT :: C_PTR, C_INT + TYPE(C_PTR), INTENT(IN), VALUE :: dst + INTEGER(C_INT), INTENT(IN) :: nbytes + END SUBROUTINE + + !> Remove key-value pair from code registry and release memory. + SUBROUTINE libxsmm_xrelease(key, keysize) & + & BIND(C, NAME="libxsmm_xrelease_") + IMPORT :: C_PTR, C_INT + TYPE(C_PTR), INTENT(IN), VALUE :: key + INTEGER(C_INT), INTENT(IN) :: keysize + END SUBROUTINE + + !> Matrix-copy (2-dimensional copy) routine. + !> Implicit FORTRAN 77 interface: + !> ARRAY :: input, output + !> INTEGER(4|8) :: m, n, ldi, ldo + !> INTEGER(4) :: typesize + PURE SUBROUTINE libxsmm_xmatcopy(output, input, typesize, & + & m, n, ldi, ldo) BIND(C, NAME="libxsmm_matcopy_") + IMPORT :: LIBXSMM_BLASINT_KIND, C_PTR, C_INT + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ldi, ldo + TYPE(C_PTR), INTENT(IN), VALUE :: output, input + INTEGER(C_INT), INTENT(IN) :: typesize + END SUBROUTINE + + !> Transpose a matrix (in-place form). + !> Implicit FORTRAN 77 interface: + !> ARRAY :: matrix + !> INTEGER(4|8) :: m, n, ldi, ldo + !> INTEGER(4) :: typesize + PURE SUBROUTINE libxsmm_xitrans(matrix, typesize, & + & m, n, ldi, ldo) BIND(C, NAME="libxsmm_itrans_") + IMPORT :: C_PTR, C_INT, LIBXSMM_BLASINT_KIND + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ldi, ldo + TYPE(C_PTR), INTENT(IN), VALUE :: matrix + INTEGER(C_INT), INTENT(IN) :: typesize + END SUBROUTINE + + !> Transpose a matrix (out-of-place form). + !> Implicit FORTRAN 77 interface: + !> ARRAY :: input, output + !> INTEGER(4|8) :: m, n, ldi, ldo + !> INTEGER(4) :: typesize + PURE SUBROUTINE libxsmm_xotrans(output, input, typesize, & + & m, n, ldi, ldo) BIND(C, NAME="libxsmm_otrans_") + IMPORT :: C_PTR, C_INT, LIBXSMM_BLASINT_KIND + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ldi, ldo + TYPE(C_PTR), INTENT(IN), VALUE :: output, input + INTEGER(C_INT), INTENT(IN) :: typesize + END SUBROUTINE + + !> Matrix copy; MT via libxsmmext (out-of-place form). + !> Implicit FORTRAN 77 interface: + !> ARRAY :: output, input + !> INTEGER(4|8) :: m, n, ldi, ldo + !> INTEGER(4) :: typesize + PURE SUBROUTINE libxsmm_matcopy_omp(output, input, typesize, & + & m, n, ldi, ldo) BIND(C, NAME="libxsmm_matcopy_omp_") + IMPORT :: C_PTR, C_INT, LIBXSMM_BLASINT_KIND + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ldi, ldo + TYPE(C_PTR), INTENT(IN), VALUE :: output, input + INTEGER(C_INT), INTENT(IN) :: typesize + END SUBROUTINE + + !> Matrix transposition; MT via libxsmmext (out-of-place form). + !> Implicit FORTRAN 77 interface: + !> ARRAY :: output, input + !> INTEGER(4|8) :: m, n, ldi, ldo + !> INTEGER(4) :: typesize + PURE SUBROUTINE libxsmm_otrans_omp(output, input, typesize, & + & m, n, ldi, ldo) BIND(C, NAME="libxsmm_otrans_omp_") + IMPORT :: C_PTR, C_INT, LIBXSMM_BLASINT_KIND + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ldi, ldo + TYPE(C_PTR), INTENT(IN), VALUE :: output, input + INTEGER(C_INT), INTENT(IN) :: typesize + END SUBROUTINE + + !> General dense MM; MT via libxsmmext (double-precision). + !> Implicit FORTRAN 77 interface: similar to DGEMM. + PURE SUBROUTINE libxsmm_dgemm_omp(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) & + & BIND(C, NAME="libxsmm_dgemm_omp_") + IMPORT :: C_DOUBLE, C_CHAR, LIBXSMM_BLASINT_KIND + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + REAL(C_DOUBLE), INTENT(IN) :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a(lda,*), b(ldb,*) + REAL(C_DOUBLE), INTENT(INOUT) :: c(ldc,*) + END SUBROUTINE + + !> General dense MM; MT via libxsmmext (single-precision). + !> Implicit FORTRAN 77 interface: similar to SGEMM. + PURE SUBROUTINE libxsmm_sgemm_omp(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) & + & BIND(C, NAME="libxsmm_sgemm_omp_") + IMPORT :: C_FLOAT, C_CHAR, LIBXSMM_BLASINT_KIND + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + REAL(C_FLOAT), INTENT(IN) :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a(lda,*), b(ldb,*) + REAL(C_FLOAT), INTENT(INOUT) :: c(ldc,*) + END SUBROUTINE + + !> Process a series of MMs (batch). See also libxsmm_gemm_batch_omp. + !> The kind of matrix operands (a, b, c) depend on index_stride: + !> index_stride==0: pointers to pointers of elements, e.g., + !> double** for the C matrices. + !> index_stride!=0: pointer to elements, e.g., + !> const double* for the A and B matrices. + !> Implicit FORTRAN 77 interface: + !> INTEGER(4) :: iprec, oprec + !> REAL(4|8) :: alpha, beta + !> ARRAY :: a, b, c + !> ARRAY/VALUE :: stride_a, stride_b, stride_c + !> INTEGER(4|8) :: index_base, index_stride, batchsize + !> INTEGER(4) :: tid, nthreads + !> Otherwise arguments are similar to GEMM. + PURE SUBROUTINE libxsmm_mmbatch(iprec, oprec, transa, transb, & + & m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, index_base, & + & index_stride, stride_a, stride_b, stride_c, batchsize, & + & tid, nthreads) & + & BIND(C, NAME="libxsmm_mmbatch_") + IMPORT :: C_PTR, C_CHAR, C_INT, LIBXSMM_BLASINT_KIND + !> Determines index-base (usually 0, 1 for one-based indexes). + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: index_base + !> Stride (measured in Bytes) used to walk stride_*. + !> In Fortran: index_stride!=0. + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: index_stride + !> Number of SMMs. If the size is given as a negative value, + !> then internal synchronization is omitted. + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: batchsize + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + TYPE(C_PTR), INTENT(IN), VALUE :: alpha, beta + TYPE(C_PTR), INTENT(IN), VALUE :: a, b, c + !> Arrays of indexes determining the position of + !> a, b, and c operands. + TYPE(C_PTR), INTENT(IN), VALUE :: stride_a + TYPE(C_PTR), INTENT(IN), VALUE :: stride_b + TYPE(C_PTR), INTENT(IN), VALUE :: stride_c + INTEGER(C_INT), INTENT(IN) :: iprec, oprec + !> Thread-ID (TID), and number of threads. + INTEGER(C_INT), INTENT(IN) :: tid, nthreads + END SUBROUTINE + + !> Process a series of SMMs (batch). See also libxsmm_mmbatch. + !> Implicit FORTRAN 77 interface: + !> INTEGER(4) :: iprec, oprec + !> REAL(4|8) :: alpha, beta + !> ARRAY :: a, b, c + !> ARRAY/VALUE :: stride_a, stride_b, stride_c + !> INTEGER(4|8) :: index_base, index_stride, batchsize + !> Otherwise arguments are similar to GEMM. + PURE SUBROUTINE libxsmm_gemm_batch(iprec, oprec, & + & transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, & + & index_base, index_stride, stride_a, stride_b, stride_c, & + & batchsize) & + & BIND(C, NAME="libxsmm_gemm_batch_") + IMPORT :: C_PTR, C_CHAR, C_INT, LIBXSMM_BLASINT_KIND + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: index_base + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: index_stride + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: batchsize + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + TYPE(C_PTR), INTENT(IN), VALUE :: alpha, beta + TYPE(C_PTR), INTENT(IN), VALUE :: a, b, c + TYPE(C_PTR), INTENT(IN), VALUE :: stride_a + TYPE(C_PTR), INTENT(IN), VALUE :: stride_b + TYPE(C_PTR), INTENT(IN), VALUE :: stride_c + INTEGER(C_INT), INTENT(IN) :: iprec, oprec + END SUBROUTINE + + !> Process a series of SMMs (batch) with OpenMP (libxsmmext). + !> Implicit FORTRAN 77 interface: + !> INTEGER(4) :: iprec, oprec + !> REAL(4|8) :: alpha, beta + !> ARRAY :: a, b, c + !> ARRAY/VALUE :: stride_a, stride_b, stride_c + !> INTEGER(4|8) :: index_base, index_stride, batchsize + !> Otherwise arguments are similar to GEMM. + PURE SUBROUTINE libxsmm_gemm_batch_omp(iprec, oprec, & + & transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, & + & index_base, index_stride, stride_a, stride_b, stride_c, & + & batchsize) & + & BIND(C, NAME="libxsmm_gemm_batch_omp_") + IMPORT :: C_PTR, C_CHAR, C_INT, LIBXSMM_BLASINT_KIND + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: index_base + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: index_stride + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: batchsize + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + TYPE(C_PTR), INTENT(IN), VALUE :: alpha, beta + TYPE(C_PTR), INTENT(IN), VALUE :: a, b, c + TYPE(C_PTR), INTENT(IN), VALUE :: stride_a + TYPE(C_PTR), INTENT(IN), VALUE :: stride_b + TYPE(C_PTR), INTENT(IN), VALUE :: stride_c + INTEGER(C_INT), INTENT(IN) :: iprec, oprec + END SUBROUTINE + + !> This function is a no-op unless LIBXSMM is built to intercept GEMM. + !> Pointer arguments are used to filter intercepted GEMM calls such that + !> non-NULL values match. Otherwise (NULL) the respective argument is + !> considered a "free value", i.e., every value can match; + !> libxsmmext required. + !> Implicit FORTRAN 77 interface: + !> INTEGER(4) :: gemm_precision, flags + !> INTEGER(4|8) :: m, n, k, lda, ldb, ldc + !> REAL(4|8) :: alpha, beta + SUBROUTINE libxsmm_mmbatch_begin(gemm_precision, flags, & + & m, n, k, lda, ldb, ldc, alpha, beta) BIND(C) + IMPORT :: C_PTR, C_INT, LIBXSMM_BLASINT_KIND + INTEGER(C_INT), INTENT(IN), VALUE :: gemm_precision + INTEGER(C_INT), INTENT(IN) :: flags + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + TYPE(C_PTR), INTENT(IN), VALUE :: alpha, beta + END SUBROUTINE + + !> Processes the batch of previously recorded SMMs + !> (libxsmm_mmbatch_begin); libxsmmext required. + !> Implicit FORTRAN 77 interface: available. + SUBROUTINE libxsmm_mmbatch_end() BIND(C) + END SUBROUTINE + + !> Reduces input into output such that the difference is maintained + !> or increased (max function). The very first (initial) output + !> should be zeroed (libxsmm_matdiff_clear). + !> Implicit FORTRAN 77 interface: available. + PURE SUBROUTINE libxsmm_matdiff_reduce(output, input) BIND(C) + IMPORT :: LIBXSMM_MATDIFF_INFO + TYPE(LIBXSMM_MATDIFF_INFO), INTENT(INOUT) :: output + TYPE(LIBXSMM_MATDIFF_INFO), INTENT(IN) :: input + END SUBROUTINE + + !> Clears the given info-structure, e.g., for the initial + !> reduction-value (libxsmm_matdiff_reduce). + !> Implicit FORTRAN 77 interface: available. + PURE SUBROUTINE libxsmm_matdiff_clear(info) BIND(C) + IMPORT :: LIBXSMM_MATDIFF_INFO + TYPE(LIBXSMM_MATDIFF_INFO), INTENT(OUT) :: info + END SUBROUTINE + + !> Calculates a hash value for the given array and seed. + !> Routine suitable for FORTRAN 77; keysize in Bytes. + PURE SUBROUTINE libxsmm_xhash(hash_seed, key, keysize) & + & BIND(C, NAME="libxsmm_xhash_") + IMPORT :: C_INT, C_PTR + INTEGER(C_INT), INTENT(INOUT) :: hash_seed + INTEGER(C_INT), INTENT(IN) :: keysize + TYPE(C_PTR), INTENT(IN), VALUE :: key + END SUBROUTINE + + !> Calculates if there is a difference between two arrays. + !> Routine suitable for FORTRAN 77; size in Bytes. + PURE SUBROUTINE libxsmm_xdiff(diff, a, b, nbytes) & + & BIND(C, NAME="libxsmm_xdiff_") + IMPORT :: C_PTR, C_LONG_LONG, C_BOOL + TYPE(C_PTR), INTENT(IN), VALUE :: a, b + INTEGER(C_LONG_LONG), INTENT(IN) :: nbytes + LOGICAL(C_BOOL), INTENT(OUT) :: diff + END SUBROUTINE + END INTERFACE + + INTERFACE libxsmm_ptr0 + MODULE PROCEDURE libxsmm_ptr_z0, libxsmm_ptr_c0 + MODULE PROCEDURE libxsmm_ptr_d0, libxsmm_ptr_s0 + MODULE PROCEDURE libxsmm_ptr_i0, libxsmm_ptr_w0 + MODULE PROCEDURE libxsmm_ptr_j0 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_b0 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_l0 !! long long + END INTERFACE + + INTERFACE libxsmm_ptr1 + MODULE PROCEDURE libxsmm_ptr_z1, libxsmm_ptr_c1 + MODULE PROCEDURE libxsmm_ptr_d1, libxsmm_ptr_s1 + MODULE PROCEDURE libxsmm_ptr_i1, libxsmm_ptr_w1 + MODULE PROCEDURE libxsmm_ptr_j1 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_b1 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_l1 !! long long + MODULE PROCEDURE libxsmm_ptr_dmm + MODULE PROCEDURE libxsmm_ptr_smm + MODULE PROCEDURE libxsmm_ptr_wimm + END INTERFACE + + INTERFACE libxsmm_ptr2 + MODULE PROCEDURE libxsmm_ptr_z2, libxsmm_ptr_c2 + MODULE PROCEDURE libxsmm_ptr_d2, libxsmm_ptr_s2 + MODULE PROCEDURE libxsmm_ptr_i2, libxsmm_ptr_w2 + MODULE PROCEDURE libxsmm_ptr_j2 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_b2 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_l2 !! long long + END INTERFACE + + INTERFACE libxsmm_ptr + MODULE PROCEDURE libxsmm_ptr_z0, libxsmm_ptr_c0 + MODULE PROCEDURE libxsmm_ptr_d0, libxsmm_ptr_s0 + MODULE PROCEDURE libxsmm_ptr_i0, libxsmm_ptr_w0 + MODULE PROCEDURE libxsmm_ptr_j0 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_b0 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_l0 !! long long + MODULE PROCEDURE libxsmm_ptr_z1, libxsmm_ptr_c1 + MODULE PROCEDURE libxsmm_ptr_d1, libxsmm_ptr_s1 + MODULE PROCEDURE libxsmm_ptr_i1, libxsmm_ptr_w1 + MODULE PROCEDURE libxsmm_ptr_j1 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_b1 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_l1 !! long long + MODULE PROCEDURE libxsmm_ptr_z2, libxsmm_ptr_c2 + MODULE PROCEDURE libxsmm_ptr_d2, libxsmm_ptr_s2 + MODULE PROCEDURE libxsmm_ptr_i2, libxsmm_ptr_w2 + MODULE PROCEDURE libxsmm_ptr_j2 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_b2 !! Byte/char + MODULE PROCEDURE libxsmm_ptr_l2 !! long long + MODULE PROCEDURE libxsmm_ptr_dmm + MODULE PROCEDURE libxsmm_ptr_smm + MODULE PROCEDURE libxsmm_ptr_wimm + END INTERFACE + + !> Deallocates JIT'ted code, or unregisters/releases code from registry. + INTERFACE libxsmm_release_mmkernel + MODULE PROCEDURE libxsmm_release_dmmkernel + MODULE PROCEDURE libxsmm_release_smmkernel + MODULE PROCEDURE libxsmm_release_wimmkernel + END INTERFACE + + !> Construct JIT-code depending on given argument set. + INTERFACE libxsmm_mmdispatch + MODULE PROCEDURE libxsmm_dmmdispatch, libxsmm_smmdispatch + MODULE PROCEDURE libxsmm_wimmdispatch + END INTERFACE + + !> Construct JIT-code depending on given argument set. + INTERFACE libxsmm_dispatch + MODULE PROCEDURE libxsmm_dmmdispatch, libxsmm_smmdispatch + MODULE PROCEDURE libxsmm_wimmdispatch + END INTERFACE + + !> Check if a function is available (LIBXSMM_?MMFUNCTION). + INTERFACE libxsmm_mmavailable + MODULE PROCEDURE libxsmm_dmmavailable, libxsmm_smmavailable + MODULE PROCEDURE libxsmm_wimmavailable + END INTERFACE + + !> Check if a function is available (LIBXSMM_?MMFUNCTION). + INTERFACE libxsmm_available + MODULE PROCEDURE libxsmm_smmavailable, libxsmm_dmmavailable + MODULE PROCEDURE libxsmm_wimmavailable + END INTERFACE + + !> Overloaded GEMM routines (double-precision). + INTERFACE libxsmm_dgemm + MODULE PROCEDURE libxsmm_dgemm0 + MODULE PROCEDURE libxsmm_dgemm1 + MODULE PROCEDURE libxsmm_dgemm2 + MODULE PROCEDURE libxsmm_dgemm3 + END INTERFACE + + !> Overloaded GEMM routines (single-precision). + INTERFACE libxsmm_sgemm + MODULE PROCEDURE libxsmm_sgemm0 + MODULE PROCEDURE libxsmm_sgemm1 + MODULE PROCEDURE libxsmm_sgemm2 + END INTERFACE + + !> Overloaded GEMM routines (low-precision). + INTERFACE libxsmm_wigemm + MODULE PROCEDURE libxsmm_wigemm0 + MODULE PROCEDURE libxsmm_wigemm1 + MODULE PROCEDURE libxsmm_wigemm2 + END INTERFACE + + !> Overloaded GEMM routines. + INTERFACE libxsmm_gemm + MODULE PROCEDURE libxsmm_dgemm0 + MODULE PROCEDURE libxsmm_dgemm1 + MODULE PROCEDURE libxsmm_dgemm2 + MODULE PROCEDURE libxsmm_dgemm3 + MODULE PROCEDURE libxsmm_sgemm0 + MODULE PROCEDURE libxsmm_sgemm1 + MODULE PROCEDURE libxsmm_sgemm2 + MODULE PROCEDURE libxsmm_sgemm3 + MODULE PROCEDURE libxsmm_wigemm0 + MODULE PROCEDURE libxsmm_wigemm1 + MODULE PROCEDURE libxsmm_wigemm2 + MODULE PROCEDURE libxsmm_wigemm3 + END INTERFACE + + !> Overloaded BLAS GEMM routines (double-precision). + INTERFACE libxsmm_blas_dgemm + MODULE PROCEDURE libxsmm_blas_dgemm0 + MODULE PROCEDURE libxsmm_blas_dgemm1 + MODULE PROCEDURE libxsmm_blas_dgemm2 + MODULE PROCEDURE libxsmm_blas_dgemm3 + END INTERFACE + + !> Overloaded BLAS GEMM routines (single-precision). + INTERFACE libxsmm_blas_sgemm + MODULE PROCEDURE libxsmm_blas_sgemm0 + MODULE PROCEDURE libxsmm_blas_sgemm1 + MODULE PROCEDURE libxsmm_blas_sgemm2 + MODULE PROCEDURE libxsmm_blas_sgemm3 + END INTERFACE + + !> Overloaded BLAS GEMM routines (single/double-precision). + INTERFACE libxsmm_blas_gemm + MODULE PROCEDURE libxsmm_blas_dgemm0 + MODULE PROCEDURE libxsmm_blas_dgemm1 + MODULE PROCEDURE libxsmm_blas_dgemm2 + MODULE PROCEDURE libxsmm_blas_dgemm3 + MODULE PROCEDURE libxsmm_blas_sgemm0 + MODULE PROCEDURE libxsmm_blas_sgemm1 + MODULE PROCEDURE libxsmm_blas_sgemm2 + MODULE PROCEDURE libxsmm_blas_sgemm3 + END INTERFACE + + !> Overloaded MATCOPY routines (2d-copy). + INTERFACE libxsmm_matcopy + MODULE PROCEDURE libxsmm_matcopy_p0 + MODULE PROCEDURE libxsmm_matcopy_d1 + MODULE PROCEDURE libxsmm_matcopy_d2 + MODULE PROCEDURE libxsmm_matcopy_s1 + MODULE PROCEDURE libxsmm_matcopy_s2 + END INTERFACE + + !> Overloaded TRANSPOSE routines (in-place form). + INTERFACE libxsmm_itrans + MODULE PROCEDURE libxsmm_itrans_p0 + MODULE PROCEDURE libxsmm_itrans_d1 + MODULE PROCEDURE libxsmm_itrans_d2 + MODULE PROCEDURE libxsmm_itrans_s1 + MODULE PROCEDURE libxsmm_itrans_s2 + END INTERFACE + + !> Overloaded TRANSPOSE routines (out-of-place form). + INTERFACE libxsmm_otrans + MODULE PROCEDURE libxsmm_otrans_p0 + MODULE PROCEDURE libxsmm_otrans_d1 + MODULE PROCEDURE libxsmm_otrans_d2 + MODULE PROCEDURE libxsmm_otrans_s1 + MODULE PROCEDURE libxsmm_otrans_s2 + END INTERFACE + + !> Calculate a hash value for a given key value (binary blob). + !> Conceptually pure, but C_LOC may be (incorrectly) impure. + INTERFACE libxsmm_hash + MODULE PROCEDURE libxsmm_hash_char + MODULE PROCEDURE libxsmm_hash_i8 + MODULE PROCEDURE libxsmm_hash_i32 + MODULE PROCEDURE libxsmm_hash_i64 + END INTERFACE + + !> Calculate whether there is a difference between two series of items. + !> Conceptually pure, but C_LOC may be (incorrectly) impure. + INTERFACE libxsmm_diff + MODULE PROCEDURE libxsmm_diff_char + MODULE PROCEDURE libxsmm_diff_i8 + MODULE PROCEDURE libxsmm_diff_i32 + MODULE PROCEDURE libxsmm_diff_i64 + END INTERFACE + + CONTAINS + !> Returns the name of the target architecture as determined by + !> the CPUID flags, as set by the libxsmm_get_target_arch* functions, + !> or as set by the LIBXSMM_TARGET environment variable. + FUNCTION libxsmm_get_target_arch() + !CHARACTER(LEN=:), POINTER :: libxsmm_get_target_arch + CHARACTER, POINTER :: libxsmm_get_target_arch(:) + INTEGER(C_INT) :: length(1) + TYPE(C_PTR) :: arch + INTERFACE + FUNCTION libxsmmf_get_target_arch(length) BIND(C) + IMPORT :: C_INT, C_PTR + INTEGER(C_INT), INTENT(OUT) :: length + TYPE(C_PTR) :: libxsmmf_get_target_arch + END FUNCTION + END INTERFACE + arch = libxsmmf_get_target_arch(length(1)) + CALL C_F_POINTER(arch, libxsmm_get_target_arch, length) + END FUNCTION + + !> Returns C_NULL_PTR. + PURE FUNCTION libxsmm_ptr_null() + TYPE(C_PTR) :: libxsmm_ptr_null + libxsmm_ptr_null = C_NULL_PTR + END FUNCTION + + !> Determines the C-address of the given array. + FUNCTION libxsmm_ptr_z0(a) + COMPLEX(C_DOUBLE_COMPLEX), INTENT(IN), TARGET :: a + TYPE(C_PTR) :: libxsmm_ptr_z0 + libxsmm_ptr_z0 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_z1(a) + COMPLEX(C_DOUBLE_COMPLEX), INTENT(IN), TARGET :: a(*) + TYPE(C_PTR) :: libxsmm_ptr_z1 + libxsmm_ptr_z1 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_z2(a) + COMPLEX(C_DOUBLE_COMPLEX), INTENT(IN) :: a(:,:) + TYPE(C_PTR) :: libxsmm_ptr_z2 + libxsmm_ptr_z2 = libxsmm_ptr_z1(a) + END FUNCTION + + !> Determines the C-address of the given array. + FUNCTION libxsmm_ptr_c0(a) + COMPLEX(C_FLOAT_COMPLEX), INTENT(IN), TARGET :: a + TYPE(C_PTR) :: libxsmm_ptr_c0 + libxsmm_ptr_c0 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_c1(a) + COMPLEX(C_FLOAT_COMPLEX), INTENT(IN), TARGET :: a(*) + TYPE(C_PTR) :: libxsmm_ptr_c1 + libxsmm_ptr_c1 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_c2(a) + COMPLEX(C_FLOAT_COMPLEX), INTENT(IN) :: a(:,:) + TYPE(C_PTR) :: libxsmm_ptr_c2 + libxsmm_ptr_c2 = libxsmm_ptr_c1(a) + END FUNCTION + + !> Determines the C-address of the given array. + FUNCTION libxsmm_ptr_d0(a) + REAL(C_DOUBLE), INTENT(IN), TARGET :: a + TYPE(C_PTR) :: libxsmm_ptr_d0 + libxsmm_ptr_d0 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_d1(a) + REAL(C_DOUBLE), INTENT(IN), TARGET :: a(*) + TYPE(C_PTR) :: libxsmm_ptr_d1 + libxsmm_ptr_d1 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_d2(a) + REAL(C_DOUBLE), INTENT(IN) :: a(:,:) + TYPE(C_PTR) :: libxsmm_ptr_d2 + libxsmm_ptr_d2 = libxsmm_ptr_d1(a) + END FUNCTION + + !> Determines the C-address of the given array. + FUNCTION libxsmm_ptr_s0(a) + REAL(C_FLOAT), INTENT(IN), TARGET :: a + TYPE(C_PTR) :: libxsmm_ptr_s0 + libxsmm_ptr_s0 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_s1(a) + REAL(C_FLOAT), INTENT(IN), TARGET :: a(*) + TYPE(C_PTR) :: libxsmm_ptr_s1 + libxsmm_ptr_s1 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_s2(a) + REAL(C_FLOAT), INTENT(IN) :: a(:,:) + TYPE(C_PTR) :: libxsmm_ptr_s2 + libxsmm_ptr_s2 = libxsmm_ptr_s1(a) + END FUNCTION + + !> Determines the C-address of the given array. + FUNCTION libxsmm_ptr_i0(a) + INTEGER(C_INT), INTENT(IN), TARGET :: a + TYPE(C_PTR) :: libxsmm_ptr_i0 + libxsmm_ptr_i0 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_i1(a) + INTEGER(C_INT), INTENT(IN), TARGET :: a(*) + TYPE(C_PTR) :: libxsmm_ptr_i1 + libxsmm_ptr_i1 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_i2(a) + INTEGER(C_INT), INTENT(IN) :: a(:,:) + TYPE(C_PTR) :: libxsmm_ptr_i2 + libxsmm_ptr_i2 = libxsmm_ptr_i1(a) + END FUNCTION + + !> Determines the C-address of the given array. + FUNCTION libxsmm_ptr_w0(a) + INTEGER(C_SHORT), INTENT(IN), TARGET :: a + TYPE(C_PTR) :: libxsmm_ptr_w0 + libxsmm_ptr_w0 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_w1(a) + INTEGER(C_SHORT), INTENT(IN), TARGET :: a(*) + TYPE(C_PTR) :: libxsmm_ptr_w1 + libxsmm_ptr_w1 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_w2(a) + INTEGER(C_SHORT), INTENT(IN) :: a(:,:) + TYPE(C_PTR) :: libxsmm_ptr_w2 + libxsmm_ptr_w2 = libxsmm_ptr_w1(a) + END FUNCTION + + !> Determines the C-address of the given array. + FUNCTION libxsmm_ptr_j0(a) + INTEGER(C_INT8_T), INTENT(IN), TARGET :: a + TYPE(C_PTR) :: libxsmm_ptr_j0 + libxsmm_ptr_j0 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_j1(a) + INTEGER(C_INT8_T), INTENT(IN), TARGET :: a(*) + TYPE(C_PTR) :: libxsmm_ptr_j1 + libxsmm_ptr_j1 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_j2(a) + INTEGER(C_INT8_T), INTENT(IN) :: a(:,:) + TYPE(C_PTR) :: libxsmm_ptr_j2 + libxsmm_ptr_j2 = libxsmm_ptr_j1(a) + END FUNCTION + + !> Determines the C-address of the given array. + FUNCTION libxsmm_ptr_b0(a) + CHARACTER(C_CHAR), INTENT(IN), TARGET :: a + TYPE(C_PTR) :: libxsmm_ptr_b0 + libxsmm_ptr_b0 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_b1(a) + CHARACTER(C_CHAR), INTENT(IN), TARGET :: a(*) + TYPE(C_PTR) :: libxsmm_ptr_b1 + libxsmm_ptr_b1 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_b2(a) + CHARACTER(C_CHAR), INTENT(IN) :: a(:,:) + TYPE(C_PTR) :: libxsmm_ptr_b2 + libxsmm_ptr_b2 = libxsmm_ptr_b1(a) + END FUNCTION + + !> Determines the C-address of the given array. + FUNCTION libxsmm_ptr_l0(a) + INTEGER(C_LONG_LONG), INTENT(IN), TARGET :: a + TYPE(C_PTR) :: libxsmm_ptr_l0 + libxsmm_ptr_l0 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_l1(a) + INTEGER(C_LONG_LONG), INTENT(IN), TARGET :: a(*) + TYPE(C_PTR) :: libxsmm_ptr_l1 + libxsmm_ptr_l1 = C_LOC(a) + END FUNCTION + FUNCTION libxsmm_ptr_l2(a) + INTEGER(C_LONG_LONG), INTENT(IN) :: a(:,:) + TYPE(C_PTR) :: libxsmm_ptr_l2 + libxsmm_ptr_l2 = libxsmm_ptr_l1(a) + END FUNCTION + + FUNCTION libxsmm_ptr_dmm(a) + TYPE(LIBXSMM_DMMFUNCTION), INTENT(IN), TARGET :: a(:) + TYPE(LIBXSMM_DMMFUNCTION), POINTER :: p + TYPE(C_PTR) :: libxsmm_ptr_dmm + p => a(LBOUND(a,1)); libxsmm_ptr_dmm = C_LOC(p%handle) + END FUNCTION + FUNCTION libxsmm_ptr_smm(a) + TYPE(LIBXSMM_SMMFUNCTION), INTENT(IN), TARGET :: a(:) + TYPE(LIBXSMM_SMMFUNCTION), POINTER :: p + TYPE(C_PTR) :: libxsmm_ptr_smm + p => a(LBOUND(a,1)); libxsmm_ptr_smm = C_LOC(p%handle) + END FUNCTION + FUNCTION libxsmm_ptr_wimm(a) + TYPE(LIBXSMM_WIMMFUNCTION), INTENT(IN), TARGET :: a(:) + TYPE(LIBXSMM_WIMMFUNCTION), POINTER :: p + TYPE(C_PTR) :: libxsmm_ptr_wimm + p => a(LBOUND(a,1)); libxsmm_ptr_wimm = C_LOC(p%handle) + END FUNCTION + + !> Deallocate JIT'ted code created by libxsmm_create routines. To + !> unregister code generated with libxsmm_dispatch is unnecessary. + SUBROUTINE libxsmm_release_dmmkernel(kernel) + TYPE(LIBXSMM_DMMFUNCTION), INTENT(IN) :: kernel + CALL libxsmm_release_kernel(kernel%handle) + END SUBROUTINE + + !> Deallocate JIT'ted code created by libxsmm_create routines. To + !> unregister code generated with libxsmm_dispatch is unnecessary. + SUBROUTINE libxsmm_release_smmkernel(kernel) + TYPE(LIBXSMM_SMMFUNCTION), INTENT(IN) :: kernel + CALL libxsmm_release_kernel(kernel%handle) + END SUBROUTINE + + !> Deallocate JIT'ted code created by libxsmm_create routines. To + !> unregister code generated with libxsmm_dispatch is unnecessary. + SUBROUTINE libxsmm_release_wimmkernel(kernel) + TYPE(LIBXSMM_WIMMFUNCTION), INTENT(IN) :: kernel + CALL libxsmm_release_kernel(kernel%handle) + END SUBROUTINE + + !> Query or JIT-generate an SMM-kernel (double-precision). + SUBROUTINE libxsmm_dmmdispatch(kernel, & + & m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch) + TYPE(LIBXSMM_DMMFUNCTION), INTENT(OUT) :: kernel + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), & + & OPTIONAL, TARGET :: lda, ldb, ldc + REAL(C_DOUBLE), INTENT(IN), OPTIONAL, TARGET :: alpha, beta + INTEGER(C_INT), INTENT(IN), OPTIONAL, TARGET :: flags + INTEGER(C_INT), INTENT(IN), OPTIONAL, TARGET :: prefetch + CALL libxsmm_xmmdispatch( & + & kernel%handle, LIBXSMM_GEMM_PRECISION_F64, & + & m, n, k, C_LOC(lda), C_LOC(ldb), C_LOC(ldc), & + & C_LOC(alpha), C_LOC(beta), C_LOC(flags), C_LOC(prefetch)) + END SUBROUTINE + + !> Query or JIT-generate an SMM-kernel (single-precision). + SUBROUTINE libxsmm_smmdispatch(kernel, & + & m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch) + TYPE(LIBXSMM_SMMFUNCTION), INTENT(OUT) :: kernel + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), & + & OPTIONAL, TARGET :: lda, ldb, ldc + REAL(C_FLOAT), INTENT(IN), OPTIONAL, TARGET :: alpha, beta + INTEGER(C_INT), INTENT(IN), OPTIONAL, TARGET :: flags + INTEGER(C_INT), INTENT(IN), OPTIONAL, TARGET :: prefetch + CALL libxsmm_xmmdispatch( & + & kernel%handle, LIBXSMM_GEMM_PRECISION_F32, & + & m, n, k, C_LOC(lda), C_LOC(ldb), C_LOC(ldc), & + & C_LOC(alpha), C_LOC(beta), C_LOC(flags), C_LOC(prefetch)) + END SUBROUTINE + + !> Query or JIT-generate an SMM-kernel (low-precision, int-accumulate). + SUBROUTINE libxsmm_wimmdispatch(kernel, & + & m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch) + TYPE(LIBXSMM_WIMMFUNCTION), INTENT(OUT) :: kernel + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), & + & OPTIONAL, TARGET :: lda, ldb, ldc + INTEGER(C_INT), INTENT(IN), OPTIONAL, TARGET :: alpha, beta + INTEGER(C_INT), INTENT(IN), OPTIONAL, TARGET :: flags + INTEGER(C_INT), INTENT(IN), OPTIONAL, TARGET :: prefetch + CALL libxsmm_xmmdispatch2(kernel%handle, & + & LIBXSMM_GEMM_PRECISION_I16, LIBXSMM_GEMM_PRECISION_I32, & + & m, n, k, C_LOC(lda), C_LOC(ldb), C_LOC(ldc), & + & C_LOC(alpha), C_LOC(beta), C_LOC(flags), C_LOC(prefetch)) + END SUBROUTINE + + !> Checks if the given kernel was generated. JIT code is guaranteed + !> to be generated if JIT support was enabled at build-time of the + !> library (default). This overload belongs to libxsmm_(mm)available. + LOGICAL FUNCTION libxsmm_dmmavailable(kernel) + TYPE(LIBXSMM_DMMFUNCTION), INTENT(IN) :: kernel + libxsmm_dmmavailable = C_ASSOCIATED(kernel%handle) + END FUNCTION + + !> Checks if the given kernel was generated. JIT code is guaranteed + !> to be generated if JIT support was enabled at build-time of the + !> library (default). This overload belongs to libxsmm_(mm)available. + LOGICAL FUNCTION libxsmm_smmavailable(kernel) + TYPE(LIBXSMM_SMMFUNCTION), INTENT(IN) :: kernel + libxsmm_smmavailable = C_ASSOCIATED(kernel%handle) + END FUNCTION + + !> Checks if the given kernel was generated. JIT code is guaranteed + !> to be generated if JIT support was enabled at build-time of the + !> library (default). This overload belongs to libxsmm_(mm)available. + LOGICAL FUNCTION libxsmm_wimmavailable(kernel) + TYPE(LIBXSMM_WIMMFUNCTION), INTENT(IN) :: kernel + libxsmm_wimmavailable = C_ASSOCIATED(kernel%handle) + END FUNCTION + + !> Calls the kernel with the given arguments. Alternatively, + !> PROCPOINTER can be used as shown by the inner comments + !> of this routine (LIBXSMM_FUNCTION3). The libxsmm_xmmcall + !> routines can be used in FORTRAN77. + SUBROUTINE libxsmm_dmmcall_abc(kernel, a, b, c) + TYPE(LIBXSMM_DMMFUNCTION), INTENT(IN) :: kernel + REAL(C_DOUBLE), INTENT(IN), TARGET :: a(*), b(*) + REAL(C_DOUBLE), INTENT(INOUT), TARGET :: c(*) + ! PROCEDURE(LIBXSMM_FUNCTION3), POINTER :: xmm + ! CALL C_F_PROCPOINTER(kernel%handle, xmm) + ! CALL xmm(...) + CALL libxsmm_xmmcall_abc(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c)) + END SUBROUTINE + + !> Calls the kernel with the given arguments. Alternatively, + !> PROCPOINTER can be used as shown by the inner comments + !> of this routine (LIBXSMM_FUNCTION6). The libxsmm_xmmcall + !> routines can be used in FORTRAN77. + SUBROUTINE libxsmm_dmmcall_prf(kernel, a, b, c, pa, pb, pc) + TYPE(LIBXSMM_DMMFUNCTION), INTENT(IN) :: kernel + REAL(C_DOUBLE), INTENT(IN), TARGET :: a(*), b(*) + REAL(C_DOUBLE), INTENT(INOUT), TARGET :: c(*) + REAL(C_DOUBLE), INTENT(IN), TARGET :: pa(*) + REAL(C_DOUBLE), INTENT(IN), TARGET :: pb(*) + REAL(C_DOUBLE), INTENT(IN), TARGET :: pc(*) + ! PROCEDURE(LIBXSMM_FUNCTION6), POINTER :: xmm + ! CALL C_F_PROCPOINTER(kernel%handle, xmm) + ! CALL xmm(...) + CALL libxsmm_xmmcall_prf(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c), & + & C_LOC(pa), C_LOC(pb), C_LOC(pc)) + END SUBROUTINE + + !> See also libxsmm_dmmcall_abc and libxsmm_dmmcall_prf. + SUBROUTINE libxsmm_dmmcall(kernel, a, b, c, pa, pb, pc) + TYPE(LIBXSMM_DMMFUNCTION), INTENT(IN) :: kernel + REAL(C_DOUBLE), INTENT(IN), TARGET :: a(*), b(*) + REAL(C_DOUBLE), INTENT(INOUT), TARGET :: c(*) + REAL(C_DOUBLE), INTENT(IN), OPTIONAL, TARGET :: pa(*) + REAL(C_DOUBLE), INTENT(IN), OPTIONAL, TARGET :: pb(*) + REAL(C_DOUBLE), INTENT(IN), OPTIONAL, TARGET :: pc(*) + ! use .OR. instead of .AND. to avoid full check + IF (PRESENT(pa).OR.PRESENT(pb).OR.PRESENT(pc)) THEN + CALL libxsmm_xmmcall_prf(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c), & + & C_LOC(pa), C_LOC(pb), C_LOC(pc)) + ELSE + CALL libxsmm_xmmcall_abc(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c)) + END IF + END SUBROUTINE + + !> Calls the kernel with the given arguments. Alternatively, + !> PROCPOINTER can be used as shown by the inner comments + !> of this routine (LIBXSMM_FUNCTION3). The libxsmm_xmmcall + !> routines can be used in FORTRAN77. + SUBROUTINE libxsmm_smmcall_abc(kernel, a, b, c) + TYPE(LIBXSMM_SMMFUNCTION), INTENT(IN) :: kernel + REAL(C_FLOAT), INTENT(IN), TARGET :: a(*), b(*) + REAL(C_FLOAT), INTENT(INOUT), TARGET :: c(*) + ! PROCEDURE(LIBXSMM_FUNCTION3), POINTER :: xmm + ! CALL C_F_PROCPOINTER(kernel%handle, xmm) + ! CALL xmm(...) + CALL libxsmm_xmmcall_abc(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c)) + END SUBROUTINE + + !> Calls the kernel with the given arguments. Alternatively, + !> PROCPOINTER can be used as shown by the inner comments + !> of this routine (LIBXSMM_FUNCTION6). The libxsmm_xmmcall + !> routines can be used in FORTRAN77. + SUBROUTINE libxsmm_smmcall_prf(kernel, a, b, c, pa, pb, pc) + TYPE(LIBXSMM_SMMFUNCTION), INTENT(IN) :: kernel + REAL(C_FLOAT), INTENT(IN), TARGET :: a(*), b(*) + REAL(C_FLOAT), INTENT(INOUT), TARGET :: c(*) + REAL(C_FLOAT), INTENT(IN), TARGET :: pa(*) + REAL(C_FLOAT), INTENT(IN), TARGET :: pb(*) + REAL(C_FLOAT), INTENT(IN), TARGET :: pc(*) + ! PROCEDURE(LIBXSMM_FUNCTION6), POINTER :: xmm + ! CALL C_F_PROCPOINTER(kernel%handle, xmm) + ! CALL xmm(...) + CALL libxsmm_xmmcall_prf(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c), & + & C_LOC(pa), C_LOC(pb), C_LOC(pc)) + END SUBROUTINE + + !> See also libxsmm_smmcall_abc and libxsmm_smmcall_prf. + SUBROUTINE libxsmm_smmcall(kernel, a, b, c, pa, pb, pc) + TYPE(LIBXSMM_SMMFUNCTION), INTENT(IN) :: kernel + REAL(C_FLOAT), INTENT(IN), TARGET :: a(*), b(*) + REAL(C_FLOAT), INTENT(INOUT), TARGET :: c(*) + REAL(C_FLOAT), INTENT(IN), OPTIONAL, TARGET :: pa(*) + REAL(C_FLOAT), INTENT(IN), OPTIONAL, TARGET :: pb(*) + REAL(C_FLOAT), INTENT(IN), OPTIONAL, TARGET :: pc(*) + ! use .OR. instead of .AND. to avoid full check + IF (PRESENT(pa).OR.PRESENT(pb).OR.PRESENT(pc)) THEN + CALL libxsmm_xmmcall_prf(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c), & + & C_LOC(pa), C_LOC(pb), C_LOC(pc)) + ELSE + CALL libxsmm_xmmcall_abc(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c)) + END IF + END SUBROUTINE + + !> Calls the kernel with the given arguments. Alternatively, + !> PROCPOINTER can be used as shown by the inner comments + !> of this routine (LIBXSMM_FUNCTION3). The libxsmm_xmmcall + !> routines can be used in FORTRAN77. + SUBROUTINE libxsmm_wimmcall_abc(kernel, a, b, c) + TYPE(LIBXSMM_WIMMFUNCTION), INTENT(IN) :: kernel + INTEGER(C_SHORT), INTENT(IN), TARGET :: a(*), b(*) + INTEGER(C_INT), INTENT(INOUT), TARGET :: c(*) + ! PROCEDURE(LIBXSMM_FUNCTION3), POINTER :: xmm + ! CALL C_F_PROCPOINTER(kernel%handle, xmm) + ! CALL xmm(...) + CALL libxsmm_xmmcall_abc(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c)) + END SUBROUTINE + + !> Calls the kernel with the given arguments. Alternatively, + !> PROCPOINTER can be used as shown by the inner comments + !> of this routine (LIBXSMM_FUNCTION6). The libxsmm_xmmcall + !> routines can be used in FORTRAN77. + SUBROUTINE libxsmm_wimmcall_prf(kernel, a, b, c, pa, pb, pc) + TYPE(LIBXSMM_WIMMFUNCTION), INTENT(IN) :: kernel + INTEGER(C_SHORT), INTENT(IN), TARGET :: a(*), b(*) + INTEGER(C_INT), INTENT(INOUT), TARGET :: c(*) + INTEGER(C_SHORT), INTENT(IN), TARGET :: pa(*) + INTEGER(C_SHORT), INTENT(IN), TARGET :: pb(*) + INTEGER(C_SHORT), INTENT(IN), TARGET :: pc(*) + ! PROCEDURE(LIBXSMM_FUNCTION6), POINTER :: xmm + ! CALL C_F_PROCPOINTER(kernel%handle, xmm) + ! CALL xmm(...) + CALL libxsmm_xmmcall_prf(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c), & + & C_LOC(pa), C_LOC(pb), C_LOC(pc)) + END SUBROUTINE + + !> See also libxsmm_wimmcall_abc and libxsmm_wimmcall_prf. + SUBROUTINE libxsmm_wimmcall(kernel, a, b, c, pa, pb, pc) + TYPE(LIBXSMM_WIMMFUNCTION), INTENT(IN) :: kernel + INTEGER(C_SHORT), INTENT(IN), TARGET :: a(*), b(*) + INTEGER(C_INT), INTENT(INOUT), TARGET :: c(*) + INTEGER(C_SHORT), INTENT(IN), OPTIONAL, TARGET :: pa(*) + INTEGER(C_SHORT), INTENT(IN), OPTIONAL, TARGET :: pb(*) + INTEGER(C_SHORT), INTENT(IN), OPTIONAL, TARGET :: pc(*) + ! use .OR. instead of .AND. to avoid full check + IF (PRESENT(pa).OR.PRESENT(pb).OR.PRESENT(pc)) THEN + CALL libxsmm_xmmcall_prf(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c), & + & C_LOC(pa), C_LOC(pb), C_LOC(pc)) + ELSE + CALL libxsmm_xmmcall_abc(kernel%handle, & + & C_LOC(a), C_LOC(b), C_LOC(c)) + END IF + END SUBROUTINE + + !> Register user-defined key-value; value can be queried (libxsmm_xdispatch). + !> Since the key-type is unknown to LIBXSMM, the key must be binary reproducible, + !> i.e., if it is a structured type (padded data may be uninitialized), it must + !> be initially zero-filled (libxsmm_xclear) followed by an element-wise setup. + !> The size of the key is limited (see documentation). The given value is copied + !> by LIBXSMM and may be initialized at registration-time or whenever queried. + !> Registered data is released at program termination but can be also released + !> if needed (libxsmm_xrelease), .e.g., for larger value for the same key. + FUNCTION libxsmm_xregister(key, keysize, valsize, & + & valinit, keyhash) + TYPE(C_PTR), INTENT(IN), VALUE :: key + INTEGER(C_INT), INTENT(IN) :: keysize, valsize + TYPE(C_PTR), INTENT(IN), OPTIONAL :: valinit + INTEGER(C_INT), INTENT(OUT), OPTIONAL :: keyhash + TYPE(C_PTR) :: libxsmm_xregister + INTERFACE + SUBROUTINE internal_xregister(regval, & + & key, keysize, valsize, valinit, keyhash) & + & BIND(C, NAME="libxsmm_xregister_") + IMPORT :: C_PTR, C_INT + TYPE(C_PTR), INTENT(OUT) :: regval + TYPE(C_PTR), INTENT(IN), VALUE :: key, valinit + INTEGER(C_INT), INTENT(IN) :: keysize, valsize + INTEGER(C_INT), INTENT(OUT) :: keyhash + END SUBROUTINE + END INTERFACE + CALL internal_xregister(libxsmm_xregister, & + & key, keysize, valsize, valinit, keyhash) + END FUNCTION + + !> Query user-defined value from LIBXSMM's code registry. + FUNCTION libxsmm_xdispatch(key, keysize, keyhash) + TYPE(C_PTR), INTENT(IN), VALUE :: key + INTEGER(C_INT), INTENT(IN) :: keysize + INTEGER(C_INT), INTENT(OUT), OPTIONAL :: keyhash + TYPE(C_PTR) :: libxsmm_xdispatch + INTERFACE + SUBROUTINE internal_xdispatch(regval, key, keysize, keyhash)& + & BIND(C, NAME="libxsmm_xdispatch_") + IMPORT :: C_PTR, C_INT + TYPE(C_PTR), INTENT(OUT) :: regval + TYPE(C_PTR), INTENT(IN), VALUE :: key + INTEGER(C_INT), INTENT(IN) :: keysize + INTEGER(C_INT), INTENT(OUT) :: keyhash + END SUBROUTINE + END INTERFACE + CALL internal_xdispatch(libxsmm_xdispatch, & + & key, keysize, keyhash) + END FUNCTION + + !> Auto-dispatched general dense MM (double-precision). + !> This overload belongs to libxsmm_(d)gemm. + PURE SUBROUTINE libxsmm_dgemm0(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + REAL(C_DOUBLE), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a, b + REAL(C_DOUBLE), INTENT(INOUT) :: c + INTERFACE + PURE SUBROUTINE internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) & + & BIND(C, NAME="libxsmm_dgemm_") + IMPORT :: C_CHAR, C_DOUBLE, LIBXSMM_BLASINT_KIND + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldc + REAL(C_DOUBLE), INTENT(IN) :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a, b + REAL(C_DOUBLE), INTENT(INOUT) :: c + END SUBROUTINE + END INTERFACE + CALL internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + END SUBROUTINE + + !> Auto-dispatched general dense MM (double-precision). + !> This overload belongs to libxsmm_(d)gemm. + PURE SUBROUTINE libxsmm_dgemm1(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + REAL(C_DOUBLE), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a(*), b(*) + REAL(C_DOUBLE), INTENT(INOUT) :: c(*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_dgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1)), lda, & + & b(LBOUND(b,1)), ldb, & + & beta, c(LBOUND(c,1)), ldc) + END IF + END SUBROUTINE + + !> Auto-dispatched general dense MM (double-precision). + !> This overload belongs to libxsmm_(d)gemm. + PURE SUBROUTINE libxsmm_dgemm2(transa, transb, m, n, k, & + & a, b, c, alpha, beta) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + REAL(C_DOUBLE), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a(m,*), b(k,*) + REAL(C_DOUBLE), INTENT(INOUT) :: c(m,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_dgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), m, & + & b(LBOUND(b,1),LBOUND(b,2)), k, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), m) + END IF + END SUBROUTINE + + !> Auto-dispatched general dense MM (double-precision). + !> This overload belongs to libxsmm_(d)gemm. + PURE SUBROUTINE libxsmm_dgemm3(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + REAL(C_DOUBLE), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a(lda,*), b(ldb,*) + REAL(C_DOUBLE), INTENT(INOUT) :: c(ldc,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_dgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), lda, & + & b(LBOUND(b,1),LBOUND(b,2)), ldb, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), ldc) + END IF + END SUBROUTINE + + !> Auto-dispatched general dense MM (single-precision). + !> This overload belongs to libxsmm_(s)gemm. + PURE SUBROUTINE libxsmm_sgemm0(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + REAL(C_FLOAT), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a, b + REAL(C_FLOAT), INTENT(INOUT) :: c + INTERFACE + PURE SUBROUTINE internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) & + & BIND(C, NAME="libxsmm_sgemm_") + IMPORT :: C_CHAR, C_FLOAT, LIBXSMM_BLASINT_KIND + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldc + REAL(C_FLOAT), INTENT(IN) :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a, b + REAL(C_FLOAT), INTENT(INOUT) :: c + END SUBROUTINE + END INTERFACE + CALL internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + END SUBROUTINE + + !> Auto-dispatched general dense MM (single-precision). + !> This overload belongs to libxsmm_(s)gemm. + PURE SUBROUTINE libxsmm_sgemm1(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + REAL(C_FLOAT), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a(*), b(*) + REAL(C_FLOAT), INTENT(INOUT) :: c(*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_sgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1)), lda, & + & b(LBOUND(b,1)), ldb, & + & beta, c(LBOUND(c,1)), ldc) + END IF + END SUBROUTINE + + !> Auto-dispatched general dense MM (single-precision). + !> This overload belongs to libxsmm_(s)gemm. + PURE SUBROUTINE libxsmm_sgemm2(transa, transb, m, n, k, & + & a, b, c, alpha, beta) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + REAL(C_FLOAT), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a(m,*), b(k,*) + REAL(C_FLOAT), INTENT(INOUT) :: c(m,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_sgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), m, & + & b(LBOUND(b,1),LBOUND(b,2)), k, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), m) + END IF + END SUBROUTINE + + !> Auto-dispatched general dense MM (single-precision). + !> This overload belongs to libxsmm_(s)gemm. + PURE SUBROUTINE libxsmm_sgemm3(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + REAL(C_FLOAT), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a(lda,*), b(ldb,*) + REAL(C_FLOAT), INTENT(INOUT) :: c(ldc,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_sgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), lda, & + & b(LBOUND(b,1),LBOUND(b,2)), ldb, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), ldc) + END IF + END SUBROUTINE + + !> Auto-dispatched general dense MM (low-precision, int-accumulate). + !> This overload belongs to libxsmm_(wi)gemm. + PURE SUBROUTINE libxsmm_wigemm0(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + INTEGER(C_INT), INTENT(IN), OPTIONAL :: alpha, beta + INTEGER(C_SHORT), INTENT(IN) :: a, b + INTEGER(C_INT), INTENT(INOUT) :: c + INTERFACE + PURE SUBROUTINE internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) & + & BIND(C, NAME="libxsmm_wigemm_") + IMPORT :: C_CHAR, C_SHORT, C_INT, LIBXSMM_BLASINT_KIND + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldc + INTEGER(C_INT), INTENT(IN) :: alpha, beta + INTEGER(C_SHORT), INTENT(IN) :: a, b + INTEGER(C_INT), INTENT(INOUT) :: c + END SUBROUTINE + END INTERFACE + CALL internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + END SUBROUTINE + + !> Auto-dispatched general dense MM (low-precision, int-accumulate). + !> This overload belongs to libxsmm_(wi)gemm. + PURE SUBROUTINE libxsmm_wigemm1(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + INTEGER(C_INT), INTENT(IN), OPTIONAL :: alpha, beta + INTEGER(C_SHORT), INTENT(IN) :: a(*), b(*) + INTEGER(C_INT), INTENT(INOUT) :: c(*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_wigemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1)), lda, & + & b(LBOUND(b,1)), ldb, & + & beta, c(LBOUND(c,1)), ldc) + END IF + END SUBROUTINE + + !> Auto-dispatched general dense MM (low-precision, int-accumulate). + !> This overload belongs to libxsmm_(wi)gemm. + PURE SUBROUTINE libxsmm_wigemm2(transa, transb, m, n, k, & + & a, b, c, alpha, beta) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(C_INT), INTENT(IN), OPTIONAL :: alpha, beta + INTEGER(C_SHORT), INTENT(IN) :: a(m,*), b(k,*) + INTEGER(C_INT), INTENT(INOUT) :: c(m,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_wigemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), m, & + & b(LBOUND(b,1),LBOUND(b,2)), k, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), m) + END IF + END SUBROUTINE + + !> Auto-dispatched general dense MM (low-precision, int-accumulate). + !> This overload belongs to libxsmm_(wi)gemm. + PURE SUBROUTINE libxsmm_wigemm3(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + INTEGER(C_INT), INTENT(IN), OPTIONAL :: alpha, beta + INTEGER(C_SHORT), INTENT(IN) :: a(lda,*), b(ldb,*) + INTEGER(C_INT), INTENT(INOUT) :: c(ldc,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_wigemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), lda, & + & b(LBOUND(b,1),LBOUND(b,2)), ldb, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), ldc) + END IF + END SUBROUTINE + + !> Re-exposes BLAS based GEMM routine with an interfaces similar to + !> libxsmm_(d)gemm. This overload belongs to libxsmm_blas_(d)gemm. + PURE SUBROUTINE libxsmm_blas_dgemm0(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + REAL(C_DOUBLE), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a, b + REAL(C_DOUBLE), INTENT(INOUT) :: c + INTERFACE + PURE SUBROUTINE internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) & + & BIND(C, NAME="libxsmm_blas_dgemm_") + IMPORT :: C_CHAR, C_DOUBLE, LIBXSMM_BLASINT_KIND + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldc + REAL(C_DOUBLE), INTENT(IN) :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a, b + REAL(C_DOUBLE), INTENT(INOUT) :: c + END SUBROUTINE + END INTERFACE + CALL internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + END SUBROUTINE + + !> Re-exposes BLAS based GEMM routine with an interfaces similar to + !> libxsmm_(d)gemm. This overload belongs to libxsmm_blas_(d)gemm. + PURE SUBROUTINE libxsmm_blas_dgemm1(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + REAL(C_DOUBLE), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a(*), b(*) + REAL(C_DOUBLE), INTENT(INOUT) :: c(*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_blas_dgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1)), lda, & + & b(LBOUND(b,1)), ldb, & + & beta, c(LBOUND(c,1)), ldc) + END IF + END SUBROUTINE + + !> Re-exposes BLAS based GEMM routine with an interfaces similar to + !> libxsmm_(d)gemm. This overload belongs to libxsmm_blas_(d)gemm. + PURE SUBROUTINE libxsmm_blas_dgemm2(transa, transb, m, n, k, & + & a, b, c, alpha, beta) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + REAL(C_DOUBLE), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a(m,*), b(k,*) + REAL(C_DOUBLE), INTENT(INOUT) :: c(m,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_blas_dgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), m, & + & b(LBOUND(b,1),LBOUND(b,2)), k, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), m) + END IF + END SUBROUTINE + + !> Re-exposes BLAS based GEMM routine with an interfaces similar to + !> libxsmm_(d)gemm. This overload belongs to libxsmm_blas_(d)gemm. + PURE SUBROUTINE libxsmm_blas_dgemm3(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + REAL(C_DOUBLE), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_DOUBLE), INTENT(IN) :: a(lda,*), b(ldb,*) + REAL(C_DOUBLE), INTENT(INOUT) :: c(ldc,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_blas_dgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), lda, & + & b(LBOUND(b,1),LBOUND(b,2)), ldb, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), ldc) + END IF + END SUBROUTINE + + !> Re-exposes BLAS based GEMM routine with an interfaces similar to + !> libxsmm_(s)gemm. This overload belongs to libxsmm_blas_(s)gemm. + PURE SUBROUTINE libxsmm_blas_sgemm0(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + REAL(C_FLOAT), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a, b + REAL(C_FLOAT), INTENT(INOUT) :: c + INTERFACE + PURE SUBROUTINE internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) & + & BIND(C, NAME="libxsmm_blas_sgemm_") + IMPORT :: C_CHAR, C_FLOAT, LIBXSMM_BLASINT_KIND + CHARACTER(C_CHAR), INTENT(IN) :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldc + REAL(C_FLOAT), INTENT(IN) :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a, b + REAL(C_FLOAT), INTENT(INOUT) :: c + END SUBROUTINE + END INTERFACE + CALL internal_gemm(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + END SUBROUTINE + + !> Re-exposes BLAS based GEMM routine with an interfaces similar to + !> libxsmm_(s)gemm. This overload belongs to libxsmm_blas_(s)gemm. + PURE SUBROUTINE libxsmm_blas_sgemm1(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: lda + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldc + REAL(C_FLOAT), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a(*), b(*) + REAL(C_FLOAT), INTENT(INOUT) :: c(*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_blas_sgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1)), lda, & + & b(LBOUND(b,1)), ldb, & + & beta, c(LBOUND(c,1)), ldc) + END IF + END SUBROUTINE + + !> Re-exposes BLAS based GEMM routine with an interfaces similar to + !> libxsmm_(s)gemm. This overload belongs to libxsmm_blas_(s)gemm. + PURE SUBROUTINE libxsmm_blas_sgemm2(transa, transb, m, n, k, & + & a, b, c, alpha, beta) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + REAL(C_FLOAT), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a(m,*), b(k,*) + REAL(C_FLOAT), INTENT(INOUT) :: c(m,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_blas_sgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), m, & + & b(LBOUND(b,1),LBOUND(b,2)), k, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), m) + END IF + END SUBROUTINE + + !> Re-exposes BLAS based GEMM routine with an interfaces similar to + !> libxsmm_(s)gemm. This overload belongs to libxsmm_blas_(s)gemm. + PURE SUBROUTINE libxsmm_blas_sgemm3(transa, transb, m, n, k, & + & alpha, a, lda, b, ldb, beta, c, ldc) + CHARACTER, INTENT(IN), OPTIONAL :: transa, transb + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, k + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: lda, ldb, ldc + REAL(C_FLOAT), INTENT(IN), OPTIONAL :: alpha, beta + REAL(C_FLOAT), INTENT(IN) :: a(lda,*), b(ldb,*) + REAL(C_FLOAT), INTENT(INOUT) :: c(ldc,*) + IF ((0.LT.m).AND.(0.LT.n).AND.(0.LT.k)) THEN + CALL libxsmm_blas_sgemm0(transa, transb, m, n, k, & + & alpha, a(LBOUND(a,1),LBOUND(a,2)), lda, & + & b(LBOUND(b,1),LBOUND(b,2)), ldb, & + & beta, c(LBOUND(c,1),LBOUND(c,2)), ldc) + END IF + END SUBROUTINE + + !> Matrix-copy (2-dimensional copy) routine. If the input (optional) + !> is not present, the routine is used to zero-fill the out-matrix. + PURE SUBROUTINE libxsmm_matcopy_p0(output, input, typesize, & + & m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), & + & OPTIONAL :: n, ldi, ldo + INTEGER(C_INT), INTENT(IN) :: typesize + TYPE(C_PTR), INTENT(IN), OPTIONAL :: input + TYPE(C_PTR), INTENT(IN) :: output + CALL libxsmm_xmatcopy(output, input, typesize, & + & m, n, ldi, ldo) + END SUBROUTINE + + !> Matrix-copy (2-dimensional copy) routine (DP/rank-1). + SUBROUTINE libxsmm_matcopy_d1(output, input, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: n + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldi + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldo + REAL(C_DOUBLE), INTENT(OUT), TARGET :: output(*) + REAL(C_DOUBLE), INTENT(IN), OPTIONAL, TARGET :: input(*) + CALL libxsmm_xmatcopy(C_LOC(output), C_LOC(input), 8, & + & m, n, ldi, ldo) + END SUBROUTINE + + !> Matrix-copy (2-dimensional copy) routine (DP/rank-2). + SUBROUTINE libxsmm_matcopy_d2(output, input, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ldi, ldo + REAL(C_DOUBLE), INTENT(OUT), TARGET :: output(ldo,*) + REAL(C_DOUBLE), INTENT(IN), OPTIONAL, TARGET :: input(ldi,*) + CALL libxsmm_xmatcopy(C_LOC(output), C_LOC(input), 8, & + & m, n, ldi, ldo) + END SUBROUTINE + + !> Matrix-copy (2-dimensional copy) routine (SP/rank-1). + SUBROUTINE libxsmm_matcopy_s1(output, input, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: n + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldi + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldo + REAL(C_FLOAT), INTENT(OUT), TARGET :: output(*) + REAL(C_FLOAT), INTENT(IN), OPTIONAL, TARGET :: input(*) + CALL libxsmm_xmatcopy(C_LOC(output), C_LOC(input), 4, & + & m, n, ldi, ldo) + END SUBROUTINE + + !> Matrix-copy (2-dimensional copy) routine (SP/rank-2). + SUBROUTINE libxsmm_matcopy_s2(output, input, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ldi, ldo + REAL(C_FLOAT), INTENT(OUT), TARGET :: output(ldo,*) + REAL(C_FLOAT), INTENT(IN), OPTIONAL, TARGET :: input(ldi,*) + CALL libxsmm_xmatcopy(C_LOC(output), C_LOC(input), 4, & + & m, n, ldi, ldo) + END SUBROUTINE + + !> Transpose a matrix (in-place form). + PURE SUBROUTINE libxsmm_itrans_p0(matrix, typesize, & + & m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: n + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldi + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldo + TYPE(C_PTR), INTENT(IN) :: matrix + INTEGER(C_INT), INTENT(IN) :: typesize + CALL libxsmm_xitrans(matrix, typesize, m, n, ldi, ldo) + END SUBROUTINE + + !> Transpose a matrix (in-place form, DP/rank-1). + SUBROUTINE libxsmm_itrans_d1(matrix, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: n + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldi + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldo + REAL(C_DOUBLE), INTENT(INOUT), TARGET :: matrix(*) + CALL libxsmm_xitrans(C_LOC(matrix), 8, m, n, ldi, ldo) + END SUBROUTINE + + !> Transpose a matrix (in-place form, DP/rank-2). + SUBROUTINE libxsmm_itrans_d2(matrix, m, n, ld) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ld + REAL(C_DOUBLE), INTENT(INOUT), TARGET :: matrix(ld,*) + CALL libxsmm_xitrans(C_LOC(matrix), 8, m, n, ld, ld) + END SUBROUTINE + + !> Transpose a matrix (in-place form, SP/rank-1). + SUBROUTINE libxsmm_itrans_s1(matrix, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: n + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldi + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldo + REAL(C_FLOAT), INTENT(INOUT), TARGET :: matrix(*) + CALL libxsmm_xitrans(C_LOC(matrix), 4, m, n, ldi, ldo) + END SUBROUTINE + + !> Transpose a matrix (in-place form, SP/rank-2). + SUBROUTINE libxsmm_itrans_s2(matrix, m, n, ld) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ld + REAL(C_FLOAT), INTENT(INOUT), TARGET :: matrix(ld,*) + CALL libxsmm_xitrans(C_LOC(matrix), 4, m, n, ld, ld) + END SUBROUTINE + + !> Transpose a matrix (out-of-place form). + PURE SUBROUTINE libxsmm_otrans_p0(output, input, typesize, & + & m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: n + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldi + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldo + TYPE(C_PTR), INTENT(IN) :: output, input + INTEGER(C_INT), INTENT(IN) :: typesize + CALL libxsmm_xotrans(output, input, typesize, m, n, ldi, ldo) + END SUBROUTINE + + !> Transpose a matrix (out-of-place form, DP/rank-1). + SUBROUTINE libxsmm_otrans_d1(output, input, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: n + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldi + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldo + REAL(C_DOUBLE), INTENT(OUT), TARGET :: output(*) + REAL(C_DOUBLE), INTENT(IN), TARGET :: input(*) + CALL libxsmm_xotrans(C_LOC(output), C_LOC(input), & + & 8, m, n, ldi, ldo) + END SUBROUTINE + + !> Transpose a matrix (out-of-place form, DP/rank-2). + SUBROUTINE libxsmm_otrans_d2(output, input, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ldi, ldo + REAL(C_DOUBLE), INTENT(OUT), TARGET :: output(ldo,*) + REAL(C_DOUBLE), INTENT(IN), TARGET :: input(ldi,*) + CALL libxsmm_xotrans(C_LOC(output), C_LOC(input), & + & 8, m, n, ldi, ldo) + END SUBROUTINE + + !> Transpose a matrix (out-of-place form, SP/rank-1). + SUBROUTINE libxsmm_otrans_s1(output, input, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: n + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldi + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), OPTIONAL :: ldo + REAL(C_FLOAT), INTENT(OUT), TARGET :: output(*) + REAL(C_FLOAT), INTENT(IN), TARGET :: input(*) + CALL libxsmm_xotrans(C_LOC(output), C_LOC(input), & + & 4, m, n, ldi, ldo) + END SUBROUTINE + + !> Transpose a matrix (out-of-place form, SP/rank-2). + SUBROUTINE libxsmm_otrans_s2(output, input, m, n, ldi, ldo) + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n, ldi, ldo + REAL(C_FLOAT), INTENT(OUT), TARGET :: output(ldo,*) + REAL(C_FLOAT), INTENT(IN), TARGET :: input(ldi,*) + CALL libxsmm_xotrans(C_LOC(output), C_LOC(input), & + & 4, m, n, ldi, ldo) + END SUBROUTINE + + !> Returns the difference between two timer ticks (cycles). + !> Implicit FORTRAN 77 interface: subroutine available. + PURE FUNCTION libxsmm_timer_ncycles(tick0, tick1) + INTEGER(LIBXSMM_TICKINT_KIND), INTENT(IN) :: tick0, tick1 + INTEGER(LIBXSMM_TICKINT_KIND) :: libxsmm_timer_ncycles + INTERFACE + PURE SUBROUTINE internal_timer_ncycles(ncycles, & + & tick0, tick1) BIND(C, NAME="libxsmm_timer_ncycles_") + IMPORT :: LIBXSMM_TICKINT_KIND + INTEGER(LIBXSMM_TICKINT_KIND), INTENT(IN) :: tick0, tick1 + INTEGER(LIBXSMM_TICKINT_KIND), INTENT(OUT) :: ncycles + END SUBROUTINE + END INTERFACE + CALL internal_timer_ncycles( & + & libxsmm_timer_ncycles, tick0, tick1) + END FUNCTION + + !> Utility function to calculate a collection of scalar differences + !> between two matrices (libxsmm_matdiff_info). The location (m, n) + !> of the largest difference (linf_abs) is recorded (also if NaN). + !> In case of NaN, differences are set to infinity. If no difference + !> is discovered, the location (m, n) is negative (OOB). + !> Implicit FORTRAN 77 interface: + !> TYPE :: info + !> INTEGER(4) :: datatype + !> INTEGER(4|8) :: m, n, ldref, ldtst + !> ARRAY :: ref, tst + PURE SUBROUTINE libxsmm_matdiff(info, datatype, m, n, & + & ref, tst, ldref, ldtst) + INTEGER(C_INT), INTENT(IN) :: datatype + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN), & + & OPTIONAL :: n, ldref, ldtst + TYPE(C_PTR), INTENT(IN), OPTIONAL :: ref, tst + TYPE(LIBXSMM_MATDIFF_INFO), INTENT(OUT) :: info + INTERFACE + PURE SUBROUTINE internal_matdiff(info, datatype, m, n, & + & ref, tst, ldref, ldtst) BIND(C, NAME="libxsmm_matdiff_") + IMPORT :: LIBXSMM_MATDIFF_INFO, LIBXSMM_BLASINT_KIND + IMPORT :: C_PTR, C_INT + INTEGER(C_INT), INTENT(IN) :: datatype + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: m, n + INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: ldref, ldtst + TYPE(C_PTR), INTENT(IN), VALUE :: ref, tst + TYPE(LIBXSMM_MATDIFF_INFO), INTENT(OUT) :: info + END SUBROUTINE + END INTERFACE + CALL internal_matdiff(info, datatype, m, n, & + & ref, tst, ldref, ldtst) + END SUBROUTINE + + !> Calculate co-prime number <= n/2 (except: libxsmm_shuffle(0|1) == 0). + !> Implicit FORTRAN 77 interface: + !> INTEGER(4) :: coprime (OUT) + !> INTEGER(4) :: n + ELEMENTAL FUNCTION libxsmm_shuffle(n) + INTEGER(C_LONG_LONG) :: libxsmm_shuffle + INTEGER(C_INT), INTENT(IN) :: n + INTERFACE + PURE SUBROUTINE internal_shuffle(coprime, n) & + & BIND(C, NAME="libxsmm_shuffle_") + IMPORT :: C_LONG_LONG, C_INT + INTEGER(C_LONG_LONG), INTENT(OUT) :: coprime + INTEGER(C_INT), INTENT(IN) :: n + END SUBROUTINE + END INTERFACE + libxsmm_shuffle = INT(0, KIND=C_LONG_LONG) ! avoid warning (older CRAY) + CALL internal_shuffle(libxsmm_shuffle, n) + END FUNCTION + + !> Calculates a hash value for the given array and seed. + !> FORTRAN 77: see libxsmm_xhash + FUNCTION libxsmm_hash_char(key, seed) + CHARACTER(C_CHAR), INTENT(IN) :: key(:) + INTEGER(C_INT), INTENT(IN) :: seed + INTEGER(C_INT) :: libxsmm_hash_char + libxsmm_hash_char = seed + CALL libxsmm_xhash(libxsmm_hash_char, & + & libxsmm_ptr(key), SIZE(key)) + END FUNCTION + + !> Calculates a hash value for the given array and seed. + !> FORTRAN 77: see libxsmm_xhash + FUNCTION libxsmm_hash_i8(key, seed) + INTEGER(C_INT8_T), INTENT(IN) :: key(:) + INTEGER(C_INT), INTENT(IN) :: seed + INTEGER(C_INT) :: libxsmm_hash_i8 + libxsmm_hash_i8 = seed + CALL libxsmm_xhash(libxsmm_hash_i8, & + & libxsmm_ptr(key), SIZE(key)) + END FUNCTION + + !> Calculates a hash value for the given array and seed. + !> FORTRAN 77: see libxsmm_xhash + FUNCTION libxsmm_hash_i32(key, seed) + INTEGER(C_INT), INTENT(IN) :: key(:) + INTEGER(C_INT), INTENT(IN) :: seed + INTEGER(C_INT) :: libxsmm_hash_i32 + libxsmm_hash_i32 = seed + CALL libxsmm_xhash(libxsmm_hash_i32, & + & libxsmm_ptr(key), SIZE(key) * 4) + END FUNCTION + + !> Calculates a hash value for the given array and seed. + !> FORTRAN 77: see libxsmm_xhash + FUNCTION libxsmm_hash_i64(key, seed) + INTEGER(C_LONG_LONG), INTENT(IN) :: key(:) + INTEGER(C_INT), INTENT(IN) :: seed + INTEGER(C_INT) :: libxsmm_hash_i64 + libxsmm_hash_i64 = seed + CALL libxsmm_xhash(libxsmm_hash_i64, & + & libxsmm_ptr(key), SIZE(key) * 8) + END FUNCTION + + !> Calculates if there is a difference between two arrays. + !> FORTRAN 77: see libxsmm_xdiff + FUNCTION libxsmm_diff_char(a, b) + CHARACTER(C_CHAR), INTENT(IN) :: a(:), b(:) + LOGICAL(C_BOOL) :: libxsmm_diff_char + IF (SIZE(a, KIND=C_LONG_LONG) .EQ. SIZE(b, KIND=C_LONG_LONG)) & + & THEN + CALL libxsmm_xdiff(libxsmm_diff_char, & + & libxsmm_ptr(a), libxsmm_ptr(b), & + & SIZE(a, KIND=C_LONG_LONG)) + ELSE + libxsmm_diff_char = LOGICAL(.TRUE., KIND=C_BOOL) + END IF + END FUNCTION + + !> Calculates if there is a difference between two arrays. + !> FORTRAN 77: see libxsmm_xdiff + FUNCTION libxsmm_diff_i8(a, b) + INTEGER(C_INT8_T), INTENT(IN) :: a(:), b(:) + LOGICAL(C_BOOL) :: libxsmm_diff_i8 + IF (SIZE(a, KIND=C_LONG_LONG) .EQ. SIZE(b, KIND=C_LONG_LONG)) & + & THEN + CALL libxsmm_xdiff(libxsmm_diff_i8, & + & libxsmm_ptr(a), libxsmm_ptr(b), & + & SIZE(a, KIND=C_LONG_LONG)) + ELSE + libxsmm_diff_i8 = LOGICAL(.TRUE., KIND=C_BOOL) + END IF + END FUNCTION + + !> Calculates if there is a difference between two arrays. + !> FORTRAN 77: see libxsmm_xdiff + FUNCTION libxsmm_diff_i32(a, b) + INTEGER(C_INT), INTENT(IN) :: a(:), b(:) + LOGICAL(C_BOOL) :: libxsmm_diff_i32 + IF (SIZE(a, KIND=C_LONG_LONG) .EQ. SIZE(b, KIND=C_LONG_LONG)) & + & THEN + CALL libxsmm_xdiff(libxsmm_diff_i32, & + & libxsmm_ptr(a), libxsmm_ptr(b), & + & SIZE(a, KIND=C_LONG_LONG) * INT(4, KIND=C_LONG_LONG)) + ELSE + libxsmm_diff_i32 = LOGICAL(.TRUE., KIND=C_BOOL) + END IF + END FUNCTION + + !> Calculates if there is a difference between two arrays. + !> FORTRAN 77: see libxsmm_xdiff + FUNCTION libxsmm_diff_i64(a, b) + INTEGER(C_LONG_LONG), INTENT(IN) :: a(:), b(:) + LOGICAL(C_BOOL) :: libxsmm_diff_i64 + IF (SIZE(a, KIND=C_LONG_LONG) .EQ. SIZE(b, KIND=C_LONG_LONG)) & + & THEN + CALL libxsmm_xdiff(libxsmm_diff_i64, & + & libxsmm_ptr(a), libxsmm_ptr(b), & + & SIZE(a, KIND=C_LONG_LONG) * INT(8, KIND=C_LONG_LONG)) + ELSE + libxsmm_diff_i64 = LOGICAL(.TRUE., KIND=C_BOOL) + END IF + END FUNCTION + + !> Check if location is SIMD-aligned and optionally consider the next + !> access as if reached by incrementing the location (in Bytes). + !> Optionally calculates the alignment of the given location in Bytes. + FUNCTION libxsmm_aligned(location, increment, alignment) + TYPE(C_PTR), INTENT(IN), VALUE :: location + INTEGER(C_INT), INTENT(IN), OPTIONAL :: increment + INTEGER(C_INT), INTENT(OUT), OPTIONAL :: alignment + LOGICAL :: libxsmm_aligned ! C_BOOL (GNU Fortran issue) + INTEGER(C_INT) :: aligned + INTERFACE + SUBROUTINE internal_aligned(is_aligned, location, & + & increment, alignment) BIND(C, NAME="libxsmm_aligned_") + IMPORT :: C_PTR, C_INT, C_BOOL + TYPE(C_PTR), VALUE, INTENT(IN) :: location + INTEGER(C_INT), INTENT(IN) :: increment + INTEGER(C_INT), INTENT(OUT) :: alignment + INTEGER(C_INT), INTENT(OUT) :: is_aligned ! C_BOOL + END SUBROUTINE + END INTERFACE + CALL internal_aligned(aligned, location, increment, alignment) + libxsmm_aligned = 0.NE.aligned + END FUNCTION + END MODULE + + + diff --git a/third_party/libxsmm/include/libxsmm.mod b/third_party/libxsmm/include/libxsmm.mod new file mode 100644 index 0000000000000000000000000000000000000000..2e87ea7ec1156b6f9b2c25231541d20eee590098 Binary files /dev/null and b/third_party/libxsmm/include/libxsmm.mod differ diff --git a/third_party/libxsmm/include/libxsmm_config.h b/third_party/libxsmm/include/libxsmm_config.h new file mode 100644 index 0000000000000000000000000000000000000000..bcde13b160eb59c04303be4dd8e6bd94c4c5ec50 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_config.h @@ -0,0 +1,45 @@ +#ifndef LIBXSMM_CONFIG_H +#define LIBXSMM_CONFIG_H + +#if !defined(LIBXSMM_DEFAULT_CONFIG) && defined(LIBXSMM_SOURCE_H) && !defined(LIBXSMM_CONFIGURED) +# define LIBXSMM_DEFAULT_CONFIG +#endif +#if !defined(LIBXSMM_DEFAULT_CONFIG) && defined(_WIN32) +# define LIBXSMM_DEFAULT_CONFIG +#endif + +#if !defined(LIBXSMM_DEFAULT_CONFIG) && (!defined(LIBXSMM_SOURCE_H) || defined(LIBXSMM_CONFIGURED)) +# include "libxsmm_version.h" + + +#else +# define LIBXSMM_CONFIG_VERSION "" +# define LIBXSMM_CONFIG_BRANCH "" +# define LIBXSMM_CONFIG_VERSION_MAJOR INT_MAX +# define LIBXSMM_CONFIG_VERSION_MINOR INT_MAX +# define LIBXSMM_CONFIG_VERSION_UPDATE INT_MAX +# define LIBXSMM_CONFIG_VERSION_PATCH INT_MAX +# define LIBXSMM_CONFIG_BUILD_DATE INT_MAX +#endif + +#define LIBXSMM_CONFIG_CACHELINE 64 +#define LIBXSMM_CONFIG_ALIGNMENT 64 +#define LIBXSMM_CONFIG_MALLOC 0 +#define LIBXSMM_CONFIG_ILP64 0 +#define LIBXSMM_CONFIG_SYNC 1 +#define LIBXSMM_CONFIG_JIT 1 + +#define LIBXSMM_CONFIG_PREFETCH -1 +#define LIBXSMM_CONFIG_MAX_MNK 262144 +#define LIBXSMM_CONFIG_MAX_DIM 64 +#define LIBXSMM_CONFIG_AVG_DIM 32 +#define LIBXSMM_CONFIG_MAX_M 64 +#define LIBXSMM_CONFIG_MAX_N 64 +#define LIBXSMM_CONFIG_MAX_K 64 +#define LIBXSMM_CONFIG_FLAGS 0 +#define LIBXSMM_CONFIG_ALPHA 1 +#define LIBXSMM_CONFIG_BETA 1 +#define LIBXSMM_CONFIG_WRAP 1 + +#endif + diff --git a/third_party/libxsmm/include/libxsmm_cpuid.h b/third_party/libxsmm/include/libxsmm_cpuid.h new file mode 100644 index 0000000000000000000000000000000000000000..83329b82da4ec01b4fe43f666229c42e0ed460c7 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_cpuid.h @@ -0,0 +1,76 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_CPUID_H +#define LIBXSMM_CPUID_H + +#include "libxsmm_macros.h" + +/** + * Enumerates the available target architectures and instruction + * set extensions as returned by libxsmm_get_target_archid(). + * LIBXSMM_X86_ALLFEAT: pseudo-value enabling all features + * used anywhere in LIBXSMM (never set as an architecture, + * used as an upper bound in comparisons to distinct x86). + */ +#define LIBXSMM_TARGET_ARCH_UNKNOWN 0 +#define LIBXSMM_TARGET_ARCH_GENERIC 1 +#define LIBXSMM_X86_GENERIC 1002 +#define LIBXSMM_X86_SSE3 1003 +#define LIBXSMM_X86_SSE42 1004 +#define LIBXSMM_X86_AVX 1005 +#define LIBXSMM_X86_AVX2 1006 +#define LIBXSMM_X86_AVX512 1007 +#define LIBXSMM_X86_AVX512_MIC 1010 /* KNL */ +#define LIBXSMM_X86_AVX512_KNM 1011 +#define LIBXSMM_X86_AVX512_CORE 1020 /* SKX */ +#define LIBXSMM_X86_AVX512_CLX 1021 +#define LIBXSMM_X86_AVX512_CPX 1022 +#define LIBXSMM_X86_AVX512_SPR 1023 +#define LIBXSMM_X86_ALLFEAT 1999 +#define LIBXSMM_AARCH64_V81 2001 /* Baseline */ +#define LIBXSMM_AARCH64_V82 2002 /* A64FX minus SVE */ +#define LIBXSMM_AARCH64_A64FX 2100 /* SVE */ +#define LIBXSMM_AARCH64_ALLFEAT 2999 + +#if defined(LIBXSMM_PLATFORM_X86) +/** Zero-initialized structure; assumes conservative properties. */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_cpuid_info { + int constant_tsc; /** Timer stamp counter is monotonic. */ + int has_context; /** Context switches are permitted. */ +} libxsmm_cpuid_info; +#else +typedef int libxsmm_cpuid_info; +#endif + +/** Returns the target architecture and instruction set extensions. */ +#if defined(__cplusplus) /* note: stay compatible with TF */ +LIBXSMM_API int libxsmm_cpuid_x86(libxsmm_cpuid_info* info = NULL); +LIBXSMM_API int libxsmm_cpuid_arm(libxsmm_cpuid_info* info = NULL); +#else +LIBXSMM_API int libxsmm_cpuid_x86(libxsmm_cpuid_info* info); +LIBXSMM_API int libxsmm_cpuid_arm(libxsmm_cpuid_info* info); +#endif + +/** + * Similar to libxsmm_cpuid_x86, but conceptually not x86-specific. + * The actual code path (as used by LIBXSMM) is determined by + * libxsmm_[get|set]_target_archid/libxsmm_[get|set]_target_arch. + */ +LIBXSMM_API int libxsmm_cpuid(void); + +/** Names the CPU architecture given by CPUID. */ +LIBXSMM_API const char* libxsmm_cpuid_name(int id); + +/** SIMD vector length (VLEN) in 32-bit elements. */ +LIBXSMM_API int libxsmm_cpuid_vlen32(int id); + +#endif /*LIBXSMM_CPUID_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn.h b/third_party/libxsmm/include/libxsmm_dnn.h new file mode 100644 index 0000000000000000000000000000000000000000..c100cbc97593405580ed77e7b6ddc8384de85d5b --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn.h @@ -0,0 +1,132 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_H +#define LIBXSMM_DNN_H + +#include "libxsmm_typedefs.h" + +typedef unsigned int libxsmm_dnn_err_t; + +/** Define error and warning codes */ +#define LIBXSMM_DNN_SUCCESS 0 + +#define LIBXSMM_DNN_WARN_FALLBACK 90000 +#define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING 90001 +#define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING 90002 +#define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING 90003 +#define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING 90004 +#define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING 90005 +#define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING 90006 + +#define LIBXSMM_DNN_ERR_GENERAL 100000 +#define LIBXSMM_DNN_ERR_CREATE_HANDLE 100001 +#define LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE 100002 +#define LIBXSMM_DNN_ERR_INVALID_BLOCKING 100003 +#define LIBXSMM_DNN_ERR_INVALID_HANDLE 100004 +#define LIBXSMM_DNN_ERR_DATA_NOT_BOUND 100005 +#define LIBXSMM_DNN_ERR_CREATE_TENSOR 100006 +#define LIBXSMM_DNN_ERR_INVALID_TENSOR 100007 +#define LIBXSMM_DNN_ERR_MISMATCH_TENSOR 100008 +#define LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR 100009 +#define LIBXSMM_DNN_ERR_INVALID_KIND 100010 +#define LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW 100011 +#define LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT 100012 +#define LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT 100013 +#define LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE 100014 +#define LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS 100015 +#define LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL 100016 +#define LIBXSMM_DNN_ERR_CREATE_LAYOUT 100017 +#define LIBXSMM_DNN_ERR_INVALID_LAYOUT 100018 +#define LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH 100019 +#define LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED 100020 +#define LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE 100021 +#define LIBXSMM_DNN_ERR_INVALID_ALGO 100022 +#define LIBXSMM_DNN_ERR_INVALID_PADDING 100023 +#define LIBXSMM_DNN_ERR_UNKNOWN_BIAS_TYPE 100024 +#define LIBXSMM_DNN_ERR_MISMATCH_BIAS 100025 +#define LIBXSMM_DNN_ERR_INVALID_HANDLE_BIAS 100026 +#define LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL 100027 +#define LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS 100028 +#define LIBXSMM_DNN_ERR_NOT_IMPLEMENTED 100029 +#define LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER 100030 +#define LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION 100031 +#define LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN 100032 +#define LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING 100033 +#define LIBXSMM_DNN_ERR_INVALID_FORMAT_FC 100034 +#define LIBXSMM_DNN_ERR_INVALID_RNN_TYPE 100035 +#define LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN 100036 +#define LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER 100037 +#define LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION 100038 +#define LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION 100039 + +/** Kinds of supported compute flavor operations. */ +typedef enum libxsmm_dnn_compute_kind { + /** Forward path */ + LIBXSMM_DNN_COMPUTE_KIND_FWD, + /** Backward path */ + LIBXSMM_DNN_COMPUTE_KIND_BWD, + /** Updated weights. */ + LIBXSMM_DNN_COMPUTE_KIND_UPD, + /** Backward and weightupdate combined, useful for RNNs */ + LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, + /** All routines, need for some init routines. */ + LIBXSMM_DNN_COMPUTE_KIND_ALL +} libxsmm_dnn_compute_kind; + +/** these are some quantization definitions, not sure if we want to + move them into some main part of LIBXSMM */ +/* @TODO check position of these declarations and defines */ +typedef union LIBXSMM_RETARGETABLE libxsmm_intfloat { + unsigned int ui; + float f; +} libxsmm_intfloat; + +/* F32 masking defines */ +#define LIBXSNN_DNN_MASK_SIGN_F32 0x80000000 +#define LIBXSMM_DNN_MASK_EXP_F32 0x7f800000 +#define LIBXSMM_DNN_MASK_MANT_F32 0x007fffff +#define LIBXSMM_DNN_MASK_ABS_F32 0x7fffffff +#define LIBXSMM_DNN_MASK_FULL_F32 0xffffffff +#define LIBXSMM_DNN_MANT_SZ_F32 23 +#define LIBXSMM_DNN_SZ_F32 32 + +/* DFP16 masking defines */ +#define LIBXSMM_DNN_MANT_DFP16 15 +#define LIXSMMM_DNN_RES_DFP16 libxsmm_sexp2_i8i(-(LIBXSMM_DNN_MANT_DFP16)) + +/* Quantization Rounding Defines */ +#define LIBXSMM_DNN_QUANT_NO_ROUND 80000 +#define LIBXSMM_DNN_QUANT_BIAS_ROUND 80001 +#define LIBXSMM_DNN_QUANT_STOCH_ROUND 80002 +#define LIBXSMM_DNN_QUANT_NEAREST_ROUND 80003 +#define LIBXSMM_DNN_QUANT_FPHW_ROUND 80004 + +/** get string of error code */ +LIBXSMM_API const char* libxsmm_dnn_get_error(libxsmm_dnn_err_t code); +LIBXSMM_API size_t libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype); +LIBXSMM_API size_t libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype); + +/** some quantization helper functions, + @TODO need to be integrated better for all different ways of quantizations */ +LIBXSMM_API void libxsmm_dnn_quantize( float* in_buffer, short* out_buffer, int length, unsigned char add_shift, unsigned char* scf, int round_mode ); +LIBXSMM_API void libxsmm_dnn_quantize_act( float* in_buffer, short* out_buffer, unsigned int N, unsigned int C, unsigned int H, unsigned int W, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ); +LIBXSMM_API void libxsmm_dnn_quantize_fil( float* in_buffer, short* out_buffer, unsigned int K, unsigned int C, unsigned int R, unsigned int S, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int kblk_f32, unsigned int kblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ); +LIBXSMM_API void libxsmm_dnn_dequantize( short* in_buffer, float* out_buffer, int length, unsigned char scf ); + +/** some BF16<->FP32 conversion functions + @TODO we need to find a final place for those */ +LIBXSMM_API void libxsmm_truncate_convert_f32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int length); +LIBXSMM_API void libxsmm_rnaz_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len); +LIBXSMM_API void libxsmm_rne_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len); +LIBXSMM_API void libxsmm_convert_bf16_f32(const libxsmm_bfloat16* in, float* out, unsigned int length); + +#endif /*LIBXSMM_DNN_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn_convolution.h b/third_party/libxsmm/include/libxsmm_dnn_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..0c956546d0f4def95a2a318e92fa5e9b6a28ae76 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn_convolution.h @@ -0,0 +1,93 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_CONVOLUTION_H +#define LIBXSMM_DNN_CONVOLUTION_H + +#include "libxsmm_dnn.h" +#include "libxsmm_dnn_tensor.h" +#include "libxsmm_dnn_fusedbatchnorm.h" + +/** Opaque handles which represents convolutions and LIBXSMM datatypes */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_layer libxsmm_dnn_layer; + +typedef enum libxsmm_dnn_conv_fuse_op { + /* we fuse nothing into convolution */ + LIBXSMM_DNN_CONV_FUSE_NONE = 0 +} libxsmm_dnn_conv_fuse_op; + +/** Type of algorithm used for convolutions. */ +typedef enum libxsmm_dnn_conv_algo { + /** let the library decide */ + LIBXSMM_DNN_CONV_ALGO_AUTO, + /** direct convolution. */ + LIBXSMM_DNN_CONV_ALGO_DIRECT +} libxsmm_dnn_conv_algo; + +/** Structure which describes the input and output of data (DNN). */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_conv_desc { + int N; /* number of images in mini-batch */ + int C; /* number of input feature maps */ + int H; /* height of input image */ + int W; /* width of input image */ + int K; /* number of output feature maps */ + int R; /* height of filter kernel */ + int S; /* width of filter kernel */ + int u; /* vertical stride */ + int v; /* horizontal stride */ + int pad_h; /* height of logical rim padding to input + for adjusting output height */ + int pad_w; /* width of logical rim padding to input + for adjusting output width */ + int pad_h_in; /* height of zero-padding in input buffer, + must equal to pad_h for direct conv */ + int pad_w_in; /* width of zero-padding in input buffer, + must equal to pad_w for direct conv */ + int pad_h_out; /* height of zero-padding in output buffer */ + int pad_w_out; /* width of zero-padding in output buffer */ + int threads; /* number of threads to use when running + convolution */ + libxsmm_dnn_datatype datatype_in; /* datatypes used for all input related buffer */ + libxsmm_dnn_datatype datatype_out; /* datatypes used for all output related buffer */ + libxsmm_dnn_tensor_format buffer_format; /* format which is for buffer buffers */ + libxsmm_dnn_tensor_format filter_format; /* format which is for filter buffers */ + libxsmm_dnn_conv_algo algo; /* convolution algorithm used */ + libxsmm_dnn_conv_option options; /* additional options */ + libxsmm_dnn_conv_fuse_op fuse_ops; /* used ops into convolutions */ +} libxsmm_dnn_conv_desc; + +/** Create a layer handle (non-NULL if successful), and pre-build all JIT-code versions. */ +LIBXSMM_API libxsmm_dnn_layer* libxsmm_dnn_create_conv_layer(libxsmm_dnn_conv_desc conv_desc, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_conv_layer(const libxsmm_dnn_layer* handle); + +/** get layout description of buffers and filters from handle */ +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_create_tensor_datalayout(const libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); + +/** scratch pad management */ +LIBXSMM_API size_t libxsmm_dnn_get_scratch_size(const libxsmm_dnn_layer* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_bind_scratch(libxsmm_dnn_layer* handle, const libxsmm_dnn_compute_kind kind, const void* scratch); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_release_scratch(libxsmm_dnn_layer* handle, const libxsmm_dnn_compute_kind kind); + +/** Bind/Release buffers, filters and bias to layer operation */ +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_bind_tensor(libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type); +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_get_tensor(libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_release_tensor(libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor_type type); + +/** Run the layer identified by the handle; may use threads internally. */ +LIBXSMM_API void libxsmm_dnn_execute(libxsmm_dnn_layer* handle, libxsmm_dnn_compute_kind kind); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_execute_st(libxsmm_dnn_layer* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid); + +/** some helper functions for framework integration */ +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_trans_reg_filter(const libxsmm_dnn_layer* handle); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_trans_reg_bf16_filter(const libxsmm_dnn_layer* handle); + +#endif /*LIBXSMM_DNN_CONVOLUTION_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn_fullyconnected.h b/third_party/libxsmm/include/libxsmm_dnn_fullyconnected.h new file mode 100644 index 0000000000000000000000000000000000000000..4dd480e490c58373c29da93a8a8c33c465de9ff1 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn_fullyconnected.h @@ -0,0 +1,65 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_FULLYCONNECTED_H +#define LIBXSMM_DNN_FULLYCONNECTED_H + +#include "libxsmm_dnn.h" +#include "libxsmm_dnn_tensor.h" + +/** Opaque handles which represents LIBXSMM fullyconnected */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fullyconnected libxsmm_dnn_fullyconnected; + +typedef enum libxsmm_dnn_fullyconnected_fuse_op { + /* the fuse order is: 1. BIAS, 2. Actitvation */ + LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE = 0, + LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS = 1, + LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU = 2, + LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID = 4, + LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU = 3, + LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID = 5 +} libxsmm_dnn_fullyconnected_fuse_op; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fullyconnected_desc { + int N; /* number of images in mini-batch */ + int C; /* number of input feature maps */ + int K; /* number of output feature maps */ + int bn; + int bk; + int bc; + int threads; /* number of threads used */ + int compressed_A; + int sparsity_factor_A; + libxsmm_dnn_datatype datatype_in; /* datatype used for all input related buffers */ + libxsmm_dnn_datatype datatype_out; /* datatype used for all output related buffers */ + libxsmm_dnn_tensor_format buffer_format; /* format which is for activation buffers */ + libxsmm_dnn_tensor_format filter_format; /* format which is for filter buffers */ + libxsmm_dnn_fullyconnected_fuse_op fuse_ops; /* fused operations */ +} libxsmm_dnn_fullyconnected_desc; + +LIBXSMM_API libxsmm_dnn_fullyconnected* libxsmm_dnn_create_fullyconnected(libxsmm_dnn_fullyconnected_desc fullyconnected_desc, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fullyconnected(const libxsmm_dnn_fullyconnected* handle); + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fullyconnected_create_tensor_datalayout(const libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); + +LIBXSMM_API void* libxsmm_dnn_fullyconnected_get_scratch_ptr (const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API size_t libxsmm_dnn_fullyconnected_get_scratch_size(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_scratch(libxsmm_dnn_fullyconnected* handle, const void* scratch); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_scratch(libxsmm_dnn_fullyconnected* handle); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type); +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fullyconnected_get_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_execute_st(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid); + +#endif /*LIBXSMM_DNN_FULLYCONNECTED_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn_fusedbatchnorm.h b/third_party/libxsmm/include/libxsmm_dnn_fusedbatchnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..e94a36a7fa281836c8fbd6a6028b09a8d99bf19f --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn_fusedbatchnorm.h @@ -0,0 +1,39 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_FUSEDBATCHNORM_H +#define LIBXSMM_DNN_FUSEDBATCHNORM_H + +#include "libxsmm_dnn.h" +#include "libxsmm_dnn_tensor.h" + +/** Opaque handles which represents LIBXSMM fusedbatchnorm */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedbatchnorm libxsmm_dnn_fusedbatchnorm; + +LIBXSMM_API libxsmm_dnn_fusedbatchnorm* libxsmm_dnn_create_fusedbatchnorm(libxsmm_dnn_fusedbatchnorm_desc fusedbatchnorm_desc, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fusedbatchnorm(const libxsmm_dnn_fusedbatchnorm* handle); + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fusedbatchnorm_create_tensor_datalayout(const libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); + +LIBXSMM_API size_t libxsmm_dnn_fusedbatchnorm_get_scratch_size(const libxsmm_dnn_fusedbatchnorm* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_bind_scratch(libxsmm_dnn_fusedbatchnorm* handle, const void* scratch); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_release_scratch(libxsmm_dnn_fusedbatchnorm* handle); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_bind_tensor(libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type); +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fusedbatchnorm_get_tensor(libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_release_tensor(libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor_type type); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_execute_st(libxsmm_dnn_fusedbatchnorm* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid); + +#endif /*LIBXSMM_DNN_FUSEDBATCHNORM_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn_fusedgroupnorm.h b/third_party/libxsmm/include/libxsmm_dnn_fusedgroupnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..6d1d90a62e55cc186bdce245f0d701e1c40b2831 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn_fusedgroupnorm.h @@ -0,0 +1,39 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_FUSEDGROUPNORM_H +#define LIBXSMM_DNN_FUSEDGROUPNORM_H + +#include "libxsmm_dnn.h" +#include "libxsmm_dnn_tensor.h" + +/** Opaque handles which represents LIBXSMM fusedgroupnorm */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedgroupnorm libxsmm_dnn_fusedgroupnorm; + +LIBXSMM_API libxsmm_dnn_fusedgroupnorm* libxsmm_dnn_create_fusedgroupnorm(libxsmm_dnn_fusedgroupnorm_desc fusedgroupnorm_desc, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fusedgroupnorm(const libxsmm_dnn_fusedgroupnorm* handle); + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fusedgroupnorm_create_tensor_datalayout(const libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); + +LIBXSMM_API size_t libxsmm_dnn_fusedgroupnorm_get_scratch_size(const libxsmm_dnn_fusedgroupnorm* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_bind_scratch(libxsmm_dnn_fusedgroupnorm* handle, const void* scratch); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_release_scratch(libxsmm_dnn_fusedgroupnorm* handle); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_bind_tensor(libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type); +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fusedgroupnorm_get_tensor(libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_release_tensor(libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor_type type); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_execute_st(libxsmm_dnn_fusedgroupnorm* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_reduce_stats_st(libxsmm_dnn_fusedgroupnorm** handles, int num_handles, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid); + +#endif /*LIBXSMM_DNN_FUSEDGROUPNORM_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn_optimizer.h b/third_party/libxsmm/include/libxsmm_dnn_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..cac46f4078bdbe0a3051b9c2d679e422e79620f3 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn_optimizer.h @@ -0,0 +1,55 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_SGD_H +#define LIBXSMM_DNN_SGD_H + +#include "libxsmm_dnn.h" +#include "libxsmm_dnn_tensor.h" + +/** Opaque handles which represents LIBXSMM optimizer */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_optimizer libxsmm_dnn_optimizer; + +typedef enum libxsmm_dnn_optimizer_type { + LIBXSMM_DNN_OPTIMIZER_SGD = 1 +} libxsmm_dnn_optimizer_type; + + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_optimizer_desc { + int C; /* number of feature maps */ + int K; /* number of feature maps */ + int bc; + int bk; + float learning_rate; /* learning rate */ + int threads; /* number of threads used */ + libxsmm_dnn_optimizer_type opt_type; + libxsmm_dnn_datatype datatype_master; /* datatype used for all input related buffers */ + libxsmm_dnn_datatype datatype; /* datatype used for all input related buffers */ + libxsmm_dnn_tensor_format filter_format; /* format which is for filter buffers */ +} libxsmm_dnn_optimizer_desc; + +LIBXSMM_API libxsmm_dnn_optimizer* libxsmm_dnn_create_optimizer(libxsmm_dnn_optimizer_desc optimizer_desc, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_optimizer(const libxsmm_dnn_optimizer* handle); + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_optimizer_create_tensor_datalayout(const libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); + +LIBXSMM_API void* libxsmm_dnn_optimizer_get_scratch_ptr (const libxsmm_dnn_optimizer* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API size_t libxsmm_dnn_optimizer_get_scratch_size(const libxsmm_dnn_optimizer* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_bind_scratch(libxsmm_dnn_optimizer* handle, const void* scratch); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_release_scratch(libxsmm_dnn_optimizer* handle); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_bind_tensor(libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type); +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_optimizer_get_tensor(libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_release_tensor(libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor_type type); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_execute_st(libxsmm_dnn_optimizer* handle, /*unsigned*/int start_thread, /*unsigned*/int tid); + +#endif /*LIBXSMM_DNN_SGD_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn_pooling.h b/third_party/libxsmm/include/libxsmm_dnn_pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..0a973664b4d3f131d688a23e0c22472dcdef5fbc --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn_pooling.h @@ -0,0 +1,65 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_POOLING_H +#define LIBXSMM_DNN_POOLING_H + +#include "libxsmm_dnn.h" +#include "libxsmm_dnn_tensor.h" + +/** Opaque handles which represents LIBXSMM pooling */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_pooling libxsmm_dnn_pooling; + +typedef enum libxsmm_dnn_pooling_type { + LIBXSMM_DNN_POOLING_MAX = 1, + LIBXSMM_DNN_POOLING_AVG = 2 +} libxsmm_dnn_pooling_type; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_pooling_desc { + int N; /* number of images in mini-batch */ + int C; /* number of input feature maps */ + int H; /* height of input image */ + int W; /* width of input image */ + int R; /* kernel height */ + int S; /* kernel width */ + int u; /* vertical stride */ + int v; /* horizontal stride */ + int pad_h; /* height of logical padding of input buffer */ + int pad_w; /* width of logical padding of input buffer */ + int pad_h_in; /* height of physical zero-padding in input buffer */ + int pad_w_in; /* width of physical zero-padding in input buffer */ + int pad_h_out; /* height of physical zero-padding in output buffer */ + int pad_w_out; /* width of physical zero-padding in output buffer */ + int threads; /* number of threads used */ + libxsmm_dnn_datatype datatype_in; /* datatypes used for all input related buffer */ + libxsmm_dnn_datatype datatype_out; /* datatypes used for all output related buffer */ + libxsmm_dnn_datatype datatype_mask; /* datatypes used for the masks */ + libxsmm_dnn_tensor_format buffer_format; /* format which is for activation buffers */ + libxsmm_dnn_pooling_type pooling_type; /* type of pooling operation */ +} libxsmm_dnn_pooling_desc; + +LIBXSMM_API libxsmm_dnn_pooling* libxsmm_dnn_create_pooling(libxsmm_dnn_pooling_desc pooling_desc, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_pooling(const libxsmm_dnn_pooling* handle); + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_pooling_create_tensor_datalayout(const libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); + +LIBXSMM_API size_t libxsmm_dnn_pooling_get_scratch_size(const libxsmm_dnn_pooling* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_scratch(libxsmm_dnn_pooling* handle, const void* scratch); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_scratch(libxsmm_dnn_pooling* handle); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type); +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_pooling_get_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_execute_st(libxsmm_dnn_pooling* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid); + +#endif /*LIBXSMM_DNN_POOLING_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn_rnncell.h b/third_party/libxsmm/include/libxsmm_dnn_rnncell.h new file mode 100644 index 0000000000000000000000000000000000000000..c3402f9d263238487d17374545710d6c629ef6d7 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn_rnncell.h @@ -0,0 +1,79 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_RNNCELL_H +#define LIBXSMM_DNN_RNNCELL_H + +#include "libxsmm_dnn.h" +#include "libxsmm_dnn_tensor.h" + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_rnncell libxsmm_dnn_rnncell; + +/** Type of algorithm used for convolutions. */ +typedef enum libxsmm_dnn_rnncell_type { + /** simple RNN cell with ReLU as activation function */ + LIBXSMM_DNN_RNNCELL_RNN_RELU, + /** simple RNN cell with sigmoid as activation function */ + LIBXSMM_DNN_RNNCELL_RNN_SIGMOID, + /** simple RNN cell with tanh as activation function */ + LIBXSMM_DNN_RNNCELL_RNN_TANH, + /** LSTM cell */ + LIBXSMM_DNN_RNNCELL_LSTM, + /** GRU cell */ + LIBXSMM_DNN_RNNCELL_GRU +} libxsmm_dnn_rnncell_type; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_rnncell_desc { + int threads; + libxsmm_blasint K; /* number of outputs */ + libxsmm_blasint N; /* size of the minibatch */ + libxsmm_blasint C; /* number of inputs */ + libxsmm_blasint max_T; /* number of time steps */ + libxsmm_blasint bk; + libxsmm_blasint bn; + libxsmm_blasint bc; + int use_fwd_fused_impl; + int fwd_block; + int bwdupd_block; + libxsmm_dnn_rnncell_type cell_type; /* cell type RNN ReLU, RNN Sigmoid, RNN Tanh, LSTM, GRU */ + libxsmm_dnn_datatype datatype_in; /* datatypes used for all input related buffer */ + libxsmm_dnn_datatype datatype_out; /* datatypes used for all output related buffer */ + libxsmm_dnn_tensor_format buffer_format; /* format which is for activation buffers */ + libxsmm_dnn_tensor_format filter_format; /* format which is for filter buffers */ +} libxsmm_dnn_rnncell_desc; + +LIBXSMM_API libxsmm_dnn_rnncell* libxsmm_dnn_create_rnncell(libxsmm_dnn_rnncell_desc rnncell_desc, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_rnncell(const libxsmm_dnn_rnncell* handle); + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_rnncell_create_tensor_datalayout(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); + +LIBXSMM_API size_t libxsmm_dnn_rnncell_get_scratch_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status); +LIBXSMM_API void* libxsmm_dnn_rnncell_get_scratch_ptr (const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* scratch); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind); + +LIBXSMM_API size_t libxsmm_dnn_rnncell_get_internalstate_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status); +LIBXSMM_API void* libxsmm_dnn_rnncell_get_internalstate_ptr (const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* internalstate); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_allocate_forget_bias(libxsmm_dnn_rnncell* handle, const float forget_bias); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type); +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_rnncell_get_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_set_sequence_length( libxsmm_dnn_rnncell* handle, const libxsmm_blasint T ); +LIBXSMM_API libxsmm_blasint libxsmm_dnn_rnncell_get_sequence_length( libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status ); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_execute_st(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid); + +#endif /*LIBXSMM_DNN_RNNCELL_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn_softmaxloss.h b/third_party/libxsmm/include/libxsmm_dnn_softmaxloss.h new file mode 100644 index 0000000000000000000000000000000000000000..0e9b9f5552a51e18f2a387d65c2a361779a1965b --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn_softmaxloss.h @@ -0,0 +1,51 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_SOFTMAXLOSS_H +#define LIBXSMM_DNN_SOFTMAXLOSS_H + +#include "libxsmm_dnn.h" +#include "libxsmm_dnn_tensor.h" + +/** Opaque handles which represents LIBXSMM softmaxloss */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_softmaxloss libxsmm_dnn_softmaxloss; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_softmaxloss_desc { + int N; /* number of images in mini-batch */ + int C; /* number of input feature maps */ + int bn; /* requested N blocking for NCNC format */ + int bc; /* requested C blocking for NCNC format */ + float loss_weight; /* loss weight */ + int threads; /* number of threads used */ + libxsmm_dnn_datatype datatype; /* datatype used for all buffers */ + libxsmm_dnn_tensor_format buffer_format; /* format which is for activation buffers */ +} libxsmm_dnn_softmaxloss_desc; + +LIBXSMM_API libxsmm_dnn_softmaxloss* libxsmm_dnn_create_softmaxloss(libxsmm_dnn_softmaxloss_desc softmaxloss_desc, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_softmaxloss(const libxsmm_dnn_softmaxloss* handle); + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_softmaxloss_create_tensor_datalayout(const libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); + +LIBXSMM_API void* libxsmm_dnn_softmaxloss_get_scratch_ptr (const libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API size_t libxsmm_dnn_softmaxloss_get_scratch_size(const libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_bind_scratch(libxsmm_dnn_softmaxloss* handle, const void* scratch); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_release_scratch(libxsmm_dnn_softmaxloss* handle); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_bind_tensor(libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type); +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_softmaxloss_get_tensor(libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_release_tensor(libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor_type type); + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_execute_st(libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid); + +LIBXSMM_API float libxsmm_dnn_softmaxloss_get_loss(const libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_err_t* status); + +#endif /*LIBXSMM_DNN_SOFTMAXLOSS_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_dnn_tensor.h b/third_party/libxsmm/include/libxsmm_dnn_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..c33185dfbf45e366be7abbc4efac79246910a04b --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_dnn_tensor.h @@ -0,0 +1,199 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_TENSOR_H +#define LIBXSMM_DNN_TENSOR_H + +#include "libxsmm_typedefs.h" +#include "libxsmm_dnn.h" + +/** Opaque handles which represents convolutions and LIBXSMM datatypes */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_tensor libxsmm_dnn_tensor; + +typedef enum libxsmm_dnn_tensor_dimtype { + /** Mini-batch */ + LIBXSMM_DNN_TENSOR_DIMTYPE_N, + /** Image Height */ + LIBXSMM_DNN_TENSOR_DIMTYPE_H, + /** Image Width */ + LIBXSMM_DNN_TENSOR_DIMTYPE_W, + /** channels or input channels */ + LIBXSMM_DNN_TENSOR_DIMTYPE_C, + /** output channels */ + LIBXSMM_DNN_TENSOR_DIMTYPE_K, + /** kernel height */ + LIBXSMM_DNN_TENSOR_DIMTYPE_R, + /** kernel width */ + LIBXSMM_DNN_TENSOR_DIMTYPE_S, + /** sequence lenth counter */ + LIBXSMM_DNN_TENSOR_DIMTYPE_T, + /** channle group counter */ + LIBXSMM_DNN_TENSOR_DIMTYPE_G, + /** general counter */ + LIBXSMM_DNN_TENSOR_DIMTYPE_X +} libxsmm_dnn_tensor_dimtype; + +/** types of different buffers */ +typedef enum libxsmm_dnn_tensor_type { + /** regular input buffer */ + LIBXSMM_DNN_REGULAR_INPUT, + /** regular input buffer */ + LIBXSMM_DNN_REGULAR_INPUT_ADD, + /** regular input buffer, transpose */ + LIBXSMM_DNN_REGULAR_INPUT_TRANS, + /** gradient input buffer */ + LIBXSMM_DNN_GRADIENT_INPUT, + /** gradient input buffer */ + LIBXSMM_DNN_GRADIENT_INPUT_ADD, + /** regular output buffer */ + LIBXSMM_DNN_REGULAR_OUTPUT, + /** gradient output buffer */ + LIBXSMM_DNN_GRADIENT_OUTPUT, + /** general input type */ + LIBXSMM_DNN_INPUT, + /** general output type */ + LIBXSMM_DNN_OUTPUT, + /** general activation type */ + LIBXSMM_DNN_ACTIVATION, + /* regular filter */ + LIBXSMM_DNN_REGULAR_FILTER, + /* regular filter */ + LIBXSMM_DNN_REGULAR_FILTER_TRANS, + /* gradient filter */ + LIBXSMM_DNN_GRADIENT_FILTER, + /* master filter */ + LIBXSMM_DNN_MASTER_FILTER, + /** general filter type */ + LIBXSMM_DNN_FILTER, + /* regular bias */ + LIBXSMM_DNN_REGULAR_CHANNEL_BIAS, + /* gradient bias */ + LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS, + /* bias */ + LIBXSMM_DNN_CHANNEL_BIAS, + /* regular beta */ + LIBXSMM_DNN_REGULAR_CHANNEL_BETA, + /* gradient beta */ + LIBXSMM_DNN_GRADIENT_CHANNEL_BETA, + /* beta */ + LIBXSMM_DNN_CHANNEL_BETA, + /* regular gamma */ + LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA, + /* gradient gamma */ + LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA, + /* Gamma */ + LIBXSMM_DNN_CHANNEL_GAMMA, + /* regular beta */ + LIBXSMM_DNN_CHANNEL_EXPECTVAL, + /* regular beta */ + LIBXSMM_DNN_CHANNEL_RCPSTDDEV, + /* variance */ + LIBXSMM_DNN_CHANNEL_VARIANCE, + /** general bias type */ + LIBXSMM_DNN_CHANNEL_SCALAR, + /** Labels */ + LIBXSMM_DNN_LABEL, + /** batch stats */ + LIBXSMM_DNN_BATCH_STATS, + LIBXSMM_DNN_MAX_STATS_FWD, + LIBXSMM_DNN_MAX_STATS_BWD, + LIBXSMM_DNN_MAX_STATS_UPD, + /** pooling mask */ + LIBXSMM_DNN_POOLING_MASK, + /** ReLU mask */ + LIBXSMM_DNN_RELU_MASK, + /** general type, if needed might cause API issues in copy in/out API */ + LIBXSMM_DNN_TENSOR, + + /** regular input buffer */ + LIBXSMM_DNN_RNN_REGULAR_INPUT, + /** regular previous cell state buffer */ + LIBXSMM_DNN_RNN_REGULAR_CS_PREV, + /** regular previous hidden state buffer */ + LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV, + /** regular weight (LSTM: wi, wc, wf, wo) */ + LIBXSMM_DNN_RNN_REGULAR_WEIGHT, + /** regular recurrent weight (LSTM: ri, rc, rf, ro) */ + LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT, + /** regular weight (LSTM: wi, wc, wf, wo) */ + LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS, + /** regular recurrent weight (LSTM: ri, rc, rf, ro) */ + LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS, + /** regular bias (LSTM: bi, bc, bf, bo) */ + LIBXSMM_DNN_RNN_REGULAR_BIAS, + /** regular output cell state buffer */ + LIBXSMM_DNN_RNN_REGULAR_CS, + /** regular hidden state buffer */ + LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE, + /** gradient input buffer */ + LIBXSMM_DNN_RNN_GRADIENT_INPUT, + /** gradient previous cell state buffer */ + LIBXSMM_DNN_RNN_GRADIENT_CS_PREV, + /** gradient previous hidden state buffer */ + LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV, + /** gradient weight */ + LIBXSMM_DNN_RNN_GRADIENT_WEIGHT, + /** gradient recurrent weight */ + LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT, + /** gradient bias */ + LIBXSMM_DNN_RNN_GRADIENT_BIAS, + /** gradient output cell state buffer */ + LIBXSMM_DNN_RNN_GRADIENT_CS, + /** gradient hidden state buffer */ + LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE, + /** internal i buffer */ + LIBXSMM_DNN_RNN_INTERNAL_I, + /** internal f buffer */ + LIBXSMM_DNN_RNN_INTERNAL_F, + /** internal o buffer */ + LIBXSMM_DNN_RNN_INTERNAL_O, + /** internal ci buffer */ + LIBXSMM_DNN_RNN_INTERNAL_CI, + /** internal co buffer */ + LIBXSMM_DNN_RNN_INTERNAL_CO +} libxsmm_dnn_tensor_type; + +/** layout descriptor to allow external data handling + outside of LIBXSMM */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_tensor_datalayout { + libxsmm_dnn_tensor_dimtype* dim_type; + unsigned int* dim_size; + unsigned int num_dims; + libxsmm_dnn_tensor_format format; /* format of activation buffer */ + libxsmm_dnn_datatype datatype; /* data type */ + libxsmm_dnn_tensor_type tensor_type; /* tensor type */ +} libxsmm_dnn_tensor_datalayout; + +/** tensorlayout handling */ +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_duplicate_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor_datalayout(libxsmm_dnn_tensor_datalayout* layout); +LIBXSMM_API unsigned int libxsmm_dnn_compare_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout_a, const libxsmm_dnn_tensor_datalayout* layout_b, libxsmm_dnn_err_t* status); +LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_size(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status); +LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_elements(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status); + +/** Create and manage buffers, filters and bias (non-NULL if successful) */ +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_tensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_qtensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, const unsigned char exp, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_tensor_data_ptr(libxsmm_dnn_tensor* tensor, const void* data); +LIBXSMM_API void* libxsmm_dnn_get_tensor_data_ptr(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_get_tensor_datalayout(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status); +LIBXSMM_API unsigned char libxsmm_dnn_get_qtensor_scf(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_qtensor_scf(libxsmm_dnn_tensor* tensor, const unsigned char scf); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor(const libxsmm_dnn_tensor* tensor); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_zero_tensor(const libxsmm_dnn_tensor* tensor); + +/** + * Copy-in/out from a plain format such [N][C][H][W] or [N][H][W][C] + */ +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyin_tensor(const libxsmm_dnn_tensor* tensor, const void* data, const libxsmm_dnn_tensor_format in_format); +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyout_tensor(const libxsmm_dnn_tensor* tensor, void* data, const libxsmm_dnn_tensor_format out_format); + +#endif /*LIBXSMM_DNN_TENSOR_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_frontend.h b/third_party/libxsmm/include/libxsmm_frontend.h new file mode 100644 index 0000000000000000000000000000000000000000..afb984985ed187c47f46ad99e06f98bb250ee5b7 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_frontend.h @@ -0,0 +1,590 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_FRONTEND_H +#define LIBXSMM_FRONTEND_H + +#include "libxsmm_typedefs.h" + +/** Helper macros for eliding prefetch address calculations depending on prefetch scheme. */ +#if !defined(_WIN32) && !defined(__CYGWIN__) /* TODO: fully support calling convention */ +#if 0 != ((LIBXSMM_PREFETCH) & 2/*AL2*/) \ + || 0 != ((LIBXSMM_PREFETCH) & 8/*AL2_AHEAD*/) +# define LIBXSMM_GEMM_PREFETCH_A(EXPR) (EXPR) +#endif +#if 0 != ((LIBXSMM_PREFETCH) & 4/*BL2_VIA_C*/) \ + || 0 != ((LIBXSMM_PREFETCH) & 16/*BL1*/) +# define LIBXSMM_GEMM_PREFETCH_B(EXPR) (EXPR) +#endif +#endif +/** Secondary helper macros derived from the above group. */ +#if defined(LIBXSMM_GEMM_PREFETCH_A) +# define LIBXSMM_NOPREFETCH_A(EXPR) +#else +# define LIBXSMM_NOPREFETCH_A(EXPR) EXPR +# define LIBXSMM_GEMM_PREFETCH_A(EXPR) 0 +#endif +#if defined(LIBXSMM_GEMM_PREFETCH_B) +# define LIBXSMM_NOPREFETCH_B(EXPR) +#else +# define LIBXSMM_NOPREFETCH_B(EXPR) EXPR +# define LIBXSMM_GEMM_PREFETCH_B(EXPR) 0 +#endif +#if defined(LIBXSMM_GEMM_PREFETCH_C) +# define LIBXSMM_NOPREFETCH_C(EXPR) +#else +# define LIBXSMM_NOPREFETCH_C(EXPR) EXPR +# define LIBXSMM_GEMM_PREFETCH_C(EXPR) 0 +#endif + +/** MKL_DIRECT_CALL requires to include the MKL interface. */ +#if (defined(MKL_DIRECT_CALL_SEQ) || defined(MKL_DIRECT_CALL) || \ + (defined(__MKL) && !defined(LIBXSMM_BUILD) && \ + (!defined(__BLAS) || (0 != __BLAS)))) +# if (0 != LIBXSMM_ILP64 && !defined(MKL_ILP64)) +# error "Inconsistent ILP64 configuration detected!" +# endif +# if defined(LIBXSMM_OFFLOAD_BUILD) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +# include +# pragma offload_attribute(pop) +# else +# include +# endif +#endif +/** __INTEL_MKL__ is needed later to fix some NOTHROW issue. */ +#if defined(__MKL) && !defined(__INTEL_MKL__) && defined(NOTHROW) +# if defined(LIBXSMM_OFFLOAD_BUILD) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +# include +# pragma offload_attribute(pop) +# else +# include +# endif +#endif + +/** Unfortunately calculation of INTEL_MKL_VERSION is not stable over time. */ +#if defined(__INTEL_MKL__) && defined(__INTEL_MKL_MINOR__) && defined(__INTEL_MKL_UPDATE__) +# define LIBXSMM_MKL_VERSION3 LIBXSMM_VERSION3(__INTEL_MKL__, __INTEL_MKL_MINOR__, __INTEL_MKL_UPDATE__) +#endif + +/** Automatically select a prefetch-strategy (libxsmm_get_gemm_xprefetch, etc.). */ +#define LIBXSMM_PREFETCH_AUTO -1 + +/** Append "_omp" postfix to the given symbol. */ +#define LIBXSMM_USEOMP(FUNCTION) LIBXSMM_CONCATENATE(FUNCTION, _omp) + +/** Helper macro for BLAS-style prefixes. */ +#define LIBXSMM_TPREFIX_NAME(TYPE) LIBXSMM_CONCATENATE(LIBXSMM_TPREFIX_, TYPE) +#define LIBXSMM_TPREFIX(TYPE, FUNCTION) LIBXSMM_CONCATENATE(LIBXSMM_TPREFIX_NAME(TYPE), FUNCTION) +#define LIBXSMM_TPREFIX_doubledouble d +#define LIBXSMM_TPREFIX_floatfloat s +#define LIBXSMM_TPREFIX_shortfloat ws +#define LIBXSMM_TPREFIX_shortint wi +#define LIBXSMM_TPREFIX_libxsmm_bfloat16float bs +/** Defaults if only the input type is specified. */ +#define LIBXSMM_TPREFIX_double LIBXSMM_TPREFIX_doubledouble +#define LIBXSMM_TPREFIX_float LIBXSMM_TPREFIX_floatfloat +#define LIBXSMM_TPREFIX_short LIBXSMM_TPREFIX_shortint + +#define LIBXSMM_GEMM_XFLAGS(ITYPE, OTYPE) LIBXSMM_CONCATENATE(LIBXSMM_GEMM_XFLAGS_, ITYPE) /* ignore OTYPE for now */ +#define LIBXSMM_GEMM_XFLAGS_double 0 +#define LIBXSMM_GEMM_XFLAGS_float 0 +#define LIBXSMM_GEMM_XFLAGS_libxsmm_bfloat16 LIBXSMM_GEMM_FLAG_VNNI_A +#define LIBXSMM_GEMM_XFLAGS_int 0 +#define LIBXSMM_GEMM_XFLAGS_short 0 + +/** Construct symbol name from a given real type name (float, double and short). */ +#define LIBXSMM_BLAS_FNTYPE(TYPE, KIND) LIBXSMM_CONCATENATE3(libxsmm_, LIBXSMM_TPREFIX(TYPE, KIND), _function) +#define LIBXSMM_MMFUNCTION_TYPE(TYPE) LIBXSMM_CONCATENATE(libxsmm_, LIBXSMM_TPREFIX(TYPE, mmfunction)) +#define LIBXSMM_MMDISPATCH_SYMBOL(TYPE) LIBXSMM_CONCATENATE(libxsmm_, LIBXSMM_TPREFIX(TYPE, mmdispatch)) +#define LIBXSMM_XBLAS_SYMBOL(TYPE) LIBXSMM_CONCATENATE(libxsmm_blas_, LIBXSMM_TPREFIX(TYPE, gemm)) +#define LIBXSMM_XGEMM_SYMBOL(TYPE) LIBXSMM_CONCATENATE(libxsmm_, LIBXSMM_TPREFIX(TYPE, gemm)) +#define LIBXSMM_YGEMM_SYMBOL(TYPE) LIBXSMM_USEOMP(LIBXSMM_XGEMM_SYMBOL(TYPE)) +#define LIBXSMM_BLAS_SYMBOL(TYPE, KIND) LIBXSMM_FSYMBOL(LIBXSMM_TPREFIX(TYPE, KIND)) +#define LIBXSMM_CBLAS_SYMBOL LIBXSMM_TPREFIX + +#define LIBXSMM_BLAS_DECL(TYPE, KIND, DECL) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_, LIBXSMM_TPREFIX(TYPE, KIND))(DECL) +#if !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL) +# define LIBXSMM_BLAS_dgemm(DECL) DECL; +# define LIBXSMM_BLAS_sgemm(DECL) DECL; +# define LIBXSMM_BLAS_dgemv(DECL) DECL; +# define LIBXSMM_BLAS_sgemv(DECL) DECL; +#else +# define LIBXSMM_BLAS_dgemm +# define LIBXSMM_BLAS_sgemm +# define LIBXSMM_BLAS_dgemv +# define LIBXSMM_BLAS_sgemv +#endif + +/* Construct prefix names, function type or dispatch function from given input and output types. */ +#define LIBXSMM_MMFUNCTION_TYPE2(ITYPE, OTYPE) LIBXSMM_MMFUNCTION_TYPE(LIBXSMM_CONCATENATE(ITYPE, OTYPE)) +#define LIBXSMM_MMDISPATCH_SYMBOL2(ITYPE, OTYPE) LIBXSMM_MMDISPATCH_SYMBOL(LIBXSMM_CONCATENATE(ITYPE, OTYPE)) +#define LIBXSMM_TPREFIX_NAME2(ITYPE, OTYPE) LIBXSMM_TPREFIX_NAME(LIBXSMM_CONCATENATE(ITYPE, OTYPE)) +#define LIBXSMM_TPREFIX2(ITYPE, OTYPE, FUNCTION) LIBXSMM_TPREFIX(LIBXSMM_CONCATENATE(ITYPE, OTYPE), FUNCTION) + +/** Helper macro for comparing selected types. */ +#define LIBXSMM_EQUAL(T1, T2) LIBXSMM_CONCATENATE3(LIBXSMM_EQUAL_, T1, T2) +#define LIBXSMM_EQUAL_floatfloat 1 +#define LIBXSMM_EQUAL_doubledouble 1 +#define LIBXSMM_EQUAL_floatdouble 0 +#define LIBXSMM_EQUAL_doublefloat 0 +#define LIBXSMM_EQUAL_shortdouble 0 +#define LIBXSMM_EQUAL_shortfloat 0 + +#if defined(LIBXSMM_BLAS_CONST) +# undef LIBXSMM_BLAS_CONST +# define LIBXSMM_BLAS_CONST const +#elif defined(OPENBLAS_CONST) +# define LIBXSMM_BLAS_CONST OPENBLAS_CONST +#elif defined(LIBXSMM_BLAS_NONCONST) || defined(__OPENBLAS) || defined(__OPENBLAS77) +# define LIBXSMM_BLAS_CONST +#else +# define LIBXSMM_BLAS_CONST const +#endif + +#if !defined(LIBXSMM_NO_BLAS) +# if (!defined(__BLAS) || (0 != __BLAS)) +# define LIBXSMM_NO_BLAS 0 +# define LIBXSMM_BLAS 1 +# else +# define LIBXSMM_NO_BLAS 1 +# define LIBXSMM_BLAS 0 +# endif +#endif + +#if defined(__BLAS) && (1 == __BLAS) +# if defined(__OPENBLAS) + LIBXSMM_EXTERN void openblas_set_num_threads(int num_threads); +# define LIBXSMM_BLAS_INIT openblas_set_num_threads(1); +# endif +#endif +#if !defined(LIBXSMM_BLAS_INIT) +# define LIBXSMM_BLAS_INIT +#endif + +#if defined(LIBXSMM_BUILD) +# if defined(LIBXSMM_BUILD_EXT) && !defined(__STATIC) +# define LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_APIEXT +# elif defined(LIBXSMM_NO_BLAS) && (1 == LIBXSMM_NO_BLAS) +# define LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_API +# endif +#endif +#if !defined(LIBXSMM_BLAS_SYMBOL_VISIBILITY) +# define LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_EXTERN LIBXSMM_VISIBILITY_IMPORT LIBXSMM_RETARGETABLE +#endif + +#if defined(NOTHROW) +# define LIBXSMM_BLAS_NOTHROW NOTHROW +#else +# define LIBXSMM_BLAS_NOTHROW LIBXSMM_NOEXCEPT +#endif +#define LIBXSMM_BLAS_NOEXCEPT(KIND) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_NOEXCEPT_, KIND) +#if defined(LIBXSMM_MKL_VERSION3) && (LIBXSMM_VERSION3(2020, 0, 2) <= LIBXSMM_MKL_VERSION3) +# define LIBXSMM_BLAS_NOEXCEPT_gemm_batch LIBXSMM_BLAS_NOTHROW +#else +# define LIBXSMM_BLAS_NOEXCEPT_gemm_batch +#endif +#define LIBXSMM_BLAS_NOEXCEPT_gemm LIBXSMM_BLAS_NOTHROW +#define LIBXSMM_BLAS_NOEXCEPT_gemv LIBXSMM_BLAS_NOTHROW + +#define LIBXSMM_BLAS_SYMBOL_SIGNATURE_gemm_batch(CONST_STAR, STAR, TYPE) char CONST_STAR, char CONST_STAR, \ + libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, \ + TYPE CONST_STAR, TYPE CONST_STAR STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR STAR, libxsmm_blasint CONST_STAR, \ + TYPE CONST_STAR, TYPE STAR STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR +#define LIBXSMM_BLAS_SYMBOL_SIGNATURE_gemm(CONST_STAR, STAR, TYPE) char CONST_STAR, char CONST_STAR, \ + libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR, TYPE CONST_STAR, libxsmm_blasint CONST_STAR, \ + TYPE CONST_STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR, TYPE STAR, libxsmm_blasint CONST_STAR +#define LIBXSMM_BLAS_SYMBOL_SIGNATURE_gemv(CONST_STAR, STAR, TYPE) char CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, \ + TYPE CONST_STAR, TYPE CONST_STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR, libxsmm_blasint CONST_STAR, \ + TYPE CONST_STAR, TYPE STAR, libxsmm_blasint CONST_STAR +#define LIBXSMM_BLAS_SYMBOL_SIGNATURE(CONST_STAR, STAR, TYPE, KIND) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_SYMBOL_SIGNATURE_, KIND)(CONST_STAR, STAR, TYPE) +#define LIBXSMM_BLAS_SYMBOL_FDECL(CONST_STAR, STAR, TYPE, KIND) LIBXSMM_BLAS_SYMBOL_VISIBILITY \ + void LIBXSMM_BLAS_SYMBOL(TYPE, KIND)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(CONST_STAR, STAR, TYPE, KIND)) LIBXSMM_BLAS_NOEXCEPT(KIND) +#define LIBXSMM_BLAS_SYMBOL_CDECL(CONST_STAR, STAR, TYPE, KIND) LIBXSMM_BLAS_SYMBOL_VISIBILITY \ + void LIBXSMM_CBLAS_SYMBOL(TYPE, KIND)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(CONST_STAR, STAR, TYPE, KIND)) LIBXSMM_BLAS_NOEXCEPT(KIND) + +#if (0 != LIBXSMM_BLAS) /* BLAS available */ +# define LIBXSMM_BLAS_SYMBOL_DECL(TYPE, KIND) LIBXSMM_BLAS_DECL(TYPE, KIND, LIBXSMM_BLAS_SYMBOL_FDECL(LIBXSMM_BLAS_CONST*, *, TYPE, KIND)) +#else +# define LIBXSMM_BLAS_SYMBOL_DECL(TYPE, KIND) +#endif + +/** Helper macro consolidating the transpose requests into a set of flags. */ +#define LIBXSMM_GEMM_FLAGS(TRANSA, TRANSB) /* check for N/n rather than T/t since C/c is also valid! */ \ + ((('n' == (TRANSA) || *"N" == (TRANSA)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_A) \ + | (('n' == (TRANSB) || *"N" == (TRANSB)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_B)) + +/** Helper macro consolidating CBLAS transpose requests into a set of flags. */ +#define LIBXSMM_GEMM_CFLAGS(TRANSA, TRANSB) /* check for N/n rather than T/t since C/c is also valid! */ \ + ((CblasNoTrans == (TRANSA) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_A) \ + | (CblasNoTrans == (TRANSB) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_B)) + +/** Helper macro consolidating the transpose requests into a set of flags. */ +#define LIBXSMM_GEMM_VNNI_FLAGS(TRANSA, TRANSB, VNNIA, VNNIB) /* check for N/n rather than T/t since C/c is also valid! */ \ + ((('n' == (TRANSA) || *"N" == (TRANSA)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_A) \ + | (('n' == (TRANSB) || *"N" == (TRANSB)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_B) \ + | (('n' == (VNNIA) || *"N" == (VNNIA)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_VNNI_A) \ + | (('n' == (VNNIB) || *"N" == (VNNIB)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_VNNI_B)) + +/** Helper macro allowing NULL-requests (transposes) supplied by some default. */ +#define LIBXSMM_GEMM_PFLAGS(TRANSA, TRANSB, DEFAULT) LIBXSMM_GEMM_FLAGS( \ + NULL != ((const void*)(TRANSA)) ? (*(const char*)(TRANSA)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & (DEFAULT)) ? 'n' : 't'), \ + NULL != ((const void*)(TRANSB)) ? (*(const char*)(TRANSB)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & (DEFAULT)) ? 'n' : 't')) \ + | (~(LIBXSMM_GEMM_FLAG_TRANS_A | LIBXSMM_GEMM_FLAG_TRANS_B) & (DEFAULT)) + +/** Inlinable GEMM exercising the compiler's code generation (macro template). TODO: only NN is supported and SP/DP matrices. */ +#define LIBXSMM_INLINE_XGEMM(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) { \ + /* Use 'n' (instead of 'N') avoids warning about "no macro replacement within a character constant". */ \ + const char libxsmm_inline_xgemm_transa_ = (char)(NULL != ((void*)(TRANSA)) ? (*(const char*)(TRANSA)) : \ + (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & LIBXSMM_FLAGS) ? 'n' : 't')); \ + const char libxsmm_inline_xgemm_transb_ = (char)(NULL != ((void*)(TRANSB)) ? (*(const char*)(TRANSB)) : \ + (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & LIBXSMM_FLAGS) ? 'n' : 't')); \ + const libxsmm_blasint libxsmm_inline_xgemm_m_ = *(const libxsmm_blasint*)(M); /* must be specified */ \ + const libxsmm_blasint libxsmm_inline_xgemm_k_ = (NULL != ((void*)(K)) ? (*(const libxsmm_blasint*)(K)) : libxsmm_inline_xgemm_m_); \ + const libxsmm_blasint libxsmm_inline_xgemm_n_ = (NULL != ((void*)(N)) ? (*(const libxsmm_blasint*)(N)) : libxsmm_inline_xgemm_k_); \ + const libxsmm_blasint libxsmm_inline_xgemm_lda_ = (NULL != ((void*)(LDA)) ? (*(const libxsmm_blasint*)(LDA)) : \ + (('n' == libxsmm_inline_xgemm_transa_ || *"N" == libxsmm_inline_xgemm_transa_) ? libxsmm_inline_xgemm_m_ : libxsmm_inline_xgemm_k_)); \ + const libxsmm_blasint libxsmm_inline_xgemm_ldb_ = (NULL != ((void*)(LDB)) ? (*(const libxsmm_blasint*)(LDB)) : \ + (('n' == libxsmm_inline_xgemm_transb_ || *"N" == libxsmm_inline_xgemm_transb_) ? libxsmm_inline_xgemm_k_ : libxsmm_inline_xgemm_n_)); \ + const libxsmm_blasint libxsmm_inline_xgemm_ldc_ = (NULL != ((void*)(LDC)) ? (*(const libxsmm_blasint*)(LDC)) : libxsmm_inline_xgemm_m_); \ + const OTYPE libxsmm_inline_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \ + const OTYPE libxsmm_inline_xgemm_beta_ = (NULL != ((void*)(BETA)) ? (*(const OTYPE*)(BETA)) : ((OTYPE)LIBXSMM_BETA)); \ + libxsmm_blasint libxsmm_inline_xgemm_ni_, libxsmm_inline_xgemm_mi_ = 0, libxsmm_inline_xgemm_ki_; /* loop induction variables */ \ + LIBXSMM_ASSERT('n' == libxsmm_inline_xgemm_transa_ || *"N" == libxsmm_inline_xgemm_transa_); \ + LIBXSMM_ASSERT('n' == libxsmm_inline_xgemm_transb_ || *"N" == libxsmm_inline_xgemm_transb_); \ + LIBXSMM_PRAGMA_SIMD \ + for (libxsmm_inline_xgemm_mi_ = 0; libxsmm_inline_xgemm_mi_ < libxsmm_inline_xgemm_m_; ++libxsmm_inline_xgemm_mi_) { \ + LIBXSMM_PRAGMA_LOOP_COUNT(1, LIBXSMM_CONFIG_MAX_DIM, LIBXSMM_CONFIG_AVG_DIM) \ + for (libxsmm_inline_xgemm_ki_ = 0; libxsmm_inline_xgemm_ki_ < libxsmm_inline_xgemm_k_; ++libxsmm_inline_xgemm_ki_) { \ + LIBXSMM_PRAGMA_UNROLL \ + for (libxsmm_inline_xgemm_ni_ = 0; libxsmm_inline_xgemm_ni_ < libxsmm_inline_xgemm_n_; ++libxsmm_inline_xgemm_ni_) { \ + ((OTYPE*)(C))[libxsmm_inline_xgemm_ni_*libxsmm_inline_xgemm_ldc_+libxsmm_inline_xgemm_mi_] \ + = ((const ITYPE*)(B))[libxsmm_inline_xgemm_ni_*libxsmm_inline_xgemm_ldb_+libxsmm_inline_xgemm_ki_] * \ + (((const ITYPE*)(A))[libxsmm_inline_xgemm_ki_*libxsmm_inline_xgemm_lda_+libxsmm_inline_xgemm_mi_] * libxsmm_inline_xgemm_alpha_) \ + + ((const OTYPE*)(C))[libxsmm_inline_xgemm_ni_*libxsmm_inline_xgemm_ldc_+libxsmm_inline_xgemm_mi_] * libxsmm_inline_xgemm_beta_; \ + } \ + } \ + } \ +} + +#if (defined(LIBXSMM_INIT) || defined(LIBXSMM_CTOR)) +# undef LIBXSMM_INIT +# define LIBXSMM_INIT LIBXSMM_ASSERT_MSG(1 < libxsmm_ninit, "LIBXSMM is not initialized"); +# define LIBXSMM_INIT_COMPLETED +#else +# define LIBXSMM_INIT if (2 > libxsmm_ninit) libxsmm_init(); +#endif + +/** Map to appropriate BLAS function (or fallback). The mapping is used, e.g., inside of LIBXSMM_BLAS_XGEMM. */ +#define LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, FUNCTION) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_FUNCTION_, LIBXSMM_TPREFIX2(ITYPE, OTYPE, FUNCTION)) +#if (0 != LIBXSMM_BLAS) /* Helper macro to eventually (if defined) call libxsmm_init */ +# if defined(LIBXSMM_INIT_COMPLETED) +# define LIBXSMM_BLAS_FUNCTION_dgemm_batch libxsmm_original_dgemm_batch_function +# define LIBXSMM_BLAS_FUNCTION_sgemm_batch libxsmm_original_sgemm_batch_function +# define LIBXSMM_BLAS_FUNCTION_dgemm libxsmm_original_dgemm_function +# define LIBXSMM_BLAS_FUNCTION_sgemm libxsmm_original_sgemm_function +# define LIBXSMM_BLAS_FUNCTION_dgemv libxsmm_original_dgemv_function +# define LIBXSMM_BLAS_FUNCTION_sgemv libxsmm_original_sgemv_function +# else +# define LIBXSMM_BLAS_FUNCTION_dgemm_batch libxsmm_original_dgemm_batch() +# define LIBXSMM_BLAS_FUNCTION_sgemm_batch libxsmm_original_sgemm_batch() +# define LIBXSMM_BLAS_FUNCTION_dgemm libxsmm_original_dgemm() +# define LIBXSMM_BLAS_FUNCTION_sgemm libxsmm_original_sgemm() +# define LIBXSMM_BLAS_FUNCTION_dgemv libxsmm_original_dgemv() +# define LIBXSMM_BLAS_FUNCTION_sgemv libxsmm_original_sgemv() +# endif +#else /* no BLAS */ +# define LIBXSMM_BLAS_FUNCTION_dgemm_batch libxsmm_blas_error("dgemm_batch") +# define LIBXSMM_BLAS_FUNCTION_sgemm_batch libxsmm_blas_error("sgemm_batch") +# define LIBXSMM_BLAS_FUNCTION_dgemm libxsmm_blas_error("dgemm") +# define LIBXSMM_BLAS_FUNCTION_sgemm libxsmm_blas_error("sgemm") +# define LIBXSMM_BLAS_FUNCTION_dgemv libxsmm_blas_error("dgemv") +# define LIBXSMM_BLAS_FUNCTION_sgemv libxsmm_blas_error("sgemv") +#endif +/** Low-precision (BLAS-like) function symbols. */ +#define LIBXSMM_BLAS_FUNCTION_wigemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ + LIBXSMM_INLINE_XGEMM(short, int, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) +#define LIBXSMM_BLAS_FUNCTION_bsgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ + LIBXSMM_INLINE_XGEMM(libxsmm_bfloat16, float, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) + +/** Short-cut macros to construct desired BLAS function symbol. */ +#define LIBXSMM_BLAS_FUNCTION1(TYPE, FUNCTION) LIBXSMM_BLAS_FUNCTION(TYPE, TYPE, FUNCTION) +#define LIBXSMM_GEMM_BATCH_SYMBOL(TYPE) LIBXSMM_BLAS_FUNCTION1(TYPE, gemm_batch) +#define LIBXSMM_GEMM_SYMBOL(TYPE) LIBXSMM_BLAS_FUNCTION1(TYPE, gemm) +#define LIBXSMM_GEMV_SYMBOL(TYPE) LIBXSMM_BLAS_FUNCTION1(TYPE, gemv) + +/** BLAS-based GEMM supplied by the linked LAPACK/BLAS library (macro template). */ +#define LIBXSMM_BLAS_XGEMM(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) { \ + /* Use 'n' (instead of 'N') avoids warning about "no macro replacement within a character constant". */ \ + const char libxsmm_blas_xgemm_transa_ = (char)(NULL != ((void*)(TRANSA)) ? (*(const char*)(TRANSA)) : \ + (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & LIBXSMM_FLAGS) ? 'n' : 't')); \ + const char libxsmm_blas_xgemm_transb_ = (char)(NULL != ((void*)(TRANSB)) ? (*(const char*)(TRANSB)) : \ + (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & LIBXSMM_FLAGS) ? 'n' : 't')); \ + const libxsmm_blasint *const libxsmm_blas_xgemm_k_ = (NULL != ((void*)(K)) ? (K) : (M)); \ + const libxsmm_blasint *const libxsmm_blas_xgemm_n_ = (NULL != ((void*)(N)) ? (N) : libxsmm_blas_xgemm_k_); \ + const libxsmm_blasint libxsmm_blas_xgemm_lda_ = LIBXSMM_MAX(NULL != ((void*)(LDA)) ? *(LDA) : \ + *(('n' == libxsmm_blas_xgemm_transa_ || *"N" == libxsmm_blas_xgemm_transa_) ? (M) : libxsmm_blas_xgemm_k_), 1); \ + const libxsmm_blasint libxsmm_blas_xgemm_ldb_ = LIBXSMM_MAX(NULL != ((void*)(LDB)) ? *(LDB) : \ + *(('n' == libxsmm_blas_xgemm_transb_ || *"N" == libxsmm_blas_xgemm_transb_) ? libxsmm_blas_xgemm_k_ : libxsmm_blas_xgemm_n_), 1); \ + const libxsmm_blasint libxsmm_blas_xgemm_ldc_ = LIBXSMM_MAX(NULL != ((void*)(LDC)) ? *(LDC) : *(M), 1); \ + const OTYPE libxsmm_blas_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \ + const OTYPE libxsmm_blas_xgemm_beta_ = (NULL != ((void*)(BETA)) ? (*(const OTYPE*)(BETA)) : ((OTYPE)LIBXSMM_BETA)); \ + LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, gemm)(&libxsmm_blas_xgemm_transa_, &libxsmm_blas_xgemm_transb_, \ + M, libxsmm_blas_xgemm_n_, libxsmm_blas_xgemm_k_, \ + &libxsmm_blas_xgemm_alpha_, (const ITYPE*)(A), &libxsmm_blas_xgemm_lda_, \ + (const ITYPE*)(B), &libxsmm_blas_xgemm_ldb_, \ + &libxsmm_blas_xgemm_beta_, (ITYPE*)(C), &libxsmm_blas_xgemm_ldc_); \ +} + +/** Helper macros for calling a dispatched function in a row/column-major aware fashion. */ +#define LIBXSMM_MMCALL_ABC(FN, A, B, C) \ + LIBXSMM_ASSERT(FN); FN(A, B, C) +#define LIBXSMM_MMCALL_PRF(FN, A, B, C, PA, PB, PC) { \ + LIBXSMM_NOPREFETCH_A(LIBXSMM_UNUSED(PA)); \ + LIBXSMM_NOPREFETCH_B(LIBXSMM_UNUSED(PB)); \ + LIBXSMM_NOPREFETCH_C(LIBXSMM_UNUSED(PC)); \ + LIBXSMM_ASSERT(FN); FN(A, B, C, \ + LIBXSMM_GEMM_PREFETCH_A(PA), \ + LIBXSMM_GEMM_PREFETCH_B(PB), \ + LIBXSMM_GEMM_PREFETCH_C(PC)); \ +} + +#if (0/*LIBXSMM_GEMM_PREFETCH_NONE*/ == LIBXSMM_PREFETCH) +# define LIBXSMM_MMCALL_LDX(FN, A, B, C, M, N, K, LDA, LDB, LDC) \ + LIBXSMM_MMCALL_ABC(FN, A, B, C) +#else +# define LIBXSMM_MMCALL_LDX(FN, A, B, C, M, N, K, LDA, LDB, LDC) \ + LIBXSMM_MMCALL_PRF(FN, A, B, C, (A) + ((size_t)LDA) * (K), (B) + ((size_t)LDB) * (N), (C) + ((size_t)LDC) * (N)) +#endif +#define LIBXSMM_MMCALL(FN, A, B, C, M, N, K) LIBXSMM_MMCALL_LDX(FN, A, B, C, M, N, K, M, K, M) + +/** Calculate problem size from M, N, and K using the correct integer type in order to cover the general case. */ +#define LIBXSMM_MNK_SIZE(M, N, K) (((size_t)(M)) * ((size_t)(N)) * ((size_t)(K))) +/** Calculate total number of matrix-elements; matrices A, B, C are given per M, N, K, and emphasize (S) the C-size. */ +#define LIBXSMM_SIZE(M, N, K, S) \ + (((size_t)(M) * (size_t)(K)) + ((size_t)(K) * (size_t)(N)) + \ + (((size_t)(S) * (size_t)(M) * (size_t)(N)))) +/** Condition based on arithmetic intensity (AI) */ +#define LIBXSMM_SMM_AI(M, N, K, S, TYPESIZE) \ + ((LIBXSMM_MNK_SIZE(M, N, K) * 2) <= ((size_t)(TYPESIZE) * 4/*AI*/ * LIBXSMM_SIZE(M, N, K, S))) +/** Determine whether an SMM is suitable, i.e., small enough. */ +#if !defined(LIBXSMM_THRESHOLD_AI) /* traditional MNK-threshold */ +# define LIBXSMM_SMM(M, N, K, S, TYPESIZE) (LIBXSMM_MNK_SIZE(M, N, K) <= (LIBXSMM_MAX_MNK)) +#else /* threshold based on arithmetic intensity */ +# define LIBXSMM_SMM LIBXSMM_SMM_AI +#endif + +/** Fall-back code paths: LIBXSMM_XGEMM_FALLBACK0, and LIBXSMM_XGEMM_FALLBACK1 (macro template). */ +#if !defined(LIBXSMM_XGEMM_FALLBACK0) +# define LIBXSMM_XGEMM_FALLBACK0(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ + LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, gemm)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) +#endif +#if !defined(LIBXSMM_XGEMM_FALLBACK1) +# define LIBXSMM_XGEMM_FALLBACK1(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ + LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, gemm)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) +#endif + +/** + * Execute a specialized function, or use a fallback code path depending on threshold (macro template). + * LIBXSMM_XGEMM_FALLBACK0 or specialized function: below LIBXSMM_MAX_MNK + * LIBXSMM_XGEMM_FALLBACK1: above LIBXSMM_MAX_MNK + */ +#define LIBXSMM_XGEMM(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) { \ + const int libxsmm_xgemm_flags_ = LIBXSMM_GEMM_PFLAGS(TRANSA, TRANSB, LIBXSMM_FLAGS) | LIBXSMM_GEMM_XFLAGS(ITYPE, OTYPE); \ + const libxsmm_blasint *const libxsmm_xgemm_k_ = (NULL != (K) ? (K) : (M)); \ + const libxsmm_blasint *const libxsmm_xgemm_n_ = (NULL != (N) ? (N) : libxsmm_xgemm_k_); \ + const libxsmm_blasint libxsmm_xgemm_lda_ = LIBXSMM_MAX(NULL != ((void*)(LDA)) ? *(LDA) : \ + *(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & libxsmm_xgemm_flags_) ? (M) : libxsmm_xgemm_k_), 1); \ + const libxsmm_blasint libxsmm_xgemm_ldb_ = LIBXSMM_MAX(NULL != ((void*)(LDB)) ? *(LDB) : \ + *(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & libxsmm_xgemm_flags_) ? libxsmm_xgemm_k_ : libxsmm_xgemm_n_), 1); \ + const libxsmm_blasint libxsmm_xgemm_ldc_ = LIBXSMM_MAX(NULL != (LDC) ? *(LDC) : *(M), 1); \ + if (LIBXSMM_SMM(*(M), *libxsmm_xgemm_n_, *libxsmm_xgemm_k_, 2/*RFO*/, sizeof(OTYPE))) { \ + const LIBXSMM_MMFUNCTION_TYPE2(ITYPE, OTYPE) libxsmm_mmfunction_ = LIBXSMM_MMDISPATCH_SYMBOL2(ITYPE, OTYPE)( \ + *(M), *libxsmm_xgemm_n_, *libxsmm_xgemm_k_, &libxsmm_xgemm_lda_, &libxsmm_xgemm_ldb_, &libxsmm_xgemm_ldc_, \ + (const OTYPE*)(ALPHA), (const OTYPE*)(BETA), &libxsmm_xgemm_flags_, NULL); \ + if (NULL != libxsmm_mmfunction_) { \ + LIBXSMM_MMCALL_LDX(libxsmm_mmfunction_, (const ITYPE*)(A), (const ITYPE*)(B), (OTYPE*)(C), \ + *(M), *libxsmm_xgemm_n_, *libxsmm_xgemm_k_, libxsmm_xgemm_lda_, libxsmm_xgemm_ldb_, libxsmm_xgemm_ldc_); \ + } \ + else { \ + const char libxsmm_xgemm_transa_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & libxsmm_xgemm_flags_) ? 'n' : 't'); \ + const char libxsmm_xgemm_transb_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & libxsmm_xgemm_flags_) ? 'n' : 't'); \ + const OTYPE libxsmm_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \ + const OTYPE libxsmm_xgemm_beta_ = (NULL != ((void*)(BETA)) ? (*(const OTYPE*)(BETA)) : ((OTYPE)LIBXSMM_BETA)); \ + LIBXSMM_XGEMM_FALLBACK0(ITYPE, OTYPE, &libxsmm_xgemm_transa_, &libxsmm_xgemm_transb_, \ + M, libxsmm_xgemm_n_, libxsmm_xgemm_k_, \ + &libxsmm_xgemm_alpha_, A, &libxsmm_xgemm_lda_, \ + B, &libxsmm_xgemm_ldb_, \ + &libxsmm_xgemm_beta_, C, &libxsmm_xgemm_ldc_); \ + } \ + } \ + else { \ + const char libxsmm_xgemm_transa_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & libxsmm_xgemm_flags_) ? 'n' : 't'); \ + const char libxsmm_xgemm_transb_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & libxsmm_xgemm_flags_) ? 'n' : 't'); \ + const OTYPE libxsmm_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \ + const OTYPE libxsmm_xgemm_beta_ = (NULL != ((void*)(BETA)) ? (*(const OTYPE*)(BETA)) : ((OTYPE)LIBXSMM_BETA)); \ + LIBXSMM_XGEMM_FALLBACK1(ITYPE, OTYPE, &libxsmm_xgemm_transa_, &libxsmm_xgemm_transb_, \ + M, libxsmm_xgemm_n_, libxsmm_xgemm_k_, \ + &libxsmm_xgemm_alpha_, A, &libxsmm_xgemm_lda_, \ + B, &libxsmm_xgemm_ldb_, \ + &libxsmm_xgemm_beta_, C, &libxsmm_xgemm_ldc_); \ + } \ +} + +/** Helper macro to setup a matrix with some initial values. */ +#define LIBXSMM_MATINIT_AUX(OMP, TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) { \ + /*const*/ double libxsmm_matinit_seed_ = (double)(SEED); /* avoid constant conditional */ \ + const double libxsmm_matinit_scale_ = (SCALE) * libxsmm_matinit_seed_ + (SCALE); \ + const libxsmm_blasint libxsmm_matinit_nrows_ = (libxsmm_blasint)NROWS; \ + const libxsmm_blasint libxsmm_matinit_ld_ = (libxsmm_blasint)LD; \ + libxsmm_blasint libxsmm_matinit_i_ = 0, libxsmm_matinit_j_ = 0; \ + LIBXSMM_OMP_VAR(libxsmm_matinit_i_); LIBXSMM_OMP_VAR(libxsmm_matinit_j_); \ + if (0 != libxsmm_matinit_seed_) { \ + OMP(parallel for private(libxsmm_matinit_i_, libxsmm_matinit_j_)) \ + for (libxsmm_matinit_i_ = 0; libxsmm_matinit_i_ < ((libxsmm_blasint)NCOLS); ++libxsmm_matinit_i_) { \ + for (libxsmm_matinit_j_ = 0; libxsmm_matinit_j_ < libxsmm_matinit_nrows_; ++libxsmm_matinit_j_) { \ + const libxsmm_blasint libxsmm_matinit_k_ = libxsmm_matinit_i_ * libxsmm_matinit_ld_ + libxsmm_matinit_j_; \ + (DST)[libxsmm_matinit_k_] = (TYPE)(libxsmm_matinit_scale_ * (1.0 + \ + libxsmm_matinit_i_ * libxsmm_matinit_nrows_ + libxsmm_matinit_j_)); \ + } \ + for (; libxsmm_matinit_j_ < libxsmm_matinit_ld_; ++libxsmm_matinit_j_) { \ + const libxsmm_blasint libxsmm_matinit_k_ = libxsmm_matinit_i_ * libxsmm_matinit_ld_ + libxsmm_matinit_j_; \ + (DST)[libxsmm_matinit_k_] = (TYPE)(SEED); \ + } \ + } \ + } \ + else { /* shuffle based initialization */ \ + const unsigned int libxsmm_matinit_maxval_ = ((unsigned int)NCOLS) * ((unsigned int)libxsmm_matinit_ld_); \ + const TYPE libxsmm_matinit_maxval2_ = (TYPE)(libxsmm_matinit_maxval_ / 2), libxsmm_matinit_inv_ = (TYPE)((SCALE) / libxsmm_matinit_maxval2_); \ + const size_t libxsmm_matinit_shuffle_ = libxsmm_shuffle(libxsmm_matinit_maxval_); \ + OMP(parallel for private(libxsmm_matinit_i_, libxsmm_matinit_j_)) \ + for (libxsmm_matinit_i_ = 0; libxsmm_matinit_i_ < ((libxsmm_blasint)NCOLS); ++libxsmm_matinit_i_) { \ + for (libxsmm_matinit_j_ = 0; libxsmm_matinit_j_ < libxsmm_matinit_ld_; ++libxsmm_matinit_j_) { \ + const libxsmm_blasint libxsmm_matinit_k_ = libxsmm_matinit_i_ * libxsmm_matinit_ld_ + libxsmm_matinit_j_; \ + (DST)[libxsmm_matinit_k_] = libxsmm_matinit_inv_ * /* normalize values to an interval of [-1, +1] */ \ + ((TYPE)(libxsmm_matinit_shuffle_ * libxsmm_matinit_k_ % libxsmm_matinit_maxval_) - libxsmm_matinit_maxval2_); \ + } \ + } \ + } \ +} + +#define LIBXSMM_MATINIT(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) \ + LIBXSMM_MATINIT_AUX(LIBXSMM_ELIDE, TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) +#define LIBXSMM_MATINIT_SEQ(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) \ + LIBXSMM_MATINIT(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) +#define LIBXSMM_MATINIT_OMP(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) \ + LIBXSMM_MATINIT_AUX(LIBXSMM_PRAGMA_OMP, TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) + +/** Call libxsmm_gemm_print using LIBXSMM's GEMM-flags. */ +#define LIBXSMM_GEMM_PRINT(OSTREAM, PRECISION, FLAGS, M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) \ + LIBXSMM_GEMM_PRINT2(OSTREAM, PRECISION, PRECISION, FLAGS, M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) +#define LIBXSMM_GEMM_PRINT2(OSTREAM, IPREC, OPREC, FLAGS, M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) \ + libxsmm_gemm_dprint2(OSTREAM, (libxsmm_gemm_precision)(IPREC), (libxsmm_gemm_precision)(OPREC), \ + /* Use 'n' (instead of 'N') avoids warning about "no macro replacement within a character constant". */ \ + (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & (FLAGS)) ? 'n' : 't'), \ + (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & (FLAGS)) ? 'n' : 't'), \ + M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) + +/** + * Utility function, which either prints information about the GEMM call + * or dumps (FILE/ostream=0) all input and output data into MHD files. + * The Meta Image Format (MHD) is suitable for visual inspection using, + * e.g., ITK-SNAP or ParaView. + */ +LIBXSMM_API void libxsmm_gemm_print(void* ostream, + libxsmm_gemm_precision precision, const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, + const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc); +LIBXSMM_API void libxsmm_gemm_print2(void* ostream, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, + const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc); +LIBXSMM_API void libxsmm_gemm_dprint(void* ostream, + libxsmm_gemm_precision precision, char transa, char transb, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + double dalpha, const void* a, libxsmm_blasint lda, + const void* b, libxsmm_blasint ldb, + double dbeta, void* c, libxsmm_blasint ldc); +LIBXSMM_API void libxsmm_gemm_dprint2(void* ostream, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, char transa, char transb, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + double dalpha, const void* a, libxsmm_blasint lda, + const void* b, libxsmm_blasint ldb, + double dbeta, void* c, libxsmm_blasint ldc); +LIBXSMM_API void libxsmm_gemm_xprint(void* ostream, + libxsmm_xmmfunction kernel, const void* a, const void* b, void* c); + +/** GEMM_BATCH: fallback prototype functions served by any compliant LAPACK/BLAS. */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dgemm_batch_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm_batch)); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sgemm_batch_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm_batch)); +/** GEMM: fallback prototype functions served by any compliant LAPACK/BLAS. */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dgemm_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm)); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sgemm_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm)); +/** GEMV: fallback prototype functions served by any compliant LAPACK/BLAS. */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dgemv_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemv)); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sgemv_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemv)); +/** Helper function to consume arguments when called. */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sink_function)(LIBXSMM_VARIADIC); + +/** The original BLAS functions. */ +LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch_function); +LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch_function); +LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_dgemm_function libxsmm_original_dgemm_function); +LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_sgemm_function libxsmm_original_sgemm_function); +LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_dgemv_function libxsmm_original_dgemv_function); +LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_sgemv_function libxsmm_original_sgemv_function); +LIBXSMM_API libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch(void); +LIBXSMM_API libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch(void); +LIBXSMM_API libxsmm_dgemm_function libxsmm_original_dgemm(void); +LIBXSMM_API libxsmm_sgemm_function libxsmm_original_sgemm(void); +LIBXSMM_API libxsmm_dgemv_function libxsmm_original_dgemv(void); +LIBXSMM_API libxsmm_sgemv_function libxsmm_original_sgemv(void); +LIBXSMM_API libxsmm_sink_function libxsmm_blas_error(const char* symbol); +LIBXSMM_API void libxsmm_sink(LIBXSMM_VARIADIC); + +/** + * General dense matrix multiplication, which re-exposes LAPACK/BLAS + * but allows to rely on LIBXSMM's defaults (libxsmm_config.h) + * when supplying NULL-arguments in certain places. + */ +LIBXSMM_API void libxsmm_blas_xgemm(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, + const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc); + +#define libxsmm_blas_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ + libxsmm_blas_xgemm(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, \ + TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) +#define libxsmm_blas_sgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ + libxsmm_blas_xgemm(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, \ + TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) + +#define libxsmm_dgemm_omp(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ + libxsmm_xgemm_omp(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, \ + TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) +#define libxsmm_sgemm_omp(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ + libxsmm_xgemm_omp(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, \ + TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) + +/** Translates GEMM prefetch request into prefetch-enumeration (incl. FE's auto-prefetch). */ +LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_xprefetch(const int* prefetch); +LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_prefetch(int prefetch); + +/** Determines the given value in double-precision based on the given type. */ +LIBXSMM_API int libxsmm_dvalue(libxsmm_datatype datatype, const void* value, double* dvalue); + +#endif /*LIBXSMM_FRONTEND_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_fsspmdm.h b/third_party/libxsmm/include/libxsmm_fsspmdm.h new file mode 100644 index 0000000000000000000000000000000000000000..46f3275c90fe8e219473b23df970cd71143ea955 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_fsspmdm.h @@ -0,0 +1,40 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_FSSPMDM_H +#define LIBXSMM_FSSPMDM_H + +#include "libxsmm_typedefs.h" + + +/** Opaque types for fsspmdm */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dfsspmdm libxsmm_dfsspmdm; +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_sfsspmdm libxsmm_sfsspmdm; + +LIBXSMM_API libxsmm_dfsspmdm* libxsmm_dfsspmdm_create( libxsmm_blasint M, libxsmm_blasint N, libxsmm_blasint K, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + const double alpha, const double beta, libxsmm_blasint c_is_nt, + const double* a_dense ); + +LIBXSMM_API void libxsmm_dfsspmdm_execute( const libxsmm_dfsspmdm* handle, const double* B, double* C ); + +LIBXSMM_API void libxsmm_dfsspmdm_destroy( libxsmm_dfsspmdm* handle ); + +LIBXSMM_API libxsmm_sfsspmdm* libxsmm_sfsspmdm_create( libxsmm_blasint M, libxsmm_blasint N, libxsmm_blasint K, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + const float alpha, const float beta, libxsmm_blasint c_is_nt, + const float* a_dense ); + +LIBXSMM_API void libxsmm_sfsspmdm_execute( const libxsmm_sfsspmdm* handle, const float* B, float* C ); + +LIBXSMM_API void libxsmm_sfsspmdm_destroy( libxsmm_sfsspmdm* handle ); + +#endif /*LIBXSMM_FSSPMDM_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_generator.h b/third_party/libxsmm/include/libxsmm_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..b08d6abd38ac994e74229f7139255a7a87c772a4 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_generator.h @@ -0,0 +1,219 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_GENERATOR_H +#define LIBXSMM_GENERATOR_H + +#include "libxsmm_typedefs.h" + +#define LIBXSMM_GEMM_NO_BYPASS(FLAGS, ALPHA, BETA) ( \ + 0 == ((FLAGS) & (LIBXSMM_GEMM_FLAG_TRANS_A)) && \ + (LIBXSMM_FEQ(1, ALPHA) /*|| LIBXSMM_FEQ(-1, ALPHA)*/) && \ + (LIBXSMM_FEQ(1, BETA) || LIBXSMM_FEQ(0, BETA))) + + +/** Initialize GEMM descriptor as used by low-level routines (type-specific). */ +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_dgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + double alpha, double beta, int flags, int prefetch); +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_sgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + float alpha, float beta, int flags, int prefetch); +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_wigemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + int alpha, int beta, int flags, int prefetch); +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_bigemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + int alpha, int beta, int flags, int prefetch); +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_bbgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + int alpha, int beta, int flags, int prefetch); +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_bsgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + float alpha, float beta, int flags, int prefetch); +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_bgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + float alpha, float beta, int flags, int prefetch); + +/** Initialize GEMM descriptor (generic: double-precision alpha/beta). */ +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_dinit(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision precision, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, double alpha, double beta, + int flags, int prefetch); +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_dinit2(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + double alpha, double beta, int flags, int prefetch); + +/** Initialize GEMM descriptor as used by low-level routines (generic). */ +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision precision, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, const void* alpha, const void* beta, + int flags, int prefetch); +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_init2(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, const void* alpha, const void* beta, + int flags, int prefetch); +/** Similar to libxsmm_gemm_descriptor_init2 with optional type-converted alpha/beta (dalpha/dbeta). */ +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_init3(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, const void* alpha, const void* beta, + int flags, int prefetch, double* dalpha, double* dbeta); + +/** Initialize transpose descriptor as used by low-level routines. */ +LIBXSMM_API libxsmm_meltw_descriptor* libxsmm_meltw_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_datatype in_type, libxsmm_datatype out_type, + libxsmm_blasint m, libxsmm_blasint n, + libxsmm_blasint ldo, libxsmm_blasint ldi, + unsigned short flags, unsigned char param, unsigned char operation); +LIBXSMM_API libxsmm_meltw_descriptor* libxsmm_meltw_descriptor_init2(libxsmm_descriptor_blob* blob, + libxsmm_datatype in_type, libxsmm_datatype in2_type, libxsmm_datatype out_type, libxsmm_datatype out2_type, + libxsmm_blasint m, libxsmm_blasint n, + libxsmm_blasint ldo, libxsmm_blasint ldi, libxsmm_blasint ldi2, libxsmm_blasint ldi3, + unsigned short flags, unsigned char param, unsigned char operation); + +/** Initialize matrix equation as used by low-level routines */ +LIBXSMM_API libxsmm_meqn_descriptor* libxsmm_meqn_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_datatype type, libxsmm_blasint m, libxsmm_blasint n, + libxsmm_blasint ldo, unsigned int eqn_idx); + +/** Structure referring to the generated code with some attached information. */ +LIBXSMM_EXTERN_C typedef struct libxsmm_generated_code { + void* generated_code; /** pointer to memory which can contain strings or binary code */ + unsigned int buffer_size; /** total size if the buffer generated_code */ + unsigned int code_size; /** size of bytes used in generated_code */ + unsigned int code_type; /** + * 0: generated code contains inline assembly in a C function + * which can be dumped into a *.c/cc/cpp file + * 1: generated code contains assembly which can be + * dumped into an *.s file + * >1: generated code contains a function in binary code which can be + * called, when the code is copied into executable memory + */ + unsigned int last_error; /** + * 0: no error occurred + * >0: error code + */ + unsigned int arch; /* target arch for the current code generation task */ + unsigned int sf_size; /* offset of RSP to the beginning of the stack frame + * we track this value to have RBP availbale for general compute + */ +} libxsmm_generated_code; + +/** function to translate LIBXSMM Generator error codes to error messages */ +LIBXSMM_API +const char* libxsmm_strerror(unsigned int i_error_code); + +/* @TODO change int based architecture value */ +LIBXSMM_API +void libxsmm_generator_gemm_inlineasm(const char* i_file_out, + const char* i_routine_name, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const char* i_arch ); + +/* @TODO change int based architecture value */ +LIBXSMM_API +void libxsmm_generator_gemm_directasm(const char* i_file_out, + const char* i_routine_name, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const char* i_arch ); + +LIBXSMM_API +void libxsmm_generator_gemm_kernel(libxsmm_generated_code* io_generated_code, + const libxsmm_gemm_descriptor* i_xgemm_desc ); + +/* @TODO change int based architecture value */ +LIBXSMM_API +void libxsmm_generator_spgemm(const char* i_file_out, + const char* i_routine_name, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const char* i_arch, + const char* i_file_in, + const int i_is_csr); + +/* @TODO change int based architecture value */ +LIBXSMM_API +void libxsmm_generator_spgemm_csc_kernel(libxsmm_generated_code* io_generated_code, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const char* i_arch, + const unsigned int* i_row_idx, + const unsigned int* i_column_idx, + const double* i_values); + +/* @TODO change int based architecture value */ +LIBXSMM_API +void libxsmm_generator_spgemm_csr_kernel(libxsmm_generated_code* io_generated_code, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const char* i_arch, + const unsigned int* i_row_idx, + const unsigned int* i_column_idx, + const double* i_values); + +/* @TODO change int based architecture value */ +LIBXSMM_API +void libxsmm_generator_spgemm_csr_reg_kernel(libxsmm_generated_code* io_generated_code, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const char* i_arch, + const unsigned int* i_row_idx, + const unsigned int* i_column_idx, + const double* i_values); + +LIBXSMM_API +void libxsmm_generator_packed_spgemm_csr_kernel( libxsmm_generated_code* io_generated_code, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const unsigned int* i_row_idx, + const unsigned int* i_column_idx, + const void* i_values, + const unsigned int i_packed_width ); + +LIBXSMM_API +void libxsmm_generator_packed_spgemm_csc_kernel( libxsmm_generated_code* io_generated_code, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const unsigned int* i_row_idx, + const unsigned int* i_column_idx, + const void* i_values, + const unsigned int i_packed_width ); + +LIBXSMM_API +void libxsmm_generator_packed_gemm_ac_rm( libxsmm_generated_code* io_generated_code, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const unsigned int i_packed_width ); + +LIBXSMM_API +void libxsmm_generator_packed_gemm_bc_rm( libxsmm_generated_code* io_generated_code, + const libxsmm_gemm_descriptor* i_xgemm_desc, + const unsigned int i_packed_width ); + +LIBXSMM_API +void libxsmm_generator_mateltwise_kernel( libxsmm_generated_code* io_generated_code, + const libxsmm_meltw_descriptor* i_mateltw_desc ); + +LIBXSMM_API +void libxsmm_generator_matequation_kernel( libxsmm_generated_code* io_generated_code, + const libxsmm_meqn_descriptor* i_mateqn_desc ); + +/** Initialization counter that can be used to check whether the library is initialized (!=0) or not (==0). */ +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_ninit); +/** Target architecture (libxsmm_get_target_archid, libxsmm_set_target_archid). */ +LIBXSMM_APIVAR_PUBLIC(int libxsmm_target_archid); +/** Verbosity level (0: quiet, 1: errors, 2: warnings, 3: info, neg.: all/dump). */ +LIBXSMM_APIVAR_PUBLIC(int libxsmm_verbosity); +/** Security-enhanced environment. */ +LIBXSMM_APIVAR_PUBLIC(int libxsmm_se); + +#endif /*LIBXSMM_GENERATOR_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_intrinsics_x86.h b/third_party/libxsmm/include/libxsmm_intrinsics_x86.h new file mode 100644 index 0000000000000000000000000000000000000000..59ec8676ae63b5531a8990d7e0af55acb38f2f9d --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_intrinsics_x86.h @@ -0,0 +1,1022 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_INTRINSICS_X86_H +#define LIBXSMM_INTRINSICS_X86_H + +#include "libxsmm_cpuid.h" + +/** Macro evaluates to LIBXSMM_ATTRIBUTE_TARGET_xxx (see below). */ +#define LIBXSMM_ATTRIBUTE_TARGET(TARGET) LIBXSMM_CONCATENATE(LIBXSMM_ATTRIBUTE_TARGET_, TARGET) + +#if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_PLATFORM_X86) +# define LIBXSMM_INTRINSICS_NONE +#endif +#if /*no intrinsics: tested with 17.x and 18.x*/(defined(__PGI) && \ + LIBXSMM_VERSION2(19, 0) > LIBXSMM_VERSION2(__PGIC__, __PGIC_MINOR__)) \ + || /*legacy*/(defined(_CRAYC) && !defined(__GNUC__)) +# if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_STATIC) +# define LIBXSMM_INTRINSICS_NONE +# endif +#elif !defined(LIBXSMM_INTRINSICS_STATIC) && !defined(LIBXSMM_INTRINSICS_NONE) && ( \ + (defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && \ + LIBXSMM_VERSION2(4, 4) > LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) /* GCC 4.4 (target-attribute) */ \ + || (defined(__clang__) && LIBXSMM_VERSION2(3, 7) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \ + || (defined(__APPLE__) && defined(__MACH__) && !defined(LIBXSMM_INTEL_COMPILER) && defined(__clang__) && \ + LIBXSMM_VERSION2(9, 0) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) +# define LIBXSMM_INTRINSICS_STATIC +#endif + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif + +/** https://github.com/intel/Immintrin-debug */ +#if !defined(LIBXSMM_INTRINSICS_DEBUG) && 0 +# define LIBXSMM_INTRINSICS_DEBUG +/* workarounds removed after LIBXSMM 1.16.1-1.16.1-1268 */ +# include "immintrin_dbg.h" +#endif +#if defined(__MIC__) && !defined(LIBXSMM_INTRINSICS_NONE) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH_GENERIC +# endif +# define LIBXSMM_INTRINSICS(TARGET) +# define LIBXSMM_INTRINSICS_INCLUDE +#elif !defined(LIBXSMM_INTRINSICS_NONE) /*!defined(__MIC__)*/ +# if defined(__AVX512F__) && defined(__AVX512CD__) \ + && defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) \ + && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \ + && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) /* TODO: check GCC, Clang, etc. */ \ + || (LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \ + && (!defined(__clang__) || (LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \ + && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(99, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(__AVX512F__) && defined(__AVX512CD__) \ + && defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) && defined(__AVX512VNNI__) \ + && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \ + && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \ + || (LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \ + && (!defined(__clang__) || (LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \ + && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(__AVX512F__) && defined(__AVX512CD__) \ + && defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) \ + && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \ + && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \ + || (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \ + && (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \ + && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(__AVX512F__) && defined(__AVX512CD__) \ + && defined(__AVX512PF__) && defined(__AVX512ER__) \ + && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \ + && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \ + || (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \ + && (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \ + && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_MIC +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(__AVX512F__) && defined(__AVX512CD__) \ + && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \ + && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \ + || (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \ + && (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \ + && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512 +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2 +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_SSE42 +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(__SSE3__) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_SSE3 +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(LIBXSMM_PLATFORM_X86) +# if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_GENERIC +# endif +# if defined(__GNUC__) +# define LIBXSMM_INTRINSICS_INCLUDE +# endif +# endif +# if defined(LIBXSMM_STATIC_TARGET_ARCH) && !defined(LIBXSMM_INTRINSICS_STATIC) +# if defined(__INTEL_COMPILER) +# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) + /* TODO: compiler version check for LIBXSMM_MAX_STATIC_TARGET_ARCH */ +# if 1904 <= (LIBXSMM_INTEL_COMPILER) && !defined(_WIN32) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX +# elif 1801 <= (LIBXSMM_INTEL_COMPILER) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX +# elif 1500 <= (LIBXSMM_INTEL_COMPILER) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE +# elif 1400 <= (LIBXSMM_INTEL_COMPILER) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_MIC +# else +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2 +# endif +# endif +# define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/ +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(_CRAYC) && defined(__GNUC__) + /* TODO: version check, e.g., LIBXSMM_VERSION2(11, 5) <= LIBXSMM_VERSION2(_RELEASE, _RELEASE_MINOR) */ +# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX +# endif +# define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/ +# define LIBXSMM_INTRINSICS_INCLUDE +# elif defined(_MSC_VER) && !defined(__clang__) + /* TODO: compiler version check for LIBXSMM_MAX_STATIC_TARGET_ARCH */ +# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2 +# endif +# define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/ +# define LIBXSMM_INTRINSICS_INCLUDE +# elif (!defined(__GNUC__) || LIBXSMM_VERSION2(4, 9) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \ + && (!defined(__clang__) || LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \ + && (!defined(__APPLE__) || !defined(__MACH__)) && !defined(__PGI) && !defined(_MSC_VER) +# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) +# if defined(__CYGWIN__) && !defined(LIBXSMM_INTRINSICS_DEBUG) /* Cygwin: invalid register for .seh_savexmm */ +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2 +# elif (defined(__clang__) && LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX +# elif (defined(__GNUC__) && LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \ + || (defined(__clang__) && LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && !defined(__cray__)) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX +# elif (defined(__GNUC__) && LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \ + || (defined(__clang__) && LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX +# elif (defined(__GNUC__) && LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \ + || (defined(__clang__) && LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE +# else +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2 +# endif +# endif +# define LIBXSMM_INTRINSICS_INCLUDE +# else /* GCC/legacy incl. Clang */ +# if defined(__clang__) && !(defined(__APPLE__) && defined(__MACH__)) && !defined(_WIN32) +# if (LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) /* TODO */ + /* no limitations */ +# elif (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# if !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_STATIC_TARGET_ARCH < LIBXSMM_X86_AVX2/*workaround*/) +# define LIBXSMM_INTRINSICS_STATIC +# endif +# elif !defined(LIBXSMM_INTRINSICS_STATIC) +# define LIBXSMM_INTRINSICS_STATIC +# endif +# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) +# if defined(__CYGWIN__) && !defined(LIBXSMM_INTRINSICS_DEBUG) /* Cygwin: invalid register for .seh_savexmm */ +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2 +# elif LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX +# elif LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && !defined(__cray__) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX +# elif LIBXSMM_VERSION2( 6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX +# else +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE +# endif +# endif +# else /* fallback */ +# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH +# endif +# if !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_STATIC_TARGET_ARCH < LIBXSMM_X86_AVX2/*workaround*/) +# define LIBXSMM_INTRINSICS_STATIC +# endif +# endif +# if !defined(LIBXSMM_INTRINSICS_INCLUDE) && (!defined(__PGI) || LIBXSMM_VERSION2(19, 0) <= LIBXSMM_VERSION2(__PGIC__, __PGIC_MINOR__)) +# define LIBXSMM_INTRINSICS_INCLUDE +# endif +# endif /* GCC/legacy incl. Clang */ +# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) +# error "LIBXSMM_MAX_STATIC_TARGET_ARCH not defined!" +# endif +# if defined(LIBXSMM_TARGET_ARCH) && (LIBXSMM_TARGET_ARCH < LIBXSMM_MAX_STATIC_TARGET_ARCH) +# undef LIBXSMM_MAX_STATIC_TARGET_ARCH +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH +# endif +# if defined(LIBXSMM_INTRINSICS_INCLUDE) && !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_DEBUG) +# include +# endif /*defined(LIBXSMM_INTRINSICS_INCLUDE)*/ +# if !defined(LIBXSMM_INTRINSICS) +# if (LIBXSMM_MAX_STATIC_TARGET_ARCH > LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_INTRINSICS(TARGET) LIBXSMM_ATTRIBUTE(LIBXSMM_ATTRIBUTE_TARGET(TARGET)) + /* LIBXSMM_ATTRIBUTE_TARGET_xxx is required to literally match the CPUID (libxsmm_cpuid.h)! */ +# define LIBXSMM_ATTRIBUTE_TARGET_1002 target("sse2") /* LIBXSMM_X86_GENERIC (64-bit ABI) */ +# if (LIBXSMM_X86_SSE3 <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1003 target("sse3") +# else +# define LIBXSMM_ATTRIBUTE_TARGET_1003 LIBXSMM_ATTRIBUTE_TARGET_1002 +# endif +# if (LIBXSMM_X86_SSE42 <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1004 target("sse4.1,sse4.2") +# else +# define LIBXSMM_ATTRIBUTE_TARGET_1004 LIBXSMM_ATTRIBUTE_TARGET_1003 +# endif +# if (LIBXSMM_X86_AVX <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1005 target("avx") +# else +# define LIBXSMM_ATTRIBUTE_TARGET_1005 LIBXSMM_ATTRIBUTE_TARGET_1004 +# endif +# if (LIBXSMM_X86_AVX2 <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1006 target("avx2,fma") +# else +# define LIBXSMM_ATTRIBUTE_TARGET_1006 LIBXSMM_ATTRIBUTE_TARGET_1005 +# endif +# if (LIBXSMM_X86_AVX512 <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1007 target("avx2,fma,avx512f,avx512cd") +# else +# define LIBXSMM_ATTRIBUTE_TARGET_1007 LIBXSMM_ATTRIBUTE_TARGET_1006 +# endif +# if (LIBXSMM_X86_AVX512_MIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1010 target("avx2,fma,avx512f,avx512cd,avx512pf,avx512er") +# else /* LIBXSMM_X86_AVX512 */ +# define LIBXSMM_ATTRIBUTE_TARGET_1010 LIBXSMM_ATTRIBUTE_TARGET_1007 +# endif +# if (LIBXSMM_X86_AVX512_KNM <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1011 target("avx2,fma,avx512f,avx512cd,avx512pf,avx512er,avx5124vnniw,avx5124fmaps") +# else /* LIBXSMM_X86_AVX512_MIC */ +# define LIBXSMM_ATTRIBUTE_TARGET_1011 LIBXSMM_ATTRIBUTE_TARGET_1010 +# endif +# if (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1020 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl") +# else /* LIBXSMM_X86_AVX512 */ +# define LIBXSMM_ATTRIBUTE_TARGET_1020 LIBXSMM_ATTRIBUTE_TARGET_1007 +# endif +# if (LIBXSMM_X86_AVX512_CLX <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1021 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl,avx512vnni") +# else /* LIBXSMM_X86_AVX512_CORE */ +# define LIBXSMM_ATTRIBUTE_TARGET_1021 LIBXSMM_ATTRIBUTE_TARGET_1020 +# endif +# if (LIBXSMM_X86_AVX512_CPX <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_ATTRIBUTE_TARGET_1022 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl,avx512vnni,avx512bf16") +# else /* LIBXSMM_X86_AVX512_CORE */ +# define LIBXSMM_ATTRIBUTE_TARGET_1022 LIBXSMM_ATTRIBUTE_TARGET_1021 +# endif +# else +# define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/ +# endif +# elif !defined(LIBXSMM_INTRINSICS_TARGET) +# define LIBXSMM_INTRINSICS_TARGET +# endif /*!defined(LIBXSMM_INTRINSICS)*/ +# endif /*defined(LIBXSMM_STATIC_TARGET_ARCH)*/ +#endif /*!defined(LIBXSMM_INTRINSICS_NONE)*/ + +#if !defined(LIBXSMM_STATIC_TARGET_ARCH) +# if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_STATIC) +# define LIBXSMM_INTRINSICS_NONE +# endif +# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH_GENERIC +#endif + +#if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH +#elif (LIBXSMM_MAX_STATIC_TARGET_ARCH < LIBXSMM_STATIC_TARGET_ARCH) +# undef LIBXSMM_MAX_STATIC_TARGET_ARCH +# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH +#endif + +#if !defined(LIBXSMM_INTRINSICS) +# define LIBXSMM_INTRINSICS(TARGET) +#endif + +/** Include basic x86 intrinsics such as __rdtsc. */ +#if defined(LIBXSMM_INTRINSICS_INCLUDE) && !defined(LIBXSMM_INTRINSICS_DEBUG) +# if defined(_WIN32) +# include +# elif defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) || defined(__clang__) || defined(__PGI) +# include +# elif defined(__GNUC__) && (LIBXSMM_VERSION2(4, 4) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) +# include +# endif +# include +# if defined(__SSE3__) +# include +# endif +#endif + +#if !defined(LIBXSMM_INTRINSICS_NONE) +# if defined(_WIN32) +# include +# else +# include +# endif +#endif + +/** + * Intrinsic-specific fix-ups + */ +# define LIBXSMM_INTRINSICS_LOADU_SI128(A) _mm_loadu_si128(A) +#if !defined(LIBXSMM_INTEL_COMPILER) && defined(__clang__) && ( \ + (LIBXSMM_VERSION2(3, 9) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \ + || (LIBXSMM_VERSION2(7, 3) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && \ + defined(__APPLE__) && defined(__MACH__))) +/* prototypes with incorrect signature: _mm512_load_ps takes DP*, _mm512_load_pd takes SP* (checked with v3.8.1) */ +# define LIBXSMM_INTRINSICS_MM512_LOAD_PS(A) _mm512_loadu_ps((const double*)(A)) +# define LIBXSMM_INTRINSICS_MM512_LOAD_PD(A) _mm512_loadu_pd((const float*)(A)) +/* Clang misses _mm512_stream_p? (checked with v3.8.1). */ +# define LIBXSMM_INTRINSICS_MM512_STREAM_SI512(A, B) _mm512_store_si512(A, B) +# define LIBXSMM_INTRINSICS_MM512_STREAM_PS(A, B) _mm512_storeu_ps(A, B) +# define LIBXSMM_INTRINSICS_MM512_STREAM_PD(A, B) _mm512_store_pd(A, B) +#else +# define LIBXSMM_INTRINSICS_MM512_LOAD_PS(A) _mm512_loadu_ps((const float*)(A)) +# define LIBXSMM_INTRINSICS_MM512_LOAD_PD(A) _mm512_loadu_pd((const double*)(A)) +# define LIBXSMM_INTRINSICS_MM512_STREAM_SI512(A, B) _mm512_stream_si512((__m512i*)(A), (B)) +# define LIBXSMM_INTRINSICS_MM512_STREAM_PS(A, B) _mm512_stream_ps(A, B) +# define LIBXSMM_INTRINSICS_MM512_STREAM_PD(A, B) _mm512_stream_pd(A, B) +#endif +#if !defined(LIBXSMM_INTEL_COMPILER) || (defined(__clang__) && ( \ + (LIBXSMM_VERSION2(8, 0) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)))) \ + || (defined(__APPLE__) && defined(__MACH__)) || defined(__GNUC__) +# define LIBXSMM_INTRINSICS_MM256_STORE_EPI32(A, B) _mm256_storeu_si256((__m256i*)(A), B) +#else +# define LIBXSMM_INTRINSICS_MM256_STORE_EPI32(A, B) _mm256_storeu_epi32(A, B) +#endif +#if defined(LIBXSMM_INTEL_COMPILER) +# if 1600 <= (LIBXSMM_INTEL_COMPILER) +# define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \ + E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \ + _mm512_set_epi16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \ + E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) +# else +# define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \ + E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \ + _mm512_castps_si512(_mm512_set_epi16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \ + E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0)) +# endif +#else +# define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \ + E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \ + _mm512_set_epi32(((E31) << 16) | (E30), ((E29) << 16) | (E28), ((E27) << 16) | (E26), ((E25) << 16) | (E24), \ + ((E23) << 16) | (E22), ((E21) << 16) | (E20), ((E19) << 16) | (E18), ((E17) << 16) | (E16), \ + ((E15) << 16) | (E14), ((E13) << 16) | (E12), ((E11) << 16) | (E10), ((E9) << 16) | (E8), \ + ((E7) << 16) | (E6), ((E5) << 16) | (E4), ((E3) << 16) | (E2), ((E1) << 16) | (E0)) +#endif +#if defined(LIBXSMM_INTEL_COMPILER) \ + || (defined(__GNUC__) && LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \ + || (defined(__clang__) && (!defined(__APPLE__) || !defined(__MACH__)) \ + && LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# define LIBXSMM_INTRINSICS_MM512_MASK_I32GATHER_EPI32(A, B, C, D, E) _mm512_mask_i32gather_epi32(A, B, C, D, E) +# define LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(A, B) _mm512_extracti64x4_epi64(A, B) +# define LIBXSMM_INTRINSICS_MM512_ABS_PS(A) _mm512_abs_ps(A) +# define LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32() _mm512_undefined_epi32() +# define LIBXSMM_INTRINSICS_MM512_UNDEFINED() _mm512_undefined() +# define LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256() _mm256_undefined_si256() +# define LIBXSMM_INTRINSICS_MM_UNDEFINED_SI128() _mm_undefined_si128() +# define LIBXSMM_INTRINSICS_MM_UNDEFINED_PD() _mm_undefined_pd() +#else +# define LIBXSMM_INTRINSICS_MM512_MASK_I32GATHER_EPI32(A, B, C, D, E) _mm512_castps_si512(_mm512_mask_i32gather_ps( \ + _mm512_castsi512_ps(A), B, C, (const float*)(D), E)) +# define LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(A, B) _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castsi512_pd(A), B)) +# define LIBXSMM_INTRINSICS_MM512_ABS_PS(A) _mm512_castsi512_ps(_mm512_and_epi32( \ + _mm512_castps_si512(A), _mm512_set1_epi32(0x7FFFFFFF))) +# define LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32() _mm512_set1_epi32(0) +# define LIBXSMM_INTRINSICS_MM512_UNDEFINED() _mm512_set1_ps(0) +# define LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256() _mm256_set1_epi32(0) +# define LIBXSMM_INTRINSICS_MM_UNDEFINED_SI128() _mm_set1_epi32(0) +# define LIBXSMM_INTRINSICS_MM_UNDEFINED_PD() _mm_set1_pd(0) +#endif +#if (defined(LIBXSMM_INTEL_COMPILER) && (1800 <= (LIBXSMM_INTEL_COMPILER))) \ + || (!defined(LIBXSMM_INTEL_COMPILER) && defined(__GNUC__) \ + && LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \ + || ((!defined(__APPLE__) || !defined(__MACH__)) && defined(__clang__) \ + && LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \ + LIBXSMM_CONCATENATE(_store_mask, NBITS)((LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR), SRC) +# define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) \ + LIBXSMM_CONCATENATE(_load_mask, NBITS)((/*const*/ LIBXSMM_CONCATENATE(__mmask, NBITS)*)(SRC_PTR)) +# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) LIBXSMM_CONCATENATE(_cvtu32_mask, NBITS)((unsigned int)(A)) +#elif defined(LIBXSMM_INTEL_COMPILER) +# define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \ + (*(LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR) = (LIBXSMM_CONCATENATE(__mmask, NBITS))(SRC)) +# define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) \ + ((LIBXSMM_CONCATENATE(__mmask, NBITS))_mm512_mask2int(*(const __mmask16*)(SRC_PTR))) +# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) LIBXSMM_CONCATENATE(LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_, NBITS)(A) +# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_32(A) ((__mmask32)(A)) +# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_16(A) _mm512_int2mask((int)(A)) +# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_8(A) ((__mmask8)(A)) +#else +# define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \ + (*(LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR) = (LIBXSMM_CONCATENATE(__mmask, NBITS))(SRC)) +# define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) (*(const LIBXSMM_CONCATENATE(__mmask, NBITS)*)(SRC_PTR)) +# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) ((LIBXSMM_CONCATENATE(__mmask, NBITS))(A)) +#endif +#define LIBXSMM_INTRINSICS_MM512_STORE_MASK64(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 64) +#define LIBXSMM_INTRINSICS_MM512_STORE_MASK32(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 32) +#define LIBXSMM_INTRINSICS_MM512_STORE_MASK16(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 16) +#define LIBXSMM_INTRINSICS_MM512_STORE_MASK8(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 8) +#define LIBXSMM_INTRINSICS_MM512_LOAD_MASK64(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 64) +#define LIBXSMM_INTRINSICS_MM512_LOAD_MASK32(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 32) +#define LIBXSMM_INTRINSICS_MM512_LOAD_MASK16(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 16) +#define LIBXSMM_INTRINSICS_MM512_LOAD_MASK8(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 8) +#define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK32(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 32) +#define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 16) +#define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 8) + +/** + * Pseudo intrinsics for portability + */ +LIBXSMM_API_INLINE int LIBXSMM_INTRINSICS_BITSCANFWD32_SW(unsigned int n) { + unsigned int i, r = 0; if (0 != n) for (i = 1; 0 == (n & i); i <<= 1) { ++r; } return r; +} +LIBXSMM_API_INLINE int LIBXSMM_INTRINSICS_BITSCANFWD64_SW(unsigned long long n) { + unsigned int i, r = 0; if (0 != n) for (i = 1; 0 == (n & i); i <<= 1) { ++r; } return r; +} + +/** Binary Logarithm (based on Stackoverflow's NBITSx macro). */ +#define LIBXSMM_INTRINSICS_BITSCANBWD_SW02(N) (0 != ((N) & 0x2/*0b10*/) ? 1 : 0) +#define LIBXSMM_INTRINSICS_BITSCANBWD_SW04(N) (0 != ((N) & 0xC/*0b1100*/) ? (2 | LIBXSMM_INTRINSICS_BITSCANBWD_SW02((N) >> 2)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW02(N)) +#define LIBXSMM_INTRINSICS_BITSCANBWD_SW08(N) (0 != ((N) & 0xF0/*0b11110000*/) ? (4 | LIBXSMM_INTRINSICS_BITSCANBWD_SW04((N) >> 4)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW04(N)) +#define LIBXSMM_INTRINSICS_BITSCANBWD_SW16(N) (0 != ((N) & 0xFF00) ? (8 | LIBXSMM_INTRINSICS_BITSCANBWD_SW08((N) >> 8)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW08(N)) +#define LIBXSMM_INTRINSICS_BITSCANBWD_SW32(N) (0 != ((N) & 0xFFFF0000) ? (16 | LIBXSMM_INTRINSICS_BITSCANBWD_SW16((N) >> 16)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW16(N)) +#define LIBXSMM_INTRINSICS_BITSCANBWD_SW64(N) (0 != ((N) & 0xFFFFFFFF00000000) ? (32 | LIBXSMM_INTRINSICS_BITSCANBWD_SW32((N) >> 32)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW32(N)) +#define LIBXSMM_INTRINSICS_BITSCANBWD32_SW(N) LIBXSMM_INTRINSICS_BITSCANBWD_SW32((unsigned int)(N)) +#define LIBXSMM_INTRINSICS_BITSCANBWD64_SW(N) LIBXSMM_INTRINSICS_BITSCANBWD_SW64((unsigned long long)(N)) + +#if defined(_WIN32) && !defined(LIBXSMM_INTRINSICS_NONE) + LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANFWD32(unsigned int n) { + unsigned long r = 0; _BitScanForward(&r, n); return (0 != n) * r; + } + LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANBWD32(unsigned int n) { + unsigned long r = 0; _BitScanReverse(&r, n); return r; + } +# if defined(_WIN64) + LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANFWD64(unsigned long long n) { + unsigned long r = 0; _BitScanForward64(&r, n); return (0 != n) * r; + } + LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANBWD64(unsigned long long n) { + unsigned long r = 0; _BitScanReverse64(&r, n); return r; + } +# else +# define LIBXSMM_INTRINSICS_BITSCANFWD64 LIBXSMM_INTRINSICS_BITSCANFWD64_SW +# define LIBXSMM_INTRINSICS_BITSCANBWD64 LIBXSMM_INTRINSICS_BITSCANBWD64_SW +# endif +#elif defined(__GNUC__) && !defined(LIBXSMM_INTRINSICS_NONE) +# define LIBXSMM_INTRINSICS_BITSCANFWD32(N) ((0 != (N)) * __builtin_ctz(N)) +# define LIBXSMM_INTRINSICS_BITSCANFWD64(N) ((0 != (N)) * __builtin_ctzll(N)) +# define LIBXSMM_INTRINSICS_BITSCANBWD32(N) ((0 != (N)) * (31 - __builtin_clz(N))) +# define LIBXSMM_INTRINSICS_BITSCANBWD64(N) ((0 != (N)) * (63 - __builtin_clzll(N))) +#else /* fallback implementation */ +# define LIBXSMM_INTRINSICS_BITSCANFWD32 LIBXSMM_INTRINSICS_BITSCANFWD32_SW +# define LIBXSMM_INTRINSICS_BITSCANFWD64 LIBXSMM_INTRINSICS_BITSCANFWD64_SW +# define LIBXSMM_INTRINSICS_BITSCANBWD32 LIBXSMM_INTRINSICS_BITSCANBWD32_SW +# define LIBXSMM_INTRINSICS_BITSCANBWD64 LIBXSMM_INTRINSICS_BITSCANBWD64_SW +#endif + +/** LIBXSMM_NBITS determines the minimum number of bits needed to represent N. */ +#define LIBXSMM_NBITS(N) (LIBXSMM_INTRINSICS_BITSCANBWD64(N) + LIBXSMM_MIN(1, N)) +#define LIBXSMM_ISQRT2(N) ((unsigned int)((1ULL << (LIBXSMM_NBITS(N) >> 1)) /*+ LIBXSMM_MIN(1, N)*/)) +/** LIBXSMM_ILOG2 definition matches ceil(log2(N)). */ +LIBXSMM_API_INLINE unsigned int LIBXSMM_ILOG2(unsigned long long n) { + unsigned int result = 0; if (1 < n) { + const unsigned int m = LIBXSMM_INTRINSICS_BITSCANBWD64(n); + result = m + ((unsigned int)LIBXSMM_INTRINSICS_BITSCANBWD64(n - 1) == m); + } return result; +} + +/** + * Target attribution + */ +#if !defined(LIBXSMM_INTRINSICS_KNC) && !defined(LIBXSMM_INTRINSICS_NONE) && defined(__MIC__) +# define LIBXSMM_INTRINSICS_KNC +#endif +/** LIBXSMM_INTRINSICS_X86 is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_X86) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_GENERIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_X86 +#endif +/** LIBXSMM_INTRINSICS_SSE3 is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_SSE3) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_SSE3 <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_SSE3 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_SSE3 +#endif +/** LIBXSMM_INTRINSICS_SSE42 is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_SSE42) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_SSE42 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_SSE42 +#endif +/** LIBXSMM_INTRINSICS_AVX is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_AVX) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_AVX +#endif +/** LIBXSMM_INTRINSICS_AVX2 is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_AVX2) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX2 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_AVX2 +#endif +/** LIBXSMM_INTRINSICS_AVX512 is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_AVX512) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_AVX512 +#endif +/** LIBXSMM_INTRINSICS_AVX512_MIC is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_AVX512_MIC) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_MIC <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_MIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_AVX512_MIC +#endif +/** LIBXSMM_INTRINSICS_AVX512_KNM is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_AVX512_KNM) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_KNM <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_KNM <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_AVX512_KNM +#endif +/** LIBXSMM_INTRINSICS_AVX512_CORE is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_AVX512_CORE) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_CORE <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_AVX512_CORE +#endif +/** LIBXSMM_INTRINSICS_AVX512_CLX is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_AVX512_CLX) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_CLX <= LIBXSMM_STATIC_TARGET_ARCH || \ + (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_CLX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)) +# define LIBXSMM_INTRINSICS_AVX512_CLX +#endif +/** LIBXSMM_INTRINSICS_AVX512_CPX is defined only if the compiler is able to generate this code without special flags. */ +#if !defined(LIBXSMM_INTRINSICS_AVX512_CPX) && !defined(LIBXSMM_INTRINSICS_NONE) && defined(LIBXSMM_X86_AVX512_CPX) && \ + !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_X86_AVX512_CPX <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +# define LIBXSMM_INTRINSICS_AVX512_CPX +#endif + +/** 2048-bit state for xoshiro128+ RNG (state/symbols needed even if AVX-512 is not used) */ +#define LIBXSMM_INTRINSICS_MM512_RNG_STATE(INDEX) (*(__m512i*)LIBXSMM_CONCATENATE(libxsmm_intrinsics_mm512_rng_state, INDEX)) +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state0[16]); +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state1[16]); +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state2[16]); +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state3[16]); + +/** + * Pseudo intrinsics (AVX-2) + */ +#if defined(LIBXSMM_INTRINSICS_AVX2) /*__AVX2__*/ +# if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0 +LIBXSMM_PRAGMA_OPTIMIZE_OFF /* avoid ICE in case of symbols (-g) */ +# endif +/** Generate random number in the interval [0, 1); thread save, state needs to be managed by user. + * this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */ +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) __m256i LIBXSMM_INTRINSICS_MM256_RNG_XOSHIRO128P_EXTSTATE_EPI32(unsigned int* stateptr) { + __m256i state_0 = _mm256_loadu_si256( (const __m256i*)stateptr ); + __m256i state_1 = _mm256_loadu_si256( (const __m256i*)(stateptr+16) ); + __m256i state_2 = _mm256_loadu_si256( (const __m256i*)(stateptr+32) ); + __m256i state_3 = _mm256_loadu_si256( (const __m256i*)(stateptr+48) ); + const __m256i result = _mm256_add_epi32(state_0, state_3); + const __m256i s = _mm256_slli_epi32(state_1, 9); + __m256i t; + state_2 = _mm256_xor_si256(state_2, state_0); + state_3 = _mm256_xor_si256(state_3, state_1); + state_1 = _mm256_xor_si256(state_1, state_2); + state_0 = _mm256_xor_si256(state_0, state_3); + state_2 = _mm256_xor_si256(state_2, s); + _mm256_storeu_si256( (__m256i*)stateptr , state_0 ); + _mm256_storeu_si256( (__m256i*)(stateptr+16), state_1 ); + _mm256_storeu_si256( (__m256i*)(stateptr+32), state_2 ); + t = _mm256_slli_epi32(state_3, 11); + state_3 = _mm256_or_si256(t, _mm256_srli_epi32(state_3, 32 - 11)); + _mm256_storeu_si256( (__m256i*)(stateptr+48), state_3 ); + return result; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) __m256 LIBXSMM_INTRINSICS_MM256_RNG_EXTSTATE_PS(unsigned int* stateptr) { + const __m256i rng_mantissa = _mm256_srli_epi32( LIBXSMM_INTRINSICS_MM256_RNG_XOSHIRO128P_EXTSTATE_EPI32(stateptr), 9 ); + const __m256 one = _mm256_set1_ps(1.0f); + return _mm256_sub_ps(_mm256_castsi256_ps(_mm256_or_si256(_mm256_set1_epi32(0x3f800000), rng_mantissa)), one); +} +# if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0 +LIBXSMM_PRAGMA_OPTIMIZE_ON +# endif +#endif /*__AVX2__*/ + +/** + * Pseudo intrinsics (AVX-512) + */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ +# define LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( A, B ) _mm512_cvtepi32_epi16(_mm512_cvt_roundps_epi32( \ + _mm512_mul_ps(LIBXSMM_INTRINSICS_MM512_LOAD_PS(A), B), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(__m512 a) { + const __m512i vnaninf = _mm512_set1_epi32(0x7f800000), vrneadd = _mm512_set1_epi32(0x00007fff); + const __m512i vfixup = _mm512_set1_epi32(0x00000001), vfixupmask = _mm512_set1_epi32(0x00010000); + const __m512i mm512_roundbf16rne_a_ = _mm512_castps_si512(a); + const __mmask16 mm512_roundbf16rne_mask1_ = _mm512_cmp_epi32_mask(_mm512_and_epi32(mm512_roundbf16rne_a_, vnaninf), vnaninf, _MM_CMPINT_NE); + const __mmask16 mm512_roundbf16rne_mask2_ = _mm512_cmp_epi32_mask(_mm512_and_epi32(mm512_roundbf16rne_a_, vfixupmask), vfixupmask, _MM_CMPINT_EQ); + return _mm512_mask_add_epi32(mm512_roundbf16rne_a_, mm512_roundbf16rne_mask1_, mm512_roundbf16rne_a_, _mm512_mask_add_epi32(vrneadd, mm512_roundbf16rne_mask2_, vrneadd, vfixup)); +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m256i LIBXSMM_INTRINSICS_MM512_CVT_FP32_BF16(__m512 a) { + return _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(a), 16)); +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_CVT2_FP32_BF16(__m512 a, __m512 b) { + const __m256i aa = _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(b), 16)); + const __m256i bb = _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(a), 16)); + return _mm512_inserti64x4(_mm512_inserti64x4(_mm512_setzero_si512(), aa, 0), bb, 1); +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(__m256i a) { + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(a),16)); +} + +/** SVML-intrinsics */ +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(__m512 x) { + const __m512 c0 = _mm512_set1_ps(2027025.0f); + const __m512 c1 = _mm512_set1_ps(270270.0f); + const __m512 c2 = _mm512_set1_ps(6930.0f); + const __m512 c3 = _mm512_set1_ps(36.0f); + const __m512 c1_d = _mm512_set1_ps(945945.0f); + const __m512 c2_d = _mm512_set1_ps(51975.0f); + const __m512 c3_d = _mm512_set1_ps(630.0f); + const __m512 hi_bound = _mm512_set1_ps(4.97f); + const __m512 lo_bound = _mm512_set1_ps(-4.97f); + const __m512 ones = _mm512_set1_ps(1.0f); + const __m512 neg_ones = _mm512_set1_ps(-1.0f); + + const __m512 x2 = _mm512_mul_ps( x, x ); + const __m512 t1_nom = _mm512_fmadd_ps( c3, x2, c2 ); + const __m512 t2_nom = _mm512_fmadd_ps( t1_nom, x2, c1 ); + const __m512 t3_nom = _mm512_fmadd_ps( t2_nom, x2, c0 ); + const __m512 nom = _mm512_mul_ps( t3_nom, x ); + const __m512 t1_denom = _mm512_add_ps( x2, c3_d ); + const __m512 t2_denom = _mm512_fmadd_ps( t1_denom, x2, c2_d ); + const __m512 t3_denom = _mm512_fmadd_ps( t2_denom, x2, c1_d ); + const __m512 denom = _mm512_fmadd_ps( t3_denom, x2, c0 ); + const __m512 denom_rcp = _mm512_rcp14_ps( denom ); + const __mmask16 mask_hi = _mm512_cmp_ps_mask( x, hi_bound, _CMP_GT_OQ); + const __mmask16 mask_lo = _mm512_cmp_ps_mask( x, lo_bound, _CMP_LT_OQ); + __m512 result = _mm512_mul_ps( nom, denom_rcp ); + result = _mm512_mask_blend_ps(mask_hi, result, ones); + result = _mm512_mask_blend_ps(mask_lo, result, neg_ones); + + return result; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_32(__m512 x) { + const __m512 c1 = _mm512_set1_ps((float)(1.0/27.0)); + const __m512 c2 = _mm512_set1_ps((float)(1.0/3)); + const __m512 hi_bound = _mm512_set1_ps(3.2f); + const __m512 lo_bound = _mm512_set1_ps(-3.2f); + const __m512 ones = _mm512_set1_ps(1.0f); + const __m512 neg_ones = _mm512_set1_ps(-1.0f); + + const __m512 x2 = _mm512_mul_ps( x, x ); + const __m512 t1_nom = _mm512_fmadd_ps( x2, c1, ones); + const __m512 nom = _mm512_mul_ps( t1_nom, x ); + const __m512 denom = _mm512_fmadd_ps( x2, c2, ones); + const __m512 denom_rcp = _mm512_rcp14_ps( denom ); + const __mmask16 mask_hi = _mm512_cmp_ps_mask( x, hi_bound, _CMP_GT_OQ); + const __mmask16 mask_lo = _mm512_cmp_ps_mask( x, lo_bound, _CMP_LT_OQ); + __m512 result = _mm512_mul_ps(nom, denom_rcp); + result = _mm512_mask_blend_ps(mask_hi, result, ones); + result = _mm512_mask_blend_ps(mask_lo, result, neg_ones); + + return result; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_EXP2(__m512 _x) { + const __m512 twice_log2_e = _mm512_set1_ps((float)(1.442695*2)); + const __m512 half = _mm512_set1_ps(0.5f); + const __m512 c2 = _mm512_set1_ps(0.240226507f); + const __m512 c1 = _mm512_set1_ps(0.452920674f); + const __m512 c0 = _mm512_set1_ps(0.713483036f); + const __m512 ones = _mm512_set1_ps(1.0f); + const __m512 minus_twos = _mm512_set1_ps(-2.0f); + + const __m512 x = _mm512_fmadd_ps(_x, twice_log2_e, half); +#if 1 + const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION)); +#else + const __m512 y = _mm512_reduce_ps(x, 1); +#endif + const __m512 t1 = _mm512_fmadd_ps( y, c2, c1); + const __m512 two_to_y = _mm512_fmadd_ps( y, t1, c0); + const __m512 exp = _mm512_scalef_ps( two_to_y, x ); + const __m512 denom_rcp = _mm512_rcp14_ps( _mm512_add_ps( exp, ones) ); + __m512 result = _mm512_fmadd_ps( denom_rcp, minus_twos, ones); + + return result; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_EXP3(__m512 _x) { + const __m512 twice_log2_e = _mm512_set1_ps((float)(1.442695*2)); + const __m512 half = _mm512_set1_ps(0.5f); + const __m512 c3 = _mm512_set1_ps(0.05550410866f); + const __m512 c2 = _mm512_set1_ps(0.15697034396f); + const __m512 c1 = _mm512_set1_ps(0.49454875509f); + const __m512 c0 = _mm512_set1_ps(0.70654502287f); + const __m512 ones = _mm512_set1_ps(1.0f); + const __m512 minus_twos = _mm512_set1_ps(-2.0f); + + const __m512 x = _mm512_fmadd_ps(_x, twice_log2_e, half); +#if 1 + const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION)); +#else + const __m512 y = _mm512_reduce_ps(x, 1); +#endif + const __m512 t1 = _mm512_fmadd_ps( y, c3, c2); + const __m512 t2 = _mm512_fmadd_ps( y, t1, c1); + const __m512 two_to_y = _mm512_fmadd_ps( y, t2, c0); + const __m512 exp = _mm512_scalef_ps( two_to_y, x ); + const __m512 denom_rcp = _mm512_rcp14_ps( _mm512_add_ps( exp, ones) ); + __m512 result = _mm512_fmadd_ps( denom_rcp, minus_twos, ones); + + return result; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(__m512 x) { + __m512 result, func_p0, func_p1, func_p2; + const __m512i sign_mask = _mm512_set1_epi32( 0x80000000 ); + const __m512i sign_filter = _mm512_set1_epi32( 0x7FFFFFFF ); + const __m512i lut_low = _mm512_set1_epi32( 246 ); + const __m512i lut_high = _mm512_set1_epi32( 261 ); + const __m512 tanh_p0_2_reg = _mm512_set_ps( 0.40555000f, 0.11892800f, -0.00972979f, -0.02740300f, -0.0169851f, -0.00776152f, -0.00305889f, + -0.00116259f, -0.00041726f, -8.53233e-6f, 1.0000000f, 0.99999800f, 0.99975400f, 0.99268200f, + 0.93645300f, 0.73833900f); + const __m512 tanh_p1_2_reg = _mm512_set_ps( 0.495602f, 0.88152f, 1.125700000f, 1.17021000f, 1.1289000000f, 1.07929000f, 1.0432300f, 1.023010f, + 1.011620f, 1.00164f, 1.56828e-14f, 4.49924e-7f, 0.0000646924f, 0.00260405f, 0.0311608f, 0.168736f); + const __m512 tanh_p2_2_reg = _mm512_set_ps(-0.108238f, -0.2384280f, -0.354418000f, -0.38240300f, -0.34135700f, -0.274509000f, -0.20524900f, -0.1511960f, + -0.107635f, -0.0466868f, -3.60822e-16f, -2.05971e-8f, -4.24538e-6f, -0.000231709f, -0.00386434f, -0.0277702f); + + const __m512i signs = _mm512_and_epi32(_mm512_castps_si512(x), sign_mask); + const __m512i abs_arg = _mm512_and_epi32(_mm512_castps_si512(x), sign_filter); + __m512i indices = _mm512_srli_epi32(abs_arg, 22); + indices = _mm512_max_epi32(indices, lut_low); + indices = _mm512_min_epi32(indices, lut_high); + + func_p0 = _mm512_permutexvar_ps(indices, tanh_p0_2_reg); + func_p1 = _mm512_permutexvar_ps(indices, tanh_p1_2_reg); + func_p2 = _mm512_permutexvar_ps(indices, tanh_p2_2_reg); + + result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), func_p2, func_p1); + result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p0); + result = _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(result), signs)); + + return result; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX3(__m512 x) { + __m512 result, func_p0, func_p1, func_p2, func_p3; + const __m512i sign_mask = _mm512_set1_epi32( 0x80000000 ); + const __m512i sign_filter = _mm512_set1_epi32( 0x7FFFFFFF ); + const __m512i lut_low = _mm512_set1_epi32( 246 ); + const __m512i lut_high = _mm512_set1_epi32( 261 ); + + const __m512 tanh_p0_3_reg = _mm512_setr_ps( 0.466283000f, 0.82850600f, 0.97437500f, 0.99882600f, 0.9999860f, 1.0000000f, -1.50006e-08f, -7.98169e-06f, + -4.53753e-05f, -0.00023755f, -0.00125285f, -0.00572314f, -0.0227717f, -0.0629089f, -0.084234300f, 0.071199800f); + const __m512 tanh_p1_3_reg = _mm512_setr_ps( 0.500617f, 0.124369f, 0.0137214f, 0.000464124f, 4.02465e-06f, 0.00000f, 1.00001f, 1.00028f, 1.00112f, 1.00414f, + 1.015570f, 1.050950f, 1.1478500f, 1.310130000f, 1.378950000f, 1.07407f); + const __m512 tanh_p2_3_reg = _mm512_setr_ps(-0.16133200f, -0.0305526f, -0.00245909f, -6.12647e-05f, -3.76127e-07f, 0.000000f, -0.000245872f, -0.00341151f, + -0.00971505f, -0.0256817f, -0.06869110f, -0.162433000f, -0.346828000f, -0.566516f, -0.640214000f, -0.44011900f); + const __m512 tanh_p3_3_reg = _mm512_setr_ps( 0.0177393f, 0.00253432f, 0.000147303f, 2.69963e-06f, 1.16764e-08f, 0.0000000f, -0.330125f, -0.3176210f, + -0.3017760f, -0.27358000f, -0.219375000f, -0.136197000f, -0.01868680f, 0.0808901f, 0.107095f, 0.0631459f); + + const __m512i signs = _mm512_and_epi32(_mm512_castps_si512(x), sign_mask); + const __m512i abs_arg = _mm512_and_epi32(_mm512_castps_si512(x), sign_filter); + __m512i indices = _mm512_srli_epi32(abs_arg, 22); + indices = _mm512_max_epi32(indices, lut_low); + indices = _mm512_min_epi32(indices, lut_high); + + func_p0 = _mm512_permutexvar_ps(indices, tanh_p0_3_reg); + func_p1 = _mm512_permutexvar_ps(indices, tanh_p1_3_reg); + func_p2 = _mm512_permutexvar_ps(indices, tanh_p2_3_reg); + func_p3 = _mm512_permutexvar_ps(indices, tanh_p3_3_reg); + + result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), func_p3, func_p2); + result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p1); + result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p0); + result = _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(result), signs)); + + return result; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512DQ__ needed*/ +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) __m512 LIBXSMM_INTRINSICS_MM512_GELU_FWD_PS_MINIMAX3(__m512 x) { + const __m512 thres = _mm512_castsi512_ps(_mm512_set1_epi32(0x40879fff)); + const __m512 absmask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffff)); + const __m512 scale = _mm512_castsi512_ps(_mm512_set1_epi32(0x406a0ea1)); + const __m512 shifter = _mm512_castsi512_ps(_mm512_set1_epi32(0x4b400000)); + const __m512 half = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f000000)); + const __m512 _c2 = _mm512_castsi512_ps(_mm512_setr_epi32(0xbd877b85u, 0xbd7d9780u, 0xbd4cb70eu, 0xbd08a1e9u, 0xbc808857u, 0xb9476fd2u, 0x3c36f765u, 0x3c924160u, + 0x3ca7b1fcu, 0x3ca5732cu, 0x3c95af63u, 0x3c8079f7u, 0x3c55fa4fu, 0x3c2fa86bu, 0x3c0fbb00u, 0x3bec178cu)); + const __m512 _c1 = _mm512_castsi512_ps(_mm512_setr_epi32(0xb7c7fb58u, 0xbacb9740u, 0xbc3e4b3au, 0xbd0d292au, 0xbd8bc5d0u, 0xbdd9978fu, 0xbe0f92d3u, 0xbe27b66du, + 0xbe328ce7u, 0xbe3125bfu, 0xbe26dc9du, 0xbe17a056u, 0xbe06bdebu, 0xbdecc593u, 0xbdcf57aau, 0xbdb5ea3au)); + const __m512 _c0 = _mm512_castsi512_ps(_mm512_setr_epi32(0x3ecc4231u, 0x3ecc541cu, 0x3ecd6c48u, 0x3ed174c3u, 0x3ed9bd5du, 0x3ee5acd5u, 0x3ef2aeddu, 0x3efd5384u, + 0x3f016724u, 0x3f00f778u, 0x3efb389eu, 0x3ef0464du, 0x3ee3014fu, 0x3ed50a78u, 0x3ec779dbu, 0x3ebae363u)); + __m512 result; + __m512 xr = _mm512_range_round_ps(x, thres, 2, _MM_FROUND_NO_EXC); + __m512 xa = _mm512_and_ps(xr, absmask); + __m512 index = _mm512_fmadd_ps(xa, scale, shifter); + __m512 c2 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c2); + __m512 c1 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c1); + __m512 c0 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c0); + __m512 poly = _mm512_fmadd_ps(c2, xa, c1); + poly = _mm512_fmadd_ps(poly, xa, c0); + result = _mm512_mul_ps(x, _mm512_fmadd_ps(poly, xr, half)); + + return result; +} +#endif /*defined(LIBXSMM_INTRINSICS_AVX512_CORE)*/ + +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512DQ__ needed*/ +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) __m512 LIBXSMM_INTRINSICS_MM512_GELU_BWD_PS_MINIMAX3(__m512 x) { + const __m512 thres = _mm512_castsi512_ps(_mm512_set1_epi32(0x408f5fff)); + const __m512 absmask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffff)); + const __m512 scale = _mm512_castsi512_ps(_mm512_set1_epi32(0x405d67c9)); + const __m512 shifter = _mm512_castsi512_ps(_mm512_set1_epi32(0x4b400000)); + const __m512 half = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f000000)); + const __m512 _c2 = _mm512_castsi512_ps(_mm512_setr_epi32(0xbe87047bu, 0xbe6eb875u, 0xbe2210c1u, 0xbd81727fu, 0x3cb9625cu, 0x3da2cbe8u, 0x3dd1d4d1u, 0x3dca0bd0u, + 0x3da47dd0u, 0x3d6f1bd3u, 0x3d216381u, 0x3cd2618cu, 0x3c89f6e6u, 0x3c3ca672u, 0x3c08ed08u, 0x3bd26a14u)); + const __m512 _c1 = _mm512_castsi512_ps(_mm512_setr_epi32(0xb930e738u, 0xbc4b28bau, 0xbda4212fu, 0xbe5feb0eu, 0xbec8b0e5u, 0xbf09e61bu, 0xbf1c403fu, 0xbf185954u, + 0xbf03e1eeu, 0xbed08a61u, 0xbe9b4508u, 0xbe61788bu, 0xbe257770u, 0xbdfc542au, 0xbdca014eu, 0xbda8d7e9u)); + const __m512 _c0 = _mm512_castsi512_ps(_mm512_setr_epi32(0x3f4c4245u, 0x3f4c927bu, 0x3f5085f8u, 0x3f5d7bdau, 0x3f73ea12u, 0x3f86142fu, 0x3f8d3df4u, 0x3f8b4b0fu, + 0x3f8022c8u, 0x3f5e5423u, 0x3f39ceb5u, 0x3f199bedu, 0x3f00bee0u, 0x3ede1737u, 0x3ec59b86u, 0x3eb4454cu)); + __m512 result; + __m512 xr = _mm512_range_round_ps(x, thres, 2, _MM_FROUND_NO_EXC); + __m512 xa = _mm512_and_ps(xr, absmask); + __m512 index = _mm512_fmadd_ps(xa, scale, shifter); + __m512 c2 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c2); + __m512 c1 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c1); + __m512 c0 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c0); + __m512 poly = _mm512_fmadd_ps(c2, xa, c1); + poly = _mm512_fmadd_ps(poly, xa, c0); + result = _mm512_fmadd_ps(poly, xr, half); + + return result; +} +#endif /*defined(LIBXSMM_INTRINSICS_AVX512_CORE)*/ + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_GELU_FWD(__m512 x) { + const __m512 c1 = _mm512_set1_ps( (float)0.79788); + const __m512 c2 = _mm512_set1_ps( (float)0.03568); + const __m512 c_half = _mm512_set1_ps( (float)0.5); + + __m512 x_half = _mm512_mul_ps( x, c_half ); + __m512 x_sq = _mm512_mul_ps( x, x ); + __m512 poly_x1 = _mm512_mul_ps(x, _mm512_fmadd_ps( x_sq, c2, c1)); + __m512 tanh_poly_x = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(poly_x1); + __m512 output = _mm512_fmadd_ps(tanh_poly_x, x_half, x_half); + + return output; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_GELU_BWD(__m512 x) { + const __m512 c1 = _mm512_set1_ps( (float)0.79788); + const __m512 c2 = _mm512_set1_ps( (float)0.03568); + const __m512 c3 = _mm512_set1_ps( (float)0.05352); + const __m512 c4 = _mm512_set1_ps( (float)0.39894); + const __m512 c_half = _mm512_set1_ps( (float)0.5); + const __m512 c_ones = _mm512_set1_ps( (float)1.0); + const __m512 c_minus_1 = _mm512_set1_ps( (float)-1.0); + + __m512 x_sq = _mm512_mul_ps( x, x ); + __m512 poly_x1 = _mm512_mul_ps(x, _mm512_fmadd_ps( x_sq, c2, c1)); + __m512 poly_x2 = _mm512_mul_ps(x, _mm512_fmadd_ps( x_sq, c3, c4)); + + __m512 tanh_poly_x = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(poly_x1); + __m512 out1 = _mm512_add_ps(c_ones, tanh_poly_x); + __m512 out2 = _mm512_add_ps(c_half, poly_x2); + __m512 out3 = _mm512_fmsub_ps(poly_x2, tanh_poly_x, out2); + __m512 out4 = _mm512_mul_ps(c_minus_1, out3); + __m512 output = _mm512_mul_ps(out1, out4); + + return output; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_EXP_PS_2DTS(__m512 in) { + const __m512 log2_e = _mm512_set1_ps(1.442695f); + const __m512 half = _mm512_set1_ps(0.5f); + const __m512 c2 = _mm512_set1_ps(0.240226507f); + const __m512 c1 = _mm512_set1_ps(0.452920674f); + const __m512 c0 = _mm512_set1_ps(0.713483036f); + + const __m512 x = _mm512_fmadd_ps(in, log2_e, half); +#if 1 + const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION)); +#else + const __m512 y = _mm512_reduce_ps(x, 1); +#endif + const __m512 t1 = _mm512_fmadd_ps( y, c2, c1); + const __m512 two_to_y = _mm512_fmadd_ps( y, t1, c0); + const __m512 exp = _mm512_scalef_ps( two_to_y, x ); + + return exp; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_EXP_PS_3DTS(__m512 in) { + const __m512 log2_e = _mm512_set1_ps(1.442695f); + const __m512 half = _mm512_set1_ps(0.5f); + const __m512 c3 = _mm512_set1_ps(0.05550410866f); + const __m512 c2 = _mm512_set1_ps(0.15697034396f); + const __m512 c1 = _mm512_set1_ps(0.49454875509f); + const __m512 c0 = _mm512_set1_ps(0.70654502287f); + + const __m512 x = _mm512_fmadd_ps(in, log2_e, half); +#if 1 + const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION)); +#else + const __m512 y = _mm512_reduce_ps(x, 1); +#endif + const __m512 t1 = _mm512_fmadd_ps( y, c3, c2); + const __m512 t2 = _mm512_fmadd_ps( y, t1, c1); + const __m512 two_to_y = _mm512_fmadd_ps( y, t2, c0); + const __m512 exp = _mm512_scalef_ps( two_to_y, x ); + + return exp; +} + +# if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0 +LIBXSMM_PRAGMA_OPTIMIZE_OFF /* avoid ICE in case of symbols (-g) */ +# endif +/** Generate random number in the interval [0, 1); not thread-safe. + * this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */ +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EPI32(void) { + const __m512i result = _mm512_add_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(0), LIBXSMM_INTRINSICS_MM512_RNG_STATE(3)); + const __m512i s = _mm512_slli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(1), 9); + __m512i t; + LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(2), LIBXSMM_INTRINSICS_MM512_RNG_STATE(0)); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), LIBXSMM_INTRINSICS_MM512_RNG_STATE(1)); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(1), LIBXSMM_INTRINSICS_MM512_RNG_STATE(2)); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(0), LIBXSMM_INTRINSICS_MM512_RNG_STATE(3)); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(2), s); + t = _mm512_slli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), 11); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_or_epi32(t, _mm512_srli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), 32 - 11)); + return result; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_RNG_PS(void) { + const __m512i rng_mantissa = _mm512_srli_epi32( LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EPI32(), 9 ); + const __m512 one = _mm512_set1_ps(1.0f); + return _mm512_sub_ps(_mm512_castsi512_ps(_mm512_or_epi32(_mm512_set1_epi32(0x3f800000), rng_mantissa)), one); +} + +/** Generate random number in the interval [0, 1); thread save, state needs to be managed by user. + * this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */ +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EXTSTATE_EPI32(unsigned int* stateptr) { + __m512i state_0 = _mm512_loadu_si512( stateptr ); + __m512i state_1 = _mm512_loadu_si512( stateptr+16 ); + __m512i state_2 = _mm512_loadu_si512( stateptr+32 ); + __m512i state_3 = _mm512_loadu_si512( stateptr+48 ); + const __m512i result = _mm512_add_epi32(state_0, state_3); + const __m512i s = _mm512_slli_epi32(state_1, 9); + __m512i t; + state_2 = _mm512_xor_epi32(state_2, state_0); + state_3 = _mm512_xor_epi32(state_3, state_1); + state_1 = _mm512_xor_epi32(state_1, state_2); + state_0 = _mm512_xor_epi32(state_0, state_3); + state_2 = _mm512_xor_epi32(state_2, s); + _mm512_storeu_si512( stateptr , state_0 ); + _mm512_storeu_si512( stateptr+16, state_1 ); + _mm512_storeu_si512( stateptr+32, state_2 ); + t = _mm512_slli_epi32(state_3, 11); + state_3 = _mm512_or_epi32(t, _mm512_srli_epi32(state_3, 32 - 11)); + _mm512_storeu_si512( stateptr+48, state_3 ); + return result; +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_RNG_EXTSTATE_PS(unsigned int* stateptr) { + const __m512i rng_mantissa = _mm512_srli_epi32( LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EXTSTATE_EPI32(stateptr), 9 ); + const __m512 one = _mm512_set1_ps(1.0f); + return _mm512_sub_ps(_mm512_castsi512_ps(_mm512_or_epi32(_mm512_set1_epi32(0x3f800000), rng_mantissa)), one); +} +# if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0 +LIBXSMM_PRAGMA_OPTIMIZE_ON +# endif +#endif /*__AVX512F__*/ + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#endif /*LIBXSMM_INTRINSICS_X86_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_macros.h b/third_party/libxsmm/include/libxsmm_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..43f3f0d51b3ffced05f678596f39afdc7214af18 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_macros.h @@ -0,0 +1,983 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_MACROS_H +#define LIBXSMM_MACROS_H + +#include "libxsmm_config.h" + +/** Parameters the library was built for. */ +#define LIBXSMM_CACHELINE LIBXSMM_CONFIG_CACHELINE +#define LIBXSMM_ALIGNMENT LIBXSMM_CONFIG_ALIGNMENT +#define LIBXSMM_MALLOC LIBXSMM_CONFIG_MALLOC +#define LIBXSMM_ILP64 LIBXSMM_CONFIG_ILP64 +#define LIBXSMM_SYNC LIBXSMM_CONFIG_SYNC +#define LIBXSMM_JIT LIBXSMM_CONFIG_JIT + +/** Parameters of GEMM domain (static kernels, etc). */ +#define LIBXSMM_PREFETCH LIBXSMM_CONFIG_PREFETCH +#define LIBXSMM_MAX_MNK LIBXSMM_CONFIG_MAX_MNK +#define LIBXSMM_MAX_DIM LIBXSMM_CONFIG_MAX_DIM +#define LIBXSMM_MAX_M LIBXSMM_CONFIG_MAX_M +#define LIBXSMM_MAX_N LIBXSMM_CONFIG_MAX_N +#define LIBXSMM_MAX_K LIBXSMM_CONFIG_MAX_K +#define LIBXSMM_FLAGS LIBXSMM_CONFIG_FLAGS +#define LIBXSMM_ALPHA LIBXSMM_CONFIG_ALPHA +#define LIBXSMM_BETA LIBXSMM_CONFIG_BETA + +/** + * Use "make PLATFORM=1" to disable platform checks. + * The platform check is to bail-out with an error + * message for an attempt to build an upstream package + * and subsequently to list LIBXSMM as "broken" on + * that platform. + * Note: successful compilation on an unsupported + * platform is desired, but only fallback code is + * present at best. + */ +#if !defined(LIBXSMM_PLATFORM_FORCE) && 0 +# define LIBXSMM_PLATFORM_FORCE +#endif + +#if !defined(LIBXSMM_PLATFORM_X86) && ( \ + (defined(__x86_64__) && 0 != (__x86_64__)) || \ + (defined(__amd64__) && 0 != (__amd64__)) || \ + (defined(_M_X64) || defined(_M_AMD64)) || \ + (defined(__i386__) && 0 != (__i386__)) || \ + (defined(_M_IX86))) +# define LIBXSMM_PLATFORM_X86 +#endif +#if !defined(LIBXSMM_PLATFORM_AARCH64) && \ + (defined(__aarch64__) || defined(__arm64__)) +# define LIBXSMM_PLATFORM_AARCH64 +#endif +#if !defined(LIBXSMM_PLATFORM_SUPPORTED) +# if defined(LIBXSMM_PLATFORM_X86) || defined(LIBXSMM_PLATFORM_AARCH64) +# define LIBXSMM_PLATFORM_SUPPORTED +# elif !defined(LIBXSMM_PLATFORM_FORCE) +# error LIBXSMM requires X86_64, AArch64, or compatible CPUs! +# endif +#endif +#if !defined(LIBXSMM_BITS) +# if (defined(__SIZEOF_PTRDIFF_T__) && 4 < (__SIZEOF_PTRDIFF_T__)) || \ + (defined(__SIZE_MAX__) && (4294967295U < (__SIZE_MAX__))) || \ + (defined(__x86_64__) && 0 != (__x86_64__)) || \ + (defined(__amd64__) && 0 != (__amd64__)) || \ + (defined(_M_X64) || defined(_M_AMD64)) || \ + (defined(_WIN64)) || \ + (defined(__powerpc64)) || \ + (defined(__aarch64__)) +# define LIBXSMM_UNLIMITED 0xFFFFFFFFFFFFFFFF +# define LIBXSMM_BITS 64 +# elif !defined(LIBXSMM_PLATFORM_FORCE) && defined(NDEBUG) +# error LIBXSMM is only supported on 64-bit platforms! +# else /* JIT-generated code (among other issues) is not supported! */ +# define LIBXSMM_UNLIMITED 0xFFFFFFFF +# define LIBXSMM_BITS 32 +# endif +#endif + +#define LIBXSMM_STRINGIFY2(SYMBOL) #SYMBOL +#define LIBXSMM_STRINGIFY(SYMBOL) LIBXSMM_STRINGIFY2(SYMBOL) +#define LIBXSMM_TOSTRING(SYMBOL) LIBXSMM_STRINGIFY(SYMBOL) +#define LIBXSMM_CONCATENATE2(A, B) A##B +#define LIBXSMM_CONCATENATE3(A, B, C) LIBXSMM_CONCATENATE(LIBXSMM_CONCATENATE(A, B), C) +#define LIBXSMM_CONCATENATE4(A, B, C, D) LIBXSMM_CONCATENATE(LIBXSMM_CONCATENATE3(A, B, C), D) +#define LIBXSMM_CONCATENATE(A, B) LIBXSMM_CONCATENATE2(A, B) +#define LIBXSMM_FSYMBOL(SYMBOL) LIBXSMM_CONCATENATE(SYMBOL, _) +#define LIBXSMM_UNIQUE(NAME) LIBXSMM_CONCATENATE(NAME, __LINE__) +#define LIBXSMM_EXPAND(...) __VA_ARGS__ +#define LIBXSMM_ELIDE(...) + +/** + * Check given value against type-range (assertion). + * Note: allows "-1" for unsigned types. + */ +#if !defined(NDEBUG) +# define LIBXSMM_CHECK_ULLONG(VALUE) assert(-1 <= (VALUE) && (VALUE) <= ULLONG_MAX) +# define LIBXSMM_CHECK_LLONG(VALUE) assert(ULLONG_MIN <= (VALUE) && (VALUE) <= LLONG_MAX) +# define LIBXSMM_CHECK_ULONG(VALUE) assert(-1 <= (VALUE) && (VALUE) <= ULONG_MAX) +# define LIBXSMM_CHECK_LONG(VALUE) assert(LONG_MIN <= (VALUE) && (VALUE) <= LONG_MAX) +# define LIBXSMM_CHECK_USHORT(VALUE) assert(-1 <= (VALUE) && (VALUE) <= USHRT_MAX) +# define LIBXSMM_CHECK_SHORT(VALUE) assert(SHRT_MIN <= (VALUE) && (VALUE) <= SHRT_MAX) +# define LIBXSMM_CHECK_UCHAR(VALUE) assert(-1 <= (VALUE) && (VALUE) <= UCHAR_MAX) +# define LIBXSMM_CHECK_ICHAR(VALUE) assert(SCHAR_MIN <= (VALUE) && (VALUE) <= SCHAR_MAX) +# define LIBXSMM_CHECK_UINT(VALUE) assert(-1 <= (VALUE) && (VALUE) <= UINT_MAX) +# define LIBXSMM_CHECK_INT(VALUE) assert(INT_MIN <= (VALUE) && (VALUE) <= INT_MAX) +#else +# define LIBXSMM_CHECK_ULLONG(VALUE) 0/*dummy*/ +# define LIBXSMM_CHECK_LLONG(VALUE) 0/*dummy*/ +# define LIBXSMM_CHECK_ULONG(VALUE) 0/*dummy*/ +# define LIBXSMM_CHECK_LONG(VALUE) 0/*dummy*/ +# define LIBXSMM_CHECK_USHORT(VALUE) 0/*dummy*/ +# define LIBXSMM_CHECK_SHORT(VALUE) 0/*dummy*/ +# define LIBXSMM_CHECK_UCHAR(VALUE) 0/*dummy*/ +# define LIBXSMM_CHECK_ICHAR(VALUE) 0/*dummy*/ +# define LIBXSMM_CHECK_UINT(VALUE) 0/*dummy*/ +# define LIBXSMM_CHECK_INT(VALUE) 0/*dummy*/ +#endif + +/** + * Perform verbose type-cast with following two advantages: + * (1) Make it easy to locate/find the type-cast. + * (2) Range-check to ensure fitting into type. + */ +#define LIBXSMM_CAST_ULLONG(VALUE) (LIBXSMM_CHECK_ULLONG(VALUE), (unsigned long long)(VALUE)) +#define LIBXSMM_CAST_LLONG(VALUE) (LIBXSMM_CHECK_LLONG(VALUE), (/*signed*/long long)(VALUE)) +#define LIBXSMM_CAST_ULONG(VALUE) (LIBXSMM_CHECK_ULONG(VALUE), (unsigned long)(VALUE)) +#define LIBXSMM_CAST_LONG(VALUE) (LIBXSMM_CHECK_LONG(VALUE), (/*signed*/long)(VALUE)) +#define LIBXSMM_CAST_USHORT(VALUE) (LIBXSMM_CHECK_USHORT(VALUE), (unsigned short)(VALUE)) +#define LIBXSMM_CAST_SHORT(VALUE) (LIBXSMM_CHECK_SHORT(VALUE), (/*signed*/short)(VALUE)) +#define LIBXSMM_CAST_UCHAR(VALUE) (LIBXSMM_CHECK_UCHAR(VALUE), (unsigned char)(VALUE)) +#define LIBXSMM_CAST_ICHAR(VALUE) (LIBXSMM_CHECK_ICHAR(VALUE), (signed char)(VALUE)) +#define LIBXSMM_CAST_UINT(VALUE) (LIBXSMM_CHECK_UINT(VALUE), (unsigned int)(VALUE)) +#define LIBXSMM_CAST_INT(VALUE) (LIBXSMM_CHECK_INT(VALUE), (/*signed*/int)(VALUE)) + +/** Use LIBXSMM_VERSION2 instead of LIBXSMM_VERSION3, e.g., if __GNUC_PATCHLEVEL__ or __clang_patchlevel__ is zero (0). */ +#define LIBXSMM_VERSION2(MAJOR, MINOR) ((MAJOR) * 10000 + (MINOR) * 100) +#define LIBXSMM_VERSION3(MAJOR, MINOR, UPDATE) (LIBXSMM_VERSION2(MAJOR, MINOR) + (UPDATE)) +#define LIBXSMM_VERSION4(MAJOR, MINOR, UPDATE, PATCH) \ + (((0x7F & (MAJOR)) << 24) | ((0x1F & (MINOR)) << 19) | ((0x1F & (UPDATE)) << 14) | (0x3FFF & (PATCH))) +#define LIBXSMM_VERSION41(VERSION) (((VERSION) >> 24)) +#define LIBXSMM_VERSION42(VERSION) (((VERSION) >> 19) & 0x1F) +#define LIBXSMM_VERSION43(VERSION) (((VERSION) >> 14) & 0x1F) +#define LIBXSMM_VERSION44(VERSION) (((VERSION)) & 0x3FFF) + +#if !defined(LIBXSMM_UNPACKED) && (defined(_CRAYC) || defined(LIBXSMM_OFFLOAD_BUILD) || \ + (0 == LIBXSMM_SYNC)/*Windows: missing pack(pop) error*/) +# define LIBXSMM_UNPACKED +#endif +#if defined(_WIN32) && !defined(__GNUC__) && !defined(__clang__) +# define LIBXSMM_ATTRIBUTE(A) __declspec(A) +# if defined(__cplusplus) +# define LIBXSMM_INLINE_ALWAYS __forceinline +# else +# define LIBXSMM_INLINE_ALWAYS static __forceinline +# endif +# define LIBXSMM_ALIGNED(DECL, N) LIBXSMM_ATTRIBUTE(align(N)) DECL +# if !defined(LIBXSMM_UNPACKED) +# define LIBXSMM_PACKED(TYPE) LIBXSMM_PRAGMA(pack(1)) TYPE +# endif +# define LIBXSMM_CDECL __cdecl +#elif (defined(__GNUC__) || defined(__clang__) || defined(__PGI)) +# define LIBXSMM_ATTRIBUTE(A) __attribute__((A)) +# define LIBXSMM_INLINE_ALWAYS LIBXSMM_ATTRIBUTE(always_inline) LIBXSMM_INLINE +# define LIBXSMM_ALIGNED(DECL, N) LIBXSMM_ATTRIBUTE(aligned(N)) DECL +# if !defined(LIBXSMM_UNPACKED) +# define LIBXSMM_PACKED(TYPE) TYPE LIBXSMM_ATTRIBUTE(__packed__) +# endif +# define LIBXSMM_CDECL LIBXSMM_ATTRIBUTE(cdecl) +#else +# define LIBXSMM_ATTRIBUTE(A) +# define LIBXSMM_INLINE_ALWAYS LIBXSMM_INLINE +# define LIBXSMM_ALIGNED(DECL, N) DECL +# define LIBXSMM_CDECL +#endif +#if !defined(LIBXSMM_PACKED) +# define LIBXSMM_PACKED(TYPE) TYPE +# if !defined(LIBXSMM_UNPACKED) +# define LIBXSMM_UNPACKED +# endif +#endif +#if !defined(LIBXSMM_UNPACKED) && 0 +/* no braces around EXPR */ +# define LIBXSMM_PAD(EXPR) EXPR; +#endif +#if !defined(LIBXSMM_PAD) +# define LIBXSMM_PAD(EXPR) +#endif + +#if defined(__INTEL_COMPILER) +# if !defined(__INTEL_COMPILER_UPDATE) +# define LIBXSMM_INTEL_COMPILER __INTEL_COMPILER +# else +# define LIBXSMM_INTEL_COMPILER (__INTEL_COMPILER + __INTEL_COMPILER_UPDATE) +# endif +#elif defined(__INTEL_COMPILER_BUILD_DATE) +# define LIBXSMM_INTEL_COMPILER ((__INTEL_COMPILER_BUILD_DATE / 10000 - 2000) * 100) +#endif + +/* LIBXSMM_ATTRIBUTE_USED: mark library functions as used to avoid warning */ +#if defined(__GNUC__) || defined(__clang__) || (defined(__INTEL_COMPILER) && !defined(_WIN32)) +# if !defined(__cplusplus) || !defined(__clang__) +# define LIBXSMM_ATTRIBUTE_COMMON LIBXSMM_ATTRIBUTE(common) +# else +# define LIBXSMM_ATTRIBUTE_COMMON +# endif +# define LIBXSMM_ATTRIBUTE_MALLOC LIBXSMM_ATTRIBUTE(malloc) +# define LIBXSMM_ATTRIBUTE_UNUSED LIBXSMM_ATTRIBUTE(unused) +# define LIBXSMM_ATTRIBUTE_USED LIBXSMM_ATTRIBUTE(used) +#else +# if defined(_WIN32) +# define LIBXSMM_ATTRIBUTE_COMMON LIBXSMM_ATTRIBUTE(selectany) +# else +# define LIBXSMM_ATTRIBUTE_COMMON +# endif +# define LIBXSMM_ATTRIBUTE_MALLOC +# define LIBXSMM_ATTRIBUTE_UNUSED +# define LIBXSMM_ATTRIBUTE_USED +#endif +#if defined(__clang__) && !defined(__INTEL_COMPILER) +# define LIBXSMM_ATTRIBUTE_NO_SANITIZE(KIND) LIBXSMM_ATTRIBUTE(no_sanitize(LIBXSMM_STRINGIFY(KIND))) +#elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 8) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) \ + && !defined(__INTEL_COMPILER) +# define LIBXSMM_ATTRIBUTE_NO_SANITIZE(KIND) LIBXSMM_ATTRIBUTE(LIBXSMM_CONCATENATE(no_sanitize_, KIND)) +#else +# define LIBXSMM_ATTRIBUTE_NO_SANITIZE(KIND) +#endif + +#if defined(__cplusplus) +# define LIBXSMM_VARIADIC ... +# define LIBXSMM_EXTERN extern "C" +# define LIBXSMM_EXTERN_C extern "C" +# define LIBXSMM_INLINE_KEYWORD inline +# define LIBXSMM_INLINE LIBXSMM_INLINE_KEYWORD +# if defined(__GNUC__) || defined(_CRAYC) +# define LIBXSMM_CALLER __PRETTY_FUNCTION__ +# elif defined(_MSC_VER) +# define LIBXSMM_CALLER __FUNCDNAME__ +# define LIBXSMM_FUNCNAME __FUNCTION__ +# else +# define LIBXSMM_CALLER __FUNCNAME__ +# endif +#else /* C */ +# define LIBXSMM_VARIADIC +# define LIBXSMM_EXTERN extern +# define LIBXSMM_EXTERN_C +# if defined(__STDC_VERSION__) && (199901L <= __STDC_VERSION__) /*C99*/ +# define LIBXSMM_PRAGMA(DIRECTIVE) _Pragma(LIBXSMM_STRINGIFY(DIRECTIVE)) +# define LIBXSMM_CALLER __func__ +# define LIBXSMM_RESTRICT restrict +# define LIBXSMM_INLINE_KEYWORD inline +# elif defined(_MSC_VER) +# define LIBXSMM_CALLER __FUNCDNAME__ +# define LIBXSMM_FUNCNAME __FUNCTION__ +# define LIBXSMM_INLINE_KEYWORD __inline +# define LIBXSMM_INLINE_FIXUP +# elif defined(__GNUC__) && !defined(__STRICT_ANSI__) +# define LIBXSMM_CALLER __PRETTY_FUNCTION__ +# endif +# if !defined(LIBXSMM_INLINE_KEYWORD) +# define LIBXSMM_INLINE_KEYWORD +# define LIBXSMM_INLINE_FIXUP +# endif +/* LIBXSMM_ATTRIBUTE_USED: increases compile-time of header-only by a large factor */ +# define LIBXSMM_INLINE static LIBXSMM_INLINE_KEYWORD LIBXSMM_ATTRIBUTE_UNUSED +#endif /*__cplusplus*/ +#if !defined(LIBXSMM_CALLER) +# define LIBXSMM_CALLER NULL +#endif +#if !defined(LIBXSMM_FUNCNAME) +# define LIBXSMM_FUNCNAME LIBXSMM_CALLER +#endif +#if !defined(LIBXSMM_CALLER_ID) +# if defined(__GNUC__) || 1 +# define LIBXSMM_CALLER_ID ((const void*)((uintptr_t)libxsmm_hash_string(LIBXSMM_CALLER))) +# else /* assume no string-pooling (perhaps unsafe) */ +# define LIBXSMM_CALLER_ID LIBXSMM_CALLER +# endif +#endif + +#if defined(LIBXSMM_OFFLOAD_BUILD) && \ + defined(__INTEL_OFFLOAD) && (!defined(_WIN32) || (1400 <= LIBXSMM_INTEL_COMPILER)) +# define LIBXSMM_OFFLOAD(A) LIBXSMM_ATTRIBUTE(target(A)) +# define LIBXSMM_NO_OFFLOAD(RTYPE, FN, ...) ((RTYPE (*)(LIBXSMM_VARIADIC))(FN))(__VA_ARGS__) +# if !defined(LIBXSMM_OFFLOAD_TARGET) +# define LIBXSMM_OFFLOAD_TARGET mic +# endif +#else +# define LIBXSMM_OFFLOAD(A) +# define LIBXSMM_NO_OFFLOAD(RTYPE, FN, ...) (FN)(__VA_ARGS__) +#endif +#define LIBXSMM_RETARGETABLE LIBXSMM_OFFLOAD(LIBXSMM_OFFLOAD_TARGET) + +#if !defined(__STATIC) && !defined(_WINDLL) && (defined(_WIN32) || defined(__CYGWIN__) || defined(__MINGW32__)) +# define __STATIC +#endif + +/* may include Clang and other compatible compilers */ +#if defined(__GNUC__) && !defined(_WIN32) && !defined(__CYGWIN__) && !defined(__MINGW32__) +# define LIBXSMM_VISIBILITY_INTERNAL LIBXSMM_ATTRIBUTE(visibility("internal")) +# define LIBXSMM_VISIBILITY_HIDDEN LIBXSMM_ATTRIBUTE(visibility("hidden")) +# define LIBXSMM_VISIBILITY_PUBLIC LIBXSMM_ATTRIBUTE(visibility("default")) +#endif +#if !defined(LIBXSMM_VISIBILITY_INTERNAL) +# define LIBXSMM_VISIBILITY_INTERNAL +#endif +#if !defined(LIBXSMM_VISIBILITY_HIDDEN) +# define LIBXSMM_VISIBILITY_HIDDEN +#endif +#if !defined(LIBXSMM_VISIBILITY_PUBLIC) +# define LIBXSMM_VISIBILITY_PUBLIC +#endif +#if !defined(LIBXSMM_VISIBILITY_PRIVATE) +# define LIBXSMM_VISIBILITY_PRIVATE LIBXSMM_VISIBILITY_HIDDEN +#endif + +/* Windows Dynamic Link Library (DLL) */ +#if !defined(__STATIC) && (defined(_WIN32) || defined(__CYGWIN__) || defined(__MINGW32__)) +# define LIBXSMM_VISIBILITY_EXPORT LIBXSMM_ATTRIBUTE(dllexport) +# define LIBXSMM_VISIBILITY_IMPORT LIBXSMM_ATTRIBUTE(dllimport) +#endif +#if !defined(LIBXSMM_VISIBILITY_EXPORT) +# define LIBXSMM_VISIBILITY_EXPORT LIBXSMM_VISIBILITY_PUBLIC +#endif +#if !defined(LIBXSMM_VISIBILITY_IMPORT) +# define LIBXSMM_VISIBILITY_IMPORT LIBXSMM_VISIBILITY_PUBLIC +#endif + +#if defined(LIBXSMM_SOURCE_H) /* header-only mode */ +# define LIBXSMM_API_VISIBILITY_EXPORT +# define LIBXSMM_API_VISIBILITY_IMPORT +# define LIBXSMM_API_VISIBILITY_INTERN +# define LIBXSMM_API_COMMON LIBXSMM_RETARGETABLE LIBXSMM_ATTRIBUTE_COMMON +# define LIBXSMM_API_TARGET LIBXSMM_API_INLINE +# define LIBXSMM_API_EXTERN LIBXSMM_EXTERN_C +#else /* classic ABI */ +# if defined(LIBXSMM_BUILD_EXT) +# define LIBXSMM_API_VISIBILITY_EXPORT LIBXSMM_VISIBILITY_IMPORT +# define LIBXSMM_API_VISIBILITY_IMPORT LIBXSMM_VISIBILITY_EXPORT +# define LIBXSMM_API_VISIBILITY_INTERN LIBXSMM_VISIBILITY_PRIVATE +# elif defined(LIBXSMM_BUILD) +# define LIBXSMM_API_VISIBILITY_EXPORT LIBXSMM_VISIBILITY_EXPORT +# define LIBXSMM_API_VISIBILITY_IMPORT LIBXSMM_VISIBILITY_IMPORT +# define LIBXSMM_API_VISIBILITY_INTERN LIBXSMM_VISIBILITY_PRIVATE +# else /* import */ +# define LIBXSMM_API_VISIBILITY_EXPORT LIBXSMM_VISIBILITY_IMPORT +# define LIBXSMM_API_VISIBILITY_IMPORT LIBXSMM_VISIBILITY_IMPORT +# define LIBXSMM_API_VISIBILITY_INTERN +# endif +# define LIBXSMM_API_COMMON LIBXSMM_RETARGETABLE +# define LIBXSMM_API_TARGET LIBXSMM_RETARGETABLE +# define LIBXSMM_API_EXTERN LIBXSMM_EXTERN +#endif + +#define LIBXSMM_API_VISIBILITY(VISIBILITY) LIBXSMM_CONCATENATE(LIBXSMM_API_VISIBILITY_, VISIBILITY) +#define LIBXSMM_APIVAR(DECL, VISIBILITY, EXTERN) EXTERN LIBXSMM_API_COMMON LIBXSMM_API_VISIBILITY(VISIBILITY) DECL +#define LIBXSMM_API_INLINE LIBXSMM_INLINE LIBXSMM_RETARGETABLE +#define LIBXSMM_API_DEF + +#if (!defined(__INTEL_COMPILER) || !defined(_WIN32)) +#define LIBXSMM_APIVAR_ALIGNED(DECL, VISIBILITY) LIBXSMM_ALIGNED(LIBXSMM_APIVAR(DECL, VISIBILITY, LIBXSMM_API_DEF), LIBXSMM_CONFIG_CACHELINE) +#else +#define LIBXSMM_APIVAR_ALIGNED(DECL, VISIBILITY) LIBXSMM_APIVAR(DECL, VISIBILITY, LIBXSMM_API_DEF) +#endif + +/** Public variable declaration (without definition) located in header file. */ +#define LIBXSMM_APIVAR_PUBLIC(DECL) LIBXSMM_APIVAR(DECL, EXPORT, LIBXSMM_API_EXTERN) +/** Public variable definition (complements declaration) located in source file. */ +#define LIBXSMM_APIVAR_PUBLIC_DEF(DECL) LIBXSMM_APIVAR_ALIGNED(DECL, EXPORT) +/** Private variable declaration (without definition) located in header file. */ +#define LIBXSMM_APIVAR_PRIVATE(DECL) LIBXSMM_APIVAR(DECL, INTERN, LIBXSMM_API_EXTERN) +/** Private variable definition (complements declaration) located in source file. */ +#define LIBXSMM_APIVAR_PRIVATE_DEF(DECL) LIBXSMM_APIVAR_ALIGNED(DECL, INTERN) +/** Private variable (declaration and definition) located in source file. */ +#define LIBXSMM_APIVAR_DEFINE(DECL) LIBXSMM_APIVAR_PRIVATE(DECL); LIBXSMM_APIVAR_PRIVATE_DEF(DECL) +/** Function decoration used for private functions. */ +#define LIBXSMM_API_INTERN LIBXSMM_API_EXTERN LIBXSMM_API_TARGET LIBXSMM_API_VISIBILITY(INTERN) +/** Function decoration used for public functions of LIBXSMMext library. */ +#define LIBXSMM_APIEXT LIBXSMM_API_EXTERN LIBXSMM_API_TARGET LIBXSMM_API_VISIBILITY(IMPORT) +/** Function decoration used for public functions of LIBXSMM library. */ +#define LIBXSMM_API LIBXSMM_API_EXTERN LIBXSMM_API_TARGET LIBXSMM_API_VISIBILITY(EXPORT) + +#if !defined(LIBXSMM_RESTRICT) +# if ((defined(__GNUC__) && !defined(__CYGWIN32__)) || defined(LIBXSMM_INTEL_COMPILER)) && !defined(_WIN32) +# define LIBXSMM_RESTRICT __restrict__ +# elif defined(_MSC_VER) || defined(LIBXSMM_INTEL_COMPILER) +# define LIBXSMM_RESTRICT __restrict +# else +# define LIBXSMM_RESTRICT +# endif +#endif /*LIBXSMM_RESTRICT*/ + +#if !defined(LIBXSMM_PRAGMA) +# if defined(LIBXSMM_INTEL_COMPILER) || defined(_MSC_VER) +# define LIBXSMM_PRAGMA(DIRECTIVE) __pragma(LIBXSMM_EXPAND(DIRECTIVE)) +# else +# define LIBXSMM_PRAGMA(DIRECTIVE) +# endif +#endif /*LIBXSMM_PRAGMA*/ + +#if !defined(LIBXSMM_OPENMP_SIMD) && (defined(_OPENMP) && (201307 <= _OPENMP/*v4.0*/)) +# if defined(LIBXSMM_INTEL_COMPILER) +# if (1500 <= LIBXSMM_INTEL_COMPILER) +# define LIBXSMM_OPENMP_SIMD +# endif +# elif defined(__GNUC__) +# if LIBXSMM_VERSION2(4, 9) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) +# define LIBXSMM_OPENMP_SIMD +# endif +# else +# define LIBXSMM_OPENMP_SIMD +# endif +#endif + +#if !defined(LIBXSMM_INTEL_COMPILER) || (LIBXSMM_INTEL_COMPILER < 9900) +# if defined(LIBXSMM_OPENMP_SIMD) +# define LIBXSMM_PRAGMA_SIMD_REDUCTION(EXPRESSION) LIBXSMM_PRAGMA(omp simd reduction(EXPRESSION)) +# define LIBXSMM_PRAGMA_SIMD_COLLAPSE(N) LIBXSMM_PRAGMA(omp simd collapse(N)) +# define LIBXSMM_PRAGMA_SIMD_PRIVATE(...) LIBXSMM_PRAGMA(omp simd private(__VA_ARGS__)) +# define LIBXSMM_PRAGMA_SIMD LIBXSMM_PRAGMA(omp simd) +# elif defined(__INTEL_COMPILER) +# define LIBXSMM_PRAGMA_SIMD_REDUCTION(EXPRESSION) LIBXSMM_PRAGMA(simd reduction(EXPRESSION)) +# define LIBXSMM_PRAGMA_SIMD_COLLAPSE(N) LIBXSMM_PRAGMA(simd collapse(N)) +# define LIBXSMM_PRAGMA_SIMD_PRIVATE(...) LIBXSMM_PRAGMA(simd private(__VA_ARGS__)) +# define LIBXSMM_PRAGMA_SIMD LIBXSMM_PRAGMA(simd) +# endif +#endif +#if !defined(LIBXSMM_PRAGMA_SIMD) +# define LIBXSMM_PRAGMA_SIMD_REDUCTION(EXPRESSION) +# define LIBXSMM_PRAGMA_SIMD_COLLAPSE(N) +# define LIBXSMM_PRAGMA_SIMD_PRIVATE(...) +# define LIBXSMM_PRAGMA_SIMD +#endif + +#if defined(__INTEL_COMPILER) +# define LIBXSMM_PRAGMA_NONTEMPORAL(...) LIBXSMM_PRAGMA(vector nontemporal(__VA_ARGS__)) +# define LIBXSMM_PRAGMA_VALIGNED LIBXSMM_PRAGMA(vector aligned) +# define LIBXSMM_PRAGMA_NOVECTOR LIBXSMM_PRAGMA(novector) +# define LIBXSMM_PRAGMA_FORCEINLINE LIBXSMM_PRAGMA(forceinline) +# define LIBXSMM_PRAGMA_LOOP_COUNT(MIN, MAX, AVG) LIBXSMM_PRAGMA(loop_count min=MIN max=MAX avg=AVG) +# define LIBXSMM_PRAGMA_UNROLL_AND_JAM(N) LIBXSMM_PRAGMA(unroll_and_jam(N)) +# define LIBXSMM_PRAGMA_UNROLL_N(N) LIBXSMM_PRAGMA(unroll(N)) +# define LIBXSMM_PRAGMA_UNROLL LIBXSMM_PRAGMA(unroll) +# define LIBXSMM_PRAGMA_VALIGNED_VAR(A) LIBXSMM_ASSUME_ALIGNED(A, LIBXSMM_ALIGNMENT); +/*# define LIBXSMM_UNUSED(VARIABLE) LIBXSMM_PRAGMA(unused(VARIABLE))*/ +#else +# if defined(LIBXSMM_OPENMP_SIMD) && (201811 <= _OPENMP/*v5.0*/) +# define LIBXSMM_PRAGMA_NONTEMPORAL(...) LIBXSMM_PRAGMA(omp simd nontemporal(__VA_ARGS__)) +# else +# define LIBXSMM_PRAGMA_NONTEMPORAL(...) +# endif +# if defined(__clang__) +# define LIBXSMM_PRAGMA_VALIGNED_VAR(A) +# define LIBXSMM_PRAGMA_VALIGNED +# define LIBXSMM_PRAGMA_NOVECTOR LIBXSMM_PRAGMA(clang loop vectorize(disable)) +# define LIBXSMM_PRAGMA_FORCEINLINE +# define LIBXSMM_PRAGMA_LOOP_COUNT(MIN, MAX, AVG) LIBXSMM_PRAGMA(unroll(AVG)) +# define LIBXSMM_PRAGMA_UNROLL_AND_JAM(N) LIBXSMM_PRAGMA(unroll(N)) +# define LIBXSMM_PRAGMA_UNROLL_N(N) LIBXSMM_PRAGMA(unroll(N)) +# define LIBXSMM_PRAGMA_UNROLL LIBXSMM_PRAGMA_UNROLL_N(4) +# else +# define LIBXSMM_PRAGMA_VALIGNED_VAR(A) +# define LIBXSMM_PRAGMA_VALIGNED +# define LIBXSMM_PRAGMA_NOVECTOR +# define LIBXSMM_PRAGMA_FORCEINLINE +# define LIBXSMM_PRAGMA_LOOP_COUNT(MIN, MAX, AVG) +# define LIBXSMM_PRAGMA_UNROLL_AND_JAM(N) +# define LIBXSMM_PRAGMA_UNROLL +# endif +#endif +#if !defined(LIBXSMM_PRAGMA_UNROLL_N) +# if defined(__GNUC__) && (LIBXSMM_VERSION2(8, 3) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) +# define LIBXSMM_PRAGMA_UNROLL_N(N) LIBXSMM_PRAGMA(GCC unroll N) +# else +# define LIBXSMM_PRAGMA_UNROLL_N(N) +# endif +#endif + +#if defined(LIBXSMM_INTEL_COMPILER) +# define LIBXSMM_PRAGMA_OPTIMIZE_OFF LIBXSMM_PRAGMA(optimize("", off)) +# define LIBXSMM_PRAGMA_OPTIMIZE_ON LIBXSMM_PRAGMA(optimize("", on)) +#elif defined(__clang__) +# define LIBXSMM_PRAGMA_OPTIMIZE_OFF LIBXSMM_PRAGMA(clang optimize off) +# define LIBXSMM_PRAGMA_OPTIMIZE_ON LIBXSMM_PRAGMA(clang optimize on) +#elif defined(__GNUC__) +# define LIBXSMM_PRAGMA_OPTIMIZE_OFF LIBXSMM_PRAGMA(GCC push_options) LIBXSMM_PRAGMA(GCC optimize("O0")) +# define LIBXSMM_PRAGMA_OPTIMIZE_ON LIBXSMM_PRAGMA(GCC pop_options) +#else +# define LIBXSMM_PRAGMA_OPTIMIZE_OFF +# define LIBXSMM_PRAGMA_OPTIMIZE_ON +#endif + +#if defined(_OPENMP) && (200805 <= _OPENMP/*v3.0*/) \ + && defined(NDEBUG) /* CCE complains for debug builds */ +# define LIBXSMM_OPENMP_COLLAPSE(N) collapse(N) +#else +# define LIBXSMM_OPENMP_COLLAPSE(N) +#endif + +/** LIBXSMM_UP2POT rounds up to the next power of two (POT). */ +#define LIBXSMM_UP2POT_01(N) ((N) | ((N) >> 1)) +#define LIBXSMM_UP2POT_02(N) (LIBXSMM_UP2POT_01(N) | (LIBXSMM_UP2POT_01(N) >> 2)) +#define LIBXSMM_UP2POT_04(N) (LIBXSMM_UP2POT_02(N) | (LIBXSMM_UP2POT_02(N) >> 4)) +#define LIBXSMM_UP2POT_08(N) (LIBXSMM_UP2POT_04(N) | (LIBXSMM_UP2POT_04(N) >> 8)) +#define LIBXSMM_UP2POT_16(N) (LIBXSMM_UP2POT_08(N) | (LIBXSMM_UP2POT_08(N) >> 16)) +#define LIBXSMM_UP2POT_32(N) (LIBXSMM_UP2POT_16(N) | (LIBXSMM_UP2POT_16(N) >> 32)) +#define LIBXSMM_UP2POT(N) (LIBXSMM_UP2POT_32((unsigned long long)(N) - LIBXSMM_MIN(1, N)) + LIBXSMM_MIN(1, N)) +#define LIBXSMM_LO2POT(N) (LIBXSMM_UP2POT_32((unsigned long long)(N) >> 1) + LIBXSMM_MIN(1, N)) + +#define LIBXSMM_UPDIV(N, MULT) (((N) + ((MULT) - 1)) / (MULT)) +#define LIBXSMM_UP(N, MULT) (LIBXSMM_UPDIV(N, MULT) * (MULT)) +#define LIBXSMM_UP2(N, NPOT) (((N) + ((NPOT) - 1)) & ~((NPOT) - 1)) +#define LIBXSMM_ABS(A) (0 <= (A) ? (A) : -(A)) +#define LIBXSMM_MIN(A, B) ((A) < (B) ? (A) : (B)) +#define LIBXSMM_MAX(A, B) ((A) < (B) ? (B) : (A)) +#define LIBXSMM_MOD(A, N) ((A) % (N)) +#define LIBXSMM_MOD2(A, NPOT) ((A) & ((NPOT) - 1)) +#define LIBXSMM_DELTA(T0, T1) ((T0) < (T1) ? ((T1) - (T0)) : ((T0) - (T1))) +#define LIBXSMM_CLMP(VALUE, LO, HI) ((LO) < (VALUE) ? ((VALUE) <= (HI) ? (VALUE) : LIBXSMM_MIN(VALUE, HI)) : LIBXSMM_MAX(LO, VALUE)) +#define LIBXSMM_SIZEOF(START, LAST) (((const char*)(LAST)) - ((const char*)(START)) + sizeof(*LAST)) +#define LIBXSMM_FEQ(A, B) ((A) == (B)) +#define LIBXSMM_NEQ(A, B) ((A) != (B)) +#define LIBXSMM_ISPOT(A) (0 != (A) && !((A) & ((A) - 1))) +#define LIBXSMM_ISWAP(A, B) (((A) ^= (B)), ((B) ^= (A)), ((A) ^= (B))) +#define LIBXSMM_ISNAN(A) LIBXSMM_NEQ(A, A) +#define LIBXSMM_NOTNAN(A) LIBXSMM_FEQ(A, A) +#define LIBXSMM_ROUNDX(TYPE, A) ((TYPE)((long long)(0 <= (A) ? ((double)(A) + 0.5) : ((double)(A) - 0.5)))) +#define LIBXSMM_CONST_VOID_PTR(A) *((const void**)&(A)) + +/** Makes some functions available independent of C99 support. */ +#if defined(__STDC_VERSION__) && (199901L/*C99*/ <= __STDC_VERSION__) +# if defined(__PGI) +# define LIBXSMM_POWF(A, B) ((float)pow((float)(A), (float)(B))) +# else +# define LIBXSMM_POWF(A, B) powf(A, B) +# endif +# define LIBXSMM_FREXPF(A, B) frexpf(A, B) +# define LIBXSMM_ROUNDF(A) roundf(A) +# define LIBXSMM_ROUND(A) round(A) +# define LIBXSMM_TANHF(A) tanhf(A) +# define LIBXSMM_SQRTF(A) sqrtf(A) +# define LIBXSMM_EXP2F(A) exp2f(A) +# define LIBXSMM_LOG2F(A) log2f(A) +# define LIBXSMM_ERFF(A) erff(A) +# define LIBXSMM_EXP2(A) exp2(A) +# define LIBXSMM_LOG2(A) log2(A) +# define LIBXSMM_EXPF(A) expf(A) +# define LIBXSMM_LOGF(A) logf(A) +#else +# define LIBXSMM_POWF(A, B) ((float)pow((float)(A), (float)(B))) +# define LIBXSMM_FREXPF(A, B) ((float)frexp((float)(A), B)) +# define LIBXSMM_ROUNDF(A) LIBXSMM_ROUNDX(float, A) +# define LIBXSMM_ROUND(A) LIBXSMM_ROUNDX(double, A) +# define LIBXSMM_TANHF(A) ((float)tanh((float)(A))) +# define LIBXSMM_SQRTF(A) ((float)sqrt((float)(A))) +# define LIBXSMM_EXP2F(A) LIBXSMM_POWF(2, A) +# define LIBXSMM_LOG2F(A) ((float)LIBXSMM_LOG2((float)(A))) +# define LIBXSMM_ERFF(A) ((float)erf((float)(A))) +# define LIBXSMM_EXP2(A) pow(2.0, A) +# define LIBXSMM_LOG2(A) (log(A) * (1.0 / (M_LN2))) +# define LIBXSMM_EXPF(A) ((float)exp((float)(A))) +# define LIBXSMM_LOGF(A) ((float)log((float)(A))) +#endif + +#if defined(LIBXSMM_INTEL_COMPILER) +# if (1700 <= LIBXSMM_INTEL_COMPILER) +# define LIBXSMM_ASSUME(EXPRESSION) __assume(EXPRESSION) +# else +# define LIBXSMM_ASSUME(EXPRESSION) assert(EXPRESSION) +# endif +#elif defined(_MSC_VER) +# define LIBXSMM_ASSUME(EXPRESSION) __assume(EXPRESSION) +#elif defined(__GNUC__) && !defined(_CRAYC) && (LIBXSMM_VERSION2(4, 5) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) +# define LIBXSMM_ASSUME(EXPRESSION) do { if (!(EXPRESSION)) __builtin_unreachable(); } while(0) +#else +# define LIBXSMM_ASSUME(EXPRESSION) assert(EXPRESSION) +#endif + +#if defined(__INTEL_COMPILER) +# define LIBXSMM_ASSUME_ALIGNED(A, N) __assume_aligned(A, N) +#else +# define LIBXSMM_ASSUME_ALIGNED(A, N) assert(0 == ((uintptr_t)(A)) % (N)) +#endif +#define LIBXSMM_ALIGN(POINTER, ALIGNMENT/*POT*/) ((POINTER) + (LIBXSMM_UP2((uintptr_t)(POINTER), ALIGNMENT) - ((uintptr_t)(POINTER))) / sizeof(*(POINTER))) +#define LIBXSMM_FOLD2(POINTER, ALIGNMENT, NPOT) LIBXSMM_MOD2(((uintptr_t)(POINTER) / (ALIGNMENT)), NPOT) + +#if defined(_MSC_VER) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) /* account for incorrect handling of __VA_ARGS__ */ +# define LIBXSMM_SELECT_ELEMENT(INDEX1/*one-based*/, .../*elements*/) LIBXSMM_CONCATENATE(LIBXSMM_SELECT_ELEMENT_, INDEX1)LIBXSMM_EXPAND((__VA_ARGS__)) +#else +# define LIBXSMM_SELECT_ELEMENT(INDEX1/*one-based*/, .../*elements*/) LIBXSMM_CONCATENATE(LIBXSMM_SELECT_ELEMENT_, INDEX1)(__VA_ARGS__) +#endif +#define LIBXSMM_SELECT_ELEMENT_1(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E0 +#define LIBXSMM_SELECT_ELEMENT_2(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E1 +#define LIBXSMM_SELECT_ELEMENT_3(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E2 +#define LIBXSMM_SELECT_ELEMENT_4(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E3 +#define LIBXSMM_SELECT_ELEMENT_5(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E4 +#define LIBXSMM_SELECT_ELEMENT_6(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E5 +#define LIBXSMM_SELECT_ELEMENT_7(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E6 +#define LIBXSMM_SELECT_ELEMENT_8(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E7 +#define LIBXSMM_SELECT_ELEMENT_9(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E8 +#define LIBXSMM_SELECT_ELEMENT_10(E0, E1, E2, E3, E4, E5, E6, E7, E8, E9) E9 +#define LIBXSMM_SELECT_HEAD_AUX(A, ...) (A) +#define LIBXSMM_SELECT_HEAD(...) LIBXSMM_EXPAND(LIBXSMM_SELECT_HEAD_AUX(__VA_ARGS__, 0/*dummy*/)) +#define LIBXSMM_SELECT_TAIL(A, ...) __VA_ARGS__ + +/** + * For VLAs, check EXACTLY for C99 since a C11-conforming compiler may not provide VLAs. + * However, some compilers (Intel) may signal support for VLA even with strict ANSI (C89). + * To ultimately disable VLA-support, define LIBXSMM_NO_VLA (make VLA=0). + * VLA-support is signaled by LIBXSMM_VLA. + */ +#if !defined(LIBXSMM_VLA) && !defined(LIBXSMM_NO_VLA) && !defined(__PGI) && ( \ + (defined(__STDC_VERSION__) && (199901L/*C99*/ == __STDC_VERSION__ || (!defined(__STDC_NO_VLA__) && 199901L/*C99*/ < __STDC_VERSION__))) || \ + (defined(__GNUC__) && LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) && !defined(__STRICT_ANSI__) && !defined(__cplusplus)) || \ + (defined(LIBXSMM_INTEL_COMPILER) && !defined(_WIN32) && !defined(__cplusplus)) || \ + (defined(__INTEL_COMPILER) && !defined(_WIN32))) +# define LIBXSMM_VLA +#endif + +/** + * LIBXSMM_INDEX1 calculates the linear address for a given set of (multiple) indexes/bounds. + * Syntax: LIBXSMM_INDEX1(, , ..., , , ..., ). + * Please note that the leading dimension (s0) is omitted in the above syntax! + * TODO: support leading dimension (pitch/stride). + */ +#if defined(_MSC_VER) && !defined(__clang__) /* account for incorrect handling of __VA_ARGS__ */ +# define LIBXSMM_INDEX1(NDIMS, ...) LIBXSMM_CONCATENATE(LIBXSMM_INDEX1_, NDIMS)LIBXSMM_EXPAND((__VA_ARGS__)) +#else +# define LIBXSMM_INDEX1(NDIMS, ...) LIBXSMM_CONCATENATE(LIBXSMM_INDEX1_, NDIMS)(__VA_ARGS__) +#endif +#define LIBXSMM_INDEX1_1(...) ((size_t)LIBXSMM_SELECT_HEAD(__VA_ARGS__)) +#define LIBXSMM_INDEX1_2(I0, I1, S1) (LIBXSMM_INDEX1_1(I0) * ((size_t)S1) + (size_t)I1) +#define LIBXSMM_INDEX1_3(I0, I1, I2, S1, S2) (LIBXSMM_INDEX1_2(I0, I1, S1) * ((size_t)S2) + (size_t)I2) +#define LIBXSMM_INDEX1_4(I0, I1, I2, I3, S1, S2, S3) (LIBXSMM_INDEX1_3(I0, I1, I2, S1, S2) * ((size_t)S3) + (size_t)I3) +#define LIBXSMM_INDEX1_5(I0, I1, I2, I3, I4, S1, S2, S3, S4) (LIBXSMM_INDEX1_4(I0, I1, I2, I3, S1, S2, S3) * ((size_t)S4) + (size_t)I4) +#define LIBXSMM_INDEX1_6(I0, I1, I2, I3, I4, I5, S1, S2, S3, S4, S5) (LIBXSMM_INDEX1_5(I0, I1, I2, I3, I4, S1, S2, S3, S4) * ((size_t)S5) + (size_t)I5) +#define LIBXSMM_INDEX1_7(I0, I1, I2, I3, I4, I5, I6, S1, S2, S3, S4, S5, S6) (LIBXSMM_INDEX1_6(I0, I1, I2, I3, I4, I5, S1, S2, S3, S4, S5) * ((size_t)S6) + (size_t)I6) +#define LIBXSMM_INDEX1_8(I0, I1, I2, I3, I4, I5, I6, I7, S1, S2, S3, S4, S5, S6, S7) (LIBXSMM_INDEX1_7(I0, I1, I2, I3, I4, I5, I6, S1, S2, S3, S4, S5, S6) * ((size_t)S7) + (size_t)I7) +#define LIBXSMM_INDEX1_9(I0, I1, I2, I3, I4, I5, I6, I7, I8, S1, S2, S3, S4, S5, S6, S7, S8) (LIBXSMM_INDEX1_8(I0, I1, I2, I3, I4, I5, I6, I7, S1, S2, S3, S4, S5, S6, S7) * ((size_t)S8) + (size_t)I8) +#define LIBXSMM_INDEX1_10(I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, S1, S2, S3, S4, S5, S6, S7, S8, S9) (LIBXSMM_INDEX1_9(I0, I1, I2, I3, I4, I5, I6, I7, I8, S1, S2, S3, S4, S5, S6, S7, S8) * ((size_t)S9) + (size_t)I9) + +/** + * LIBXSMM_VLA_DECL declares an array according to the given set of (multiple) bounds. + * Syntax: LIBXSMM_VLA_DECL(, , , , , ..., ). + * The element type can be "const" or otherwise qualified; initial value must be (const)element-type*. + * Please note that the syntax is similar to LIBXSMM_INDEX1, and the leading dimension (s0) is omitted! + * + * LIBXSMM_VLA_ACCESS gives the array element according to the given set of (multiple) indexes/bounds. + * Syntax: LIBXSMM_VLA_ACCESS(, , , ..., , , ..., ). + * Please note that the syntax is similar to LIBXSMM_INDEX1, and the leading dimension (s0) is omitted! + */ +#if !defined(LIBXSMM_VLA_POSTFIX) +# define LIBXSMM_VLA_POSTFIX _ +#endif +#if defined(LIBXSMM_VLA) +LIBXSMM_API_INLINE int libxsmm_nonconst_int(int i) { return i; } +# define LIBXSMM_VLA_ACCESS(NDIMS, ARRAY, ...) LIBXSMM_VLA_ACCESS_ND(NDIMS, LIBXSMM_CONCATENATE(ARRAY, LIBXSMM_VLA_POSTFIX), LIBXSMM_VLA_ACCESS_SINK, __VA_ARGS__) +# define LIBXSMM_VLA_ACCESS_SINK(S) + 0 * (S) +# define LIBXSMM_VLA_ACCESS_NONCONST(I) libxsmm_nonconst_int(I) +# define LIBXSMM_VLA_ACCESS_ND(NDIMS, ARRAY, XY, ...) LIBXSMM_CONCATENATE3(LIBXSMM_VLA_ACCESS_, NDIMS, D)(ARRAY, XY, __VA_ARGS__) +# define LIBXSMM_VLA_ACCESS_0D(ARRAY, XY, ...) (ARRAY)/*scalar*/ +# define LIBXSMM_VLA_ACCESS_1D(ARRAY, XY, ...) ((ARRAY)[LIBXSMM_VLA_ACCESS_NONCONST(LIBXSMM_SELECT_HEAD(__VA_ARGS__))]) +# define LIBXSMM_VLA_ACCESS_2D(ARRAY, XY, I0, I1, ...) (((ARRAY) XY(__VA_ARGS__))[I0][LIBXSMM_VLA_ACCESS_NONCONST(I1)]) +# define LIBXSMM_VLA_ACCESS_3D(ARRAY, XY, I0, I1, I2, S1, ...) (((ARRAY) XY(S1) XY(__VA_ARGS__))[I0][I1][LIBXSMM_VLA_ACCESS_NONCONST(I2)]) +# define LIBXSMM_VLA_ACCESS_4D(ARRAY, XY, I0, I1, I2, I3, S1, S2, ...) (((ARRAY) XY(S1) XY(S2) XY(__VA_ARGS__))[I0][I1][I2][LIBXSMM_VLA_ACCESS_NONCONST(I3)]) +# define LIBXSMM_VLA_ACCESS_5D(ARRAY, XY, I0, I1, I2, I3, I4, S1, S2, S3, ...) (((ARRAY) XY(S1) XY(S2) XY(S3) XY(__VA_ARGS__))[I0][I1][I2][I3][LIBXSMM_VLA_ACCESS_NONCONST(I4)]) +# define LIBXSMM_VLA_ACCESS_6D(ARRAY, XY, I0, I1, I2, I3, I4, I5, S1, S2, S3, S4, ...) (((ARRAY) XY(S1) XY(S2) XY(S3) XY(S4) XY(__VA_ARGS__))[I0][I1][I2][I3][I4][LIBXSMM_VLA_ACCESS_NONCONST(I5)]) +# define LIBXSMM_VLA_ACCESS_7D(ARRAY, XY, I0, I1, I2, I3, I4, I5, I6, S1, S2, S3, S4, S5, ...) (((ARRAY) XY(S1) XY(S2) XY(S3) XY(S4) XY(S5) XY(__VA_ARGS__))[I0][I1][I2][I3][I4][I5][LIBXSMM_VLA_ACCESS_NONCONST(I6)]) +# define LIBXSMM_VLA_ACCESS_8D(ARRAY, XY, I0, I1, I2, I3, I4, I5, I6, I7, S1, S2, S3, S4, S5, S6, ...) (((ARRAY) XY(S1) XY(S2) XY(S3) XY(S4) XY(S5) XY(S6) XY(__VA_ARGS__))[I0][I1][I2][I3][I4][I5][I6][LIBXSMM_VLA_ACCESS_NONCONST(I7)]) +# define LIBXSMM_VLA_ACCESS_9D(ARRAY, XY, I0, I1, I2, I3, I4, I5, I6, I7, I8, S1, S2, S3, S4, S5, S6, S7, ...) (((ARRAY) XY(S1) XY(S2) XY(S3) XY(S4) XY(S5) XY(S6) XY(S7) XY(__VA_ARGS__))[I0][I1][I2][I3][I4][I5][I6][I7][LIBXSMM_VLA_ACCESS_NONCONST(I8)]) +# define LIBXSMM_VLA_ACCESS_10D(ARRAY, XY, I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, S1, S2, S3, S4, S5, S6, S7, S8, ...) (((ARRAY) XY(S1) XY(S2) XY(S3) XY(S4) XY(S5) XY(S6) XY(S7) XY(S8) XY(__VA_ARGS__))[I0][I1][I2][I3][I4][I5][I6][I7][I8][LIBXSMM_VLA_ACCESS_NONCONST(I9)]) +# define LIBXSMM_VLA_DECL(NDIMS, ELEMENT_TYPE, ARRAY_VAR, .../*initial value, and bounds*/) \ + ELEMENT_TYPE LIBXSMM_VLA_ACCESS_ND(LIBXSMM_SELECT_ELEMENT(NDIMS, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9), *LIBXSMM_RESTRICT LIBXSMM_CONCATENATE(ARRAY_VAR, LIBXSMM_VLA_POSTFIX), \ + LIBXSMM_ELIDE, LIBXSMM_SELECT_TAIL(__VA_ARGS__, 0)/*bounds*/, LIBXSMM_SELECT_TAIL(__VA_ARGS__, 0)/*dummy*/) = \ + (ELEMENT_TYPE LIBXSMM_VLA_ACCESS_ND(LIBXSMM_SELECT_ELEMENT(NDIMS, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9), *, \ + LIBXSMM_ELIDE, LIBXSMM_SELECT_TAIL(__VA_ARGS__, 0)/*bounds*/, LIBXSMM_SELECT_TAIL(__VA_ARGS__, 0)/*dummy*/))LIBXSMM_SELECT_HEAD(__VA_ARGS__) +#else /* calculate linear index */ +# define LIBXSMM_VLA_ACCESS(NDIMS, ARRAY, ...) LIBXSMM_CONCATENATE(ARRAY, LIBXSMM_VLA_POSTFIX)[LIBXSMM_INDEX1(NDIMS, __VA_ARGS__)] +# define LIBXSMM_VLA_DECL(NDIMS, ELEMENT_TYPE, ARRAY_VAR, .../*initial value, and bounds*/) \ + ELEMENT_TYPE *LIBXSMM_RESTRICT LIBXSMM_CONCATENATE(ARRAY_VAR, LIBXSMM_VLA_POSTFIX) = /*(ELEMENT_TYPE*)*/LIBXSMM_SELECT_HEAD(__VA_ARGS__) \ + + 0 * LIBXSMM_INDEX1(NDIMS, LIBXSMM_SELECT_TAIL(__VA_ARGS__, LIBXSMM_SELECT_TAIL(__VA_ARGS__, 0))) /* dummy-shift to "sink" unused arguments */ +#endif + +/** Access an array of TYPE with Byte-measured stride. */ +#define LIBXSMM_ACCESS(TYPE, ARRAY, STRIDE) (*(TYPE*)((char*)(ARRAY) + (STRIDE))) + +#if !defined(LIBXSMM_UNUSED) +# if 0 +# define LIBXSMM_UNUSED(VARIABLE) LIBXSMM_PRAGMA(unused(VARIABLE)) +# else +# define LIBXSMM_UNUSED(VARIABLE) (void)(VARIABLE) +# endif +#endif +#if !defined(NDEBUG) +# define LIBXSMM_UNUSED_DEBUG(VARIABLE) LIBXSMM_UNUSED(VARIABLE) +#else +# define LIBXSMM_UNUSED_DEBUG(VARIABLE) +#endif + +#if defined(_OPENMP) +# define LIBXSMM_PRAGMA_OMP(...) LIBXSMM_PRAGMA(omp __VA_ARGS__) +# if defined(_MSC_VER) && !defined(__INTEL_COMPILER) +# define LIBXSMM_OMP_VAR(A) LIBXSMM_UNUSED(A) /* suppress warning about "unused" variable */ +# elif defined(__clang__) +# define LIBXSMM_OMP_VAR(A) (A) = 0 +# else +# define LIBXSMM_OMP_VAR(A) +# endif +#else +# define LIBXSMM_PRAGMA_OMP(...) +# define LIBXSMM_OMP_VAR(A) +#endif + +#if defined(LIBXSMM_BUILD) && (defined(__GNUC__) || defined(__clang__)) && !defined(__CYGWIN__) && !defined(__MINGW32__) +# define LIBXSMM_ATTRIBUTE_WEAK_IMPORT LIBXSMM_ATTRIBUTE(weak_import) +# define LIBXSMM_ATTRIBUTE_WEAK LIBXSMM_ATTRIBUTE(weak) +#else +# define LIBXSMM_ATTRIBUTE_WEAK +# define LIBXSMM_ATTRIBUTE_WEAK_IMPORT +#endif + +#if !defined(LIBXSMM_NO_CTOR) && !defined(LIBXSMM_CTOR) && \ + (defined(__STDC_VERSION__) && (199901L <= __STDC_VERSION__)) && \ + (defined(LIBXSMM_BUILD) && !defined(__STATIC)) && \ + (defined(__GNUC__) || defined(__clang__)) +# define LIBXSMM_ATTRIBUTE_CTOR LIBXSMM_ATTRIBUTE(constructor) +# define LIBXSMM_ATTRIBUTE_DTOR LIBXSMM_ATTRIBUTE(destructor) +# define LIBXSMM_CTOR +#else +# define LIBXSMM_ATTRIBUTE_CTOR +# define LIBXSMM_ATTRIBUTE_DTOR +#endif + +#if defined(__GNUC__) && !defined(__PGI) && !defined(__ibmxl__) +# define LIBXSMM_ATTRIBUTE_NO_TRACE LIBXSMM_ATTRIBUTE(no_instrument_function) +#else +# define LIBXSMM_ATTRIBUTE_NO_TRACE +#endif + +#if defined(__GNUC__) +# define LIBXSMM_MAY_ALIAS LIBXSMM_ATTRIBUTE(__may_alias__) +#else +# define LIBXSMM_MAY_ALIAS +#endif + +#if !defined(LIBXSMM_MKTEMP_PATTERN) +# define LIBXSMM_MKTEMP_PATTERN "XXXXXX" +#endif + +/** Below group is to fix-up some platform/compiler specifics. */ +#if defined(_WIN32) +# if !defined(_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES) +# define _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES 1 +# endif +# if !defined(_CRT_SECURE_NO_DEPRECATE) +# define _CRT_SECURE_NO_DEPRECATE 1 +# endif +# if !defined(_USE_MATH_DEFINES) +# define _USE_MATH_DEFINES 1 +# endif +# if !defined(WIN32_LEAN_AND_MEAN) +# define WIN32_LEAN_AND_MEAN 1 +# endif +# if !defined(NOMINMAX) +# define NOMINMAX 1 +# endif +# if defined(__INTEL_COMPILER) && (190023506 <= _MSC_FULL_VER) +# define __builtin_huge_val() HUGE_VAL +# define __builtin_huge_valf() HUGE_VALF +# define __builtin_nan nan +# define __builtin_nanf nanf +# define __builtin_nans nan +# define __builtin_nansf nanf +# if defined(__cplusplus) +# define _CMATH_ +# endif +# endif +#endif +#if !defined(_GNU_SOURCE) && defined(LIBXSMM_BUILD) +# define _GNU_SOURCE +#endif +#if !defined(__STDC_FORMAT_MACROS) +# define __STDC_FORMAT_MACROS +#endif +#if defined(__clang__) && !defined(__extern_always_inline) +# define __extern_always_inline LIBXSMM_INLINE +#endif +#if defined(LIBXSMM_INLINE_FIXUP) && !defined(inline) +# define inline LIBXSMM_INLINE_KEYWORD +#endif + +#if (0 != LIBXSMM_SYNC) +# if !defined(_REENTRANT) +# define _REENTRANT +# endif +# if defined(__PGI) +# if defined(__GCC_ATOMIC_TEST_AND_SET_TRUEVAL) +# undef __GCC_ATOMIC_TEST_AND_SET_TRUEVAL +# endif +# define __GCC_ATOMIC_TEST_AND_SET_TRUEVAL 1 +# endif +#endif + +#if !defined(__has_feature) && !defined(__clang__) +# define __has_feature(A) 0 +#endif +#if !defined(__has_builtin) && !defined(__clang__) +# define __has_builtin(A) 0 +#endif + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif + +#if (0 != LIBXSMM_SYNC) +# if defined(_WIN32) || defined(__CYGWIN__) +# include +# else +# include +# endif +#endif +#if !defined(LIBXSMM_ASSERT) +# include +# if defined(NDEBUG) +# define LIBXSMM_ASSERT(EXPR) LIBXSMM_ASSUME(EXPR) +# else +# define LIBXSMM_ASSERT(EXPR) assert(EXPR) +# endif +#endif +#if !defined(LIBXSMM_ASSERT_MSG) +# define LIBXSMM_ASSERT_MSG(EXPR, MSG) assert((EXPR) && *MSG) +#endif +#if !defined(LIBXSMM_EXPECT_ELIDE) +# define LIBXSMM_EXPECT_ELIDE(RESULT, EXPR) do { \ + /*const*/ int libxsmm_expect_result_ = ((RESULT) == (EXPR)); \ + LIBXSMM_UNUSED(libxsmm_expect_result_); \ + } while(0) +#endif +#if defined(NDEBUG) +# define LIBXSMM_EXPECT LIBXSMM_EXPECT_ELIDE +# define LIBXSMM_EXPECT_NOT LIBXSMM_EXPECT_ELIDE +#else +# define LIBXSMM_EXPECT(RESULT, EXPR) LIBXSMM_ASSERT((RESULT) == (EXPR)) +# define LIBXSMM_EXPECT_NOT(RESULT, EXPR) LIBXSMM_ASSERT((RESULT) != (EXPR)) +#endif +#if defined(_DEBUG) +# define LIBXSMM_EXPECT_DEBUG LIBXSMM_EXPECT +#else +# define LIBXSMM_EXPECT_DEBUG LIBXSMM_EXPECT_ELIDE +#endif +#if defined(_OPENMP) && defined(LIBXSMM_SYNC_OMP) +# include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if !defined(FLT_MAX) +# if !defined(__FLT_MAX__) +# define FLT_MAX 3.40282346638528859811704183484516925e+38F +# else +# define FLT_MAX __FLT_MAX__ +# endif +#endif +#if !defined(FLT_MIN) +# if !defined(__FLT_MIN__) +# define FLT_MIN 1.17549435082228750796873653722224568e-38F +# else +# define FLT_MIN __FLT_MIN__ +# endif +#endif +#if defined(_WIN32) && 0 +# define LIBXSMM_SNPRINTF(S, N, ...) _snprintf_s(S, N, _TRUNCATE, __VA_ARGS__) +#elif defined(__STDC_VERSION__) && (199901L <= __STDC_VERSION__ || defined(__GNUC__)) +# define LIBXSMM_SNPRINTF(S, N, ...) snprintf(S, N, __VA_ARGS__) +#else +# define LIBXSMM_SNPRINTF(S, N, ...) sprintf((S) + /*unused*/(N) * 0, __VA_ARGS__) +#endif + +#if defined(__THROW) && defined(__cplusplus) +# define LIBXSMM_THROW __THROW +#endif +#if !defined(LIBXSMM_THROW) +# define LIBXSMM_THROW +#endif +#if defined(__GNUC__) && LIBXSMM_VERSION2(4, 2) == LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) && \ + !defined(__clang__) && !defined(__PGI) && !defined(__INTEL_COMPILER) && !defined(_CRAYC) +# define LIBXSMM_NOTHROW LIBXSMM_THROW +#else +# define LIBXSMM_NOTHROW +#endif +#if defined(__cplusplus) +# if (__cplusplus > 199711L) +# define LIBXSMM_NOEXCEPT noexcept +# else +# define LIBXSMM_NOEXCEPT throw() +# endif +#else +# define LIBXSMM_NOEXCEPT LIBXSMM_NOTHROW +#endif + +#if defined(_WIN32) +# define LIBXSMM_PUTENV(A) _putenv(A) +#else +# define LIBXSMM_PUTENV(A) putenv(A) +#endif + +/* block must be after including above header files */ +#if (defined(__GLIBC__) && defined(__GLIBC_MINOR__) && LIBXSMM_VERSION2(__GLIBC__, __GLIBC_MINOR__) < LIBXSMM_VERSION2(2, 26)) \ + || (defined(LIBXSMM_INTEL_COMPILER) && (1802 >= LIBXSMM_INTEL_COMPILER) && !defined(__cplusplus) && defined(__linux__)) +/* _Float128 was introduced with GNU GCC 7.0. */ +# if !defined(_Float128) && !defined(__SIZEOF_FLOAT128__) && defined(__GNUC__) && !defined(__cplusplus) && defined(__linux__) +# define _Float128 __float128 +# endif +# if !defined(LIBXSMM_GLIBC_FPTYPES) && defined(__GNUC__) && !defined(__cplusplus) && defined(__linux__) \ + && (LIBXSMM_VERSION2(7, 0) > LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) || \ + (defined(LIBXSMM_INTEL_COMPILER) && (1802 >= LIBXSMM_INTEL_COMPILER))) +# define LIBXSMM_GLIBC_FPTYPES +# endif +# if !defined(_Float128X) && defined(LIBXSMM_GLIBC_FPTYPES) +# define _Float128X _Float128 +# endif +# if !defined(_Float32) && defined(LIBXSMM_GLIBC_FPTYPES) +# define _Float32 float +# endif +# if !defined(_Float32x) && defined(LIBXSMM_GLIBC_FPTYPES) +# define _Float32x _Float32 +# endif +# if !defined(_Float64) && defined(LIBXSMM_GLIBC_FPTYPES) +# define _Float64 double +# endif +# if !defined(_Float64x) && defined(LIBXSMM_GLIBC_FPTYPES) +# define _Float64x _Float64 +# endif +#endif + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#if defined(LIBXSMM_GLIBC_FPTYPES) +# if defined(__cplusplus) +# undef __USE_MISC +# if !defined(_DEFAULT_SOURCE) +# define _DEFAULT_SOURCE +# endif +# if !defined(_BSD_SOURCE) +# define _BSD_SOURCE +# endif +# else +# if !defined(__PURE_INTEL_C99_HEADERS__) +# define __PURE_INTEL_C99_HEADERS__ +# endif +# endif +#endif +#if !defined(LIBXSMM_NO_LIBM) +# if (defined(LIBXSMM_INTEL_COMPILER) && (1800 <= LIBXSMM_INTEL_COMPILER)) \ + && !defined(_WIN32) /* error including dfp754.h */ +# include +# endif +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#endif /*LIBXSMM_MACROS_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_malloc.h b/third_party/libxsmm/include/libxsmm_malloc.h new file mode 100644 index 0000000000000000000000000000000000000000..3f978fea287af2f96ccd9bafa21501d66114ac6a --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_malloc.h @@ -0,0 +1,397 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_MALLOC_H +#define LIBXSMM_MALLOC_H + +#include "libxsmm_memory.h" + +/* include tensorflow/core/public/version.h prior to LIBXSMM otherwise the current TensorFlow API is assumed */ +#if !defined(LIBXSMM_TF12) && (!defined(TF_VERSION_STRING) || \ + LIBXSMM_VERSION2(1, 12) <= LIBXSMM_VERSION2(TF_MAJOR_VERSION, TF_MINOR_VERSION)) +# define LIBXSMM_TF12 /* TF_PATCH_VERSION does not matter */ +#endif + +/** Can be used with libxsmm_[get|set]_scratch_limit. */ +#define LIBXSMM_SCRATCH_UNLIMITED ((size_t)LIBXSMM_UNLIMITED) +#define LIBXSMM_SCRATCH_DEFAULT 0 + + +/** Function types accepted for memory allocation (see libxsmm_*_allocator). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void* (*libxsmm_malloc_ctx)(size_t /*size*/, const void* /*context*/); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void* (*libxsmm_malloc_fun)(size_t /*size*/); +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_malloc_function { + libxsmm_malloc_ctx ctx_form; + libxsmm_malloc_fun function; +} libxsmm_malloc_function; + +/** Function types accepted for releasing memory (see libxsmm_*_allocator). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_free_ctx)(void* /*buffer*/, const void* /*context*/); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_free_fun)(void* /*buffer*/); +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_free_function { + libxsmm_free_ctx ctx_form; + libxsmm_free_fun function; +} libxsmm_free_function; + +/** + * To setup the custom default memory allocator, either a malloc_fn and a free_fn + * are given, or two NULL-pointers designate to reset the default allocator to a + * library-internal default. If a context is given (non-NULL), the context-based + * form of the memory allocation is used. + * Changing the allocator including the function for deallocation applies to + * upcoming allocation/deallocation and works correctly for pending buffers. + */ +LIBXSMM_API int libxsmm_set_default_allocator(/* malloc_fn/free_fn must correspond */ + const void* context, libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn); +/** Retrieve the default memory allocator. */ +LIBXSMM_API int libxsmm_get_default_allocator(const void** context, + libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn); + +/** + * To setup the scratch memory allocator, a malloc_fn function and an optional free_fn + * are given. A NULL-free acts as a "no-operation", and the deallocation is expected + * to be controlled otherwise. If two NULL-pointers are given, the allocator is reset + * to the currently active default memory allocator. If a context is given (non-NULL), + * the context-based form of the memory allocation is used. + * Changing the allocator including the function for deallocation applies to + * upcoming allocation/deallocation and works correctly for pending buffers. + */ +LIBXSMM_API int libxsmm_set_scratch_allocator(/* malloc_fn/free_fn must correspond */ + const void* context, libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn); +/** Retrieve the scratch memory allocator. */ +LIBXSMM_API int libxsmm_get_scratch_allocator(const void** context, + libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn); + +/** Allocate memory (malloc/free interface). */ +LIBXSMM_API LIBXSMM_ATTRIBUTE_MALLOC void* libxsmm_malloc(size_t size); + +/** Allocate aligned memory using the default allocator. */ +LIBXSMM_API LIBXSMM_ATTRIBUTE_MALLOC void* libxsmm_aligned_malloc(size_t size, + /** + * =0: align automatically according to the size + * 0<: align according to the alignment value + */ + size_t alignment); + +/** Reallocate memory using the default allocator (alignment is preserved). */ +LIBXSMM_API void* libxsmm_realloc(size_t size, void* ptr); + +/** + * Allocate aligned scratch memory. It is not supported + * to query properties per libxsmm_get_malloc_info, but + * libxsmm_get_scratch_info can used instead. + */ +LIBXSMM_API void* libxsmm_scratch_malloc(size_t size, + /** + * =0: align automatically according to the size + * 0<: align according to the alignment value + */ + size_t alignment, + /** + * Identifies the call site, which is used + * to determine the memory pool. + */ + const void* caller); + +/** + * Binary form of libxsmm_scratch_malloc, which + * expands the call-context automatically. This + * macro is intentionally lower case. + */ +#define libxsmm_aligned_scratch(size, alignment) \ + libxsmm_scratch_malloc(size, alignment, \ + LIBXSMM_CALLER_ID) + +/** Deallocate memory (malloc/free interface). */ +LIBXSMM_API void libxsmm_free(const void* memory); + +/** + * Release the entire scratch memory regardless + * of whether it is still referenced or not. + */ +LIBXSMM_API void libxsmm_release_scratch(void); + +/** Information about a buffer (default memory domain). */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_malloc_info { + /** Size of the buffer. */ + size_t size; +} libxsmm_malloc_info; + +/** Retrieve information about a buffer (default memory domain). */ +LIBXSMM_API int libxsmm_get_malloc_info(const void* memory, libxsmm_malloc_info* info); + +/** Information about the scratch memory domain. */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_scratch_info { + /** Watermark memory across pools (size), unsatisfied (local), and library-internal memory. */ + size_t size, local, internal; + /** Pending allocations (not released). */ + size_t npending; + /** Number of allocations so far. */ + size_t nmallocs; + /** Number of pools used. */ + unsigned int npools; +} libxsmm_scratch_info; + +/** Retrieve information about the scratch memory domain. */ +LIBXSMM_API int libxsmm_get_scratch_info(libxsmm_scratch_info* info); + +/** + * Limit the total size (Bytes) of the scratch memory. + * LIBXSMM_SCRATCH_UNLIMITED removes any limit, and + * LIBXSMM_SCRATCH_DEFAULT populates the default. + * The related environment variable LIBXSMM_SCRATCH_LIMIT + * allows units: /b/B (Bytes), k/K, m/M, and g/G. + */ +LIBXSMM_API void libxsmm_set_scratch_limit(size_t nbytes); +/** Get the maximum size of the scratch memory domain. */ +LIBXSMM_API size_t libxsmm_get_scratch_limit(void); + +/** + * Intercepts malloc/free to use scratch memory allocator. + * (related environment variable LIBXSMM_MALLOC). + * Optionally set the range of malloc-sizes to be intercepted. + * The related environment variable LIBXSMM_MALLOC_LIMIT + * allows units: /b/B (Bytes), k/K, m/M, and g/G. + */ +LIBXSMM_API void libxsmm_set_malloc(int enabled, const size_t* lo, const size_t* hi); +/** + * Determines if malloc/free are (and can be) intercepted. + * Optionally gets the range of enabled malloc-sizes. + */ +LIBXSMM_API int libxsmm_get_malloc(size_t* lo, size_t* hi); + +/** + * Calculate the linear offset of the n-dimensional (ndims) offset (can be NULL), + * and the (optional) linear size of the corresponding shape. + */ +LIBXSMM_API size_t libxsmm_offset(const size_t offset[], const size_t shape[], size_t ndims, size_t* size); + + +#if defined(__cplusplus) + +/** RAII idiom to temporarily setup an allocator for the lifetime of the scope. */ +template class LIBXSMM_RETARGETABLE libxsmm_scoped_allocator { +public: + /** C'tor, which instantiates the new allocator (plain form). */ + libxsmm_scoped_allocator(libxsmm_malloc_fun malloc_fn, libxsmm_free_fun free_fn) { + kind::get(m_context, m_malloc, m_free); + kind::set(NULL/*context*/, NULL/*malloc_ctx*/, NULL/*free_ctx*/, malloc_fn, free_fn); + } + + /** C'tor, which instantiates the new allocator (context form). */ + libxsmm_scoped_allocator(const void* context, libxsmm_malloc_ctx malloc_ctx, libxsmm_free_ctx free_ctx, + libxsmm_malloc_fun malloc_fun = NULL, libxsmm_free_fun free_fun = NULL) + { + kind::get(m_context, m_malloc, m_free); + kind::set(context, malloc_ctx, free_ctx, malloc_fun, free_fun); + } + + /** Following the RAII idiom, the d'tor restores the previous allocator. */ + ~libxsmm_scoped_allocator() { + kind::set(m_context, + m_malloc.ctx_form, m_free.ctx_form, + m_malloc.function, m_free.function); + } + +private: /* no copy/assignment */ + explicit libxsmm_scoped_allocator(const libxsmm_scoped_allocator&); + libxsmm_scoped_allocator& operator=(const libxsmm_scoped_allocator&); + +protected: /* saved/previous allocator */ + const void* m_context; + libxsmm_malloc_function m_malloc; + libxsmm_free_function m_free; +}; + +/** Allocator-kind to instantiate libxsmm_scoped_allocator. */ +struct LIBXSMM_RETARGETABLE libxsmm_default_allocator { + static void set(const void* context, + libxsmm_malloc_ctx malloc_ctx, libxsmm_free_ctx free_ctx, + libxsmm_malloc_fun malloc_fun, libxsmm_free_fun free_fun) + { + libxsmm_malloc_function malloc_fn; + libxsmm_free_function free_fn; + if (NULL == context) { /* use global form only when no context is given */ + malloc_fn.function = malloc_fun; free_fn.function = free_fun; + } + else { + malloc_fn.ctx_form = malloc_ctx; free_fn.ctx_form = free_ctx; + } + libxsmm_set_default_allocator(context, malloc_fn, free_fn); + } + static void get(const void*& context, + libxsmm_malloc_function& malloc_fn, libxsmm_free_function& free_fn) + { + libxsmm_get_default_allocator(&context, &malloc_fn, &free_fn); + } +}; + +/** Allocator-kind to instantiate libxsmm_scoped_allocator. */ +struct LIBXSMM_RETARGETABLE libxsmm_scratch_allocator { + static void set(const void* context, + libxsmm_malloc_ctx malloc_ctx, libxsmm_free_ctx free_ctx, + libxsmm_malloc_fun malloc_fun, libxsmm_free_fun free_fun) + { + libxsmm_malloc_function malloc_fn; + libxsmm_free_function free_fn; + if (NULL != context) { /* adopt context form */ + malloc_fn.function = malloc_fun; free_fn.function = free_fun; + } + else { /* adopt global form */ + malloc_fn.ctx_form = malloc_ctx; free_fn.ctx_form = free_ctx; + } + libxsmm_set_scratch_allocator(context, malloc_fn, free_fn); + } + static void get(const void*& context, + libxsmm_malloc_function& malloc_fn, libxsmm_free_function& free_fn) + { + libxsmm_get_scratch_allocator(&context, &malloc_fn, &free_fn); + } +}; + +/** Forward-declared types/functions used to implement libxsmm_tf_allocator. */ +namespace tensorflow { + class Allocator; +#if defined(LIBXSMM_TF12) + class DeviceBase; int DeviceNumaNode(const DeviceBase* /*device*/); + Allocator* cpu_allocator(int /*numa_node*/); +#else + Allocator* cpu_allocator(); +#endif +} + +/** + * An object of this type adopts a memory allocator from TensorFlow. + * All memory allocations of the requested kind within the current + * scope (where the libxsmm_tf_allocator object lives) are subject + * to TensorFlow's memory allocation scheme. The allocation kind + * is usually "libxsmm_scratch_allocator"; using a second object + * of kind "libxsmm_default_allocator" makes the default memory + * allocation of LIBXSMM subject to TensorFlow as well. + */ +template class LIBXSMM_RETARGETABLE libxsmm_tf_allocator: + public libxsmm_scoped_allocator +{ +public: + /** The TensorFlow allocator is adopted from the global CPU memory allocator. */ + explicit libxsmm_tf_allocator() + : libxsmm_scoped_allocator( + libxsmm_tf_allocator::malloc, + libxsmm_tf_allocator::free) + {} + + /** The TensorFlow allocator is adopted from the given OpKernelContext. */ + template + explicit libxsmm_tf_allocator(context_type& context) + : libxsmm_scoped_allocator(&context, + libxsmm_tf_allocator::template malloc_ctx, + libxsmm_tf_allocator::template free_ctx, + libxsmm_tf_allocator::malloc, + libxsmm_tf_allocator::free) + {} + + /** Global form of allocating memory (malloc signature). */ + static void* malloc(size_t size) { +#if defined(LIBXSMM_TF12) + return libxsmm_tf_allocator::allocate(tensorflow::cpu_allocator(-1/*kNUMANoAffinity*/), size); +#else + return libxsmm_tf_allocator::allocate(tensorflow::cpu_allocator(), size); +#endif + } + + /** Global form of deallocating memory (free signature). */ + static void free(void* buffer) { +#if defined(LIBXSMM_TF12) + libxsmm_tf_allocator::deallocate(tensorflow::cpu_allocator(-1/*kNUMANoAffinity*/), buffer); +#else + libxsmm_tf_allocator::deallocate(tensorflow::cpu_allocator(), buffer); +#endif + } + + /** Context based form of allocating memory. */ + template static void* malloc_ctx(const void* context, size_t size) { + typedef typename context_type::WrappedAllocator::first_type allocator_ptr; + context_type *const tf_context = static_cast(context); + allocator_ptr allocator = NULL; + if (NULL != tf_context) { +#if !defined(LIBXSMM_TF12) + if (NULL != tf_context->device()) { + if (0 < tf_context->num_outputs()) { + allocator = tf_context->device()->GetStepAllocator( + tf_context->output_alloc_attr(0), + tf_context->resource_manager()); + } + else if (0 < tf_context->num_inputs()) { + allocator = tf_context->device()->GetStepAllocator( + tf_context->input_alloc_attr(0), + tf_context->resource_manager()); + } + } +#else /* include tensorflow/core/public/version.h prior to LIBXSMM otherwise the current TensorFlow API is assumed */ + const int numa_node = DeviceNumaNode(tf_context->device()); + allocator = tensorflow::cpu_allocator(numa_node); +#endif + } + return libxsmm_tf_allocator::allocate(allocator, size); + } + + /** Context based form of deallocating memory. */ + template static void free_ctx(const void* context, void* buffer) { + typedef typename context_type::WrappedAllocator::first_type allocator_ptr; + context_type *const tf_context = static_cast(context); + allocator_ptr allocator = NULL; + if (NULL != tf_context) { +#if defined(LIBXSMM_TF12) + const int numa_node = DeviceNumaNode(tf_context->device()); + allocator = tensorflow::cpu_allocator(numa_node); +#else + if (NULL != tf_context->device()) { + if (0 < tf_context->num_outputs()) { + allocator = tf_context->device()->GetStepAllocator( + tf_context->output_alloc_attr(0), + tf_context->resource_manager()); + } + else if (0 < tf_context->num_inputs()) { + allocator = tf_context->device()->GetStepAllocator( + tf_context->input_alloc_attr(0), + tf_context->resource_manager()); + } + } +#endif + } + libxsmm_tf_allocator::deallocate(allocator, buffer); + } + +private: + template /* break interface dependency with TF */ + static void* allocate(allocator_ptr allocator, size_t size) { + void* result; + if (NULL != allocator) { + /* no (useless) waste with alignment; raw result is re-aligned anyways */ + result = allocator->AllocateRaw(1/*alignment*/, size); + } + else { + LIBXSMM_ASSERT_MSG(0/*false*/, "LIBXSMM ERROR: memory allocator is missing"); + result = NULL; + } + return result; + } + + template /* break interface dependency with TF */ + static void deallocate(allocator_ptr allocator, void* buffer) { + LIBXSMM_ASSERT_MSG(NULL != allocator, "LIBXSMM ERROR: memory allocator is missing"); + if (NULL != allocator) allocator->DeallocateRaw(buffer); + } +}; + +#endif /*defined(__cplusplus)*/ + +#endif /*LIBXSMM_MALLOC_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_math.h b/third_party/libxsmm/include/libxsmm_math.h new file mode 100644 index 0000000000000000000000000000000000000000..f6514228ee3f752d2d6ade0e47dc352a25492026 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_math.h @@ -0,0 +1,140 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_MATH_H +#define LIBXSMM_MATH_H + +#include "libxsmm_typedefs.h" + + +/** + * Structure of differences with matrix norms according + * to http://www.netlib.org/lapack/lug/node75.html). + */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_matdiff_info { + /** One-norm */ double norm1_abs, norm1_rel; + /** Infinity-norm */ double normi_abs, normi_rel; + /** Froebenius-norm */ double normf_rel; + /** Maximum difference, L2-norm (absolute and relative), and R-squared. */ + double linf_abs, linf_rel, l2_abs, l2_rel, rsq; + /** Statistics: sum/l1, min., max., arith. avg., and variance. */ + double l1_ref, min_ref, max_ref, avg_ref, var_ref; + /** Statistics: sum/l1, min., max., arith. avg., and variance. */ + double l1_tst, min_tst, max_tst, avg_tst, var_tst; + /** Values (v_ref, v_tst) and location (m, n) of largest linf_abs. */ + double v_ref, v_tst; + libxsmm_blasint m, n; +} libxsmm_matdiff_info; + +/** + * Utility function to calculate a collection of scalar differences between two matrices (libxsmm_matdiff_info). + * The location (m, n) of the largest difference (linf_abs) is recorded (also in case of NaN). In case of NaN, + * differences are set to infinity. If no difference is discovered, the location (m, n) is negative (OOB). + */ +LIBXSMM_API int libxsmm_matdiff(libxsmm_matdiff_info* info, + libxsmm_datatype datatype, libxsmm_blasint m, libxsmm_blasint n, const void* ref, const void* tst, + const libxsmm_blasint* ldref, const libxsmm_blasint* ldtst); + +/** + * Reduces input into output such that the difference is maintained or increased (max function). + * The very first (initial) output should be zeroed (libxsmm_matdiff_clear). + */ +LIBXSMM_API void libxsmm_matdiff_reduce(libxsmm_matdiff_info* output, const libxsmm_matdiff_info* input); +/** Clears the given info-structure, e.g., for the initial reduction-value (libxsmm_matdiff_reduce). */ +LIBXSMM_API void libxsmm_matdiff_clear(libxsmm_matdiff_info* info); + +/** Greatest common divisor (corner case: the GCD of 0 and 0 is 1). */ +LIBXSMM_API size_t libxsmm_gcd(size_t a, size_t b); +/** Least common multiple. */ +LIBXSMM_API size_t libxsmm_lcm(size_t a, size_t b); + +/** + * This function finds prime-factors (up to 32) of an unsigned integer in ascending order, and + * returns the number of factors found (zero if the given number is prime and unequal to two). + */ +LIBXSMM_API int libxsmm_primes_u32(unsigned int num, unsigned int num_factors_n32[]); + +/** Calculate co-prime number <= n/2 (except: libxsmm_shuffle(0|1) == 0). */ +LIBXSMM_API size_t libxsmm_shuffle(unsigned int n); + +/** + * Divides the product into prime factors and selects factors such that the new product is within + * the given limit (0/1-Knapsack problem), e.g., product=12=2*2*3 and limit=6 then result=2*3=6. + * The limit is at least reached or exceeded with the minimal possible product (is_lower=true). + */ +LIBXSMM_API unsigned int libxsmm_product_limit(unsigned int product, unsigned int limit, int is_lower); + +/* Kahan's summation returns accumulator += value and updates compensation. */ +LIBXSMM_API double libxsmm_kahan_sum(double value, double* accumulator, double* compensation); + +/** SQRT with Newton's method using integer arithmetic. */ +LIBXSMM_API unsigned int libxsmm_isqrt_u64(unsigned long long x); +/** SQRT with Newton's method using integer arithmetic. */ +LIBXSMM_API unsigned int libxsmm_isqrt_u32(unsigned int x); +/** Based on libxsmm_isqrt_u32, but actual factor of x. */ +LIBXSMM_API unsigned int libxsmm_isqrt2_u32(unsigned int x); +/** SQRT with Newton's method using double-precision. */ +LIBXSMM_API double libxsmm_dsqrt(double x); +/** SQRT with Newton's method using single-precision. */ +LIBXSMM_API float libxsmm_ssqrt(float x); + +/** CBRT with Newton's method using integer arithmetic. */ +LIBXSMM_API unsigned int libxsmm_icbrt_u64(unsigned long long x); +/** CBRT with Newton's method using integer arithmetic. */ +LIBXSMM_API unsigned int libxsmm_icbrt_u32(unsigned int x); + +/** Single-precision approximation of exponential function (base 2). */ +LIBXSMM_API float libxsmm_sexp2(float x); + +/** + * Exponential function (base 2), which is limited to unsigned 8-bit input values. + * This function reproduces bit-accurate results (single-precision). + */ +LIBXSMM_API float libxsmm_sexp2_u8(unsigned char x); + +/** +* Exponential function (base 2), which is limited to signed 8-bit input values. +* This function reproduces bit-accurate results (single-precision). +*/ +LIBXSMM_API float libxsmm_sexp2_i8(signed char x); + +/** Similar to libxsmm_sexp2_i8, but takes an integer as signed 8-bit value (check). */ +LIBXSMM_API float libxsmm_sexp2_i8i(int x); + +/** Inlineable fast tanh, such that a the compiler can potentially vectorize. */ +LIBXSMM_API_INLINE float libxsmm_stanh_pade78(float i_x) { + const float l_c0 = 2027025.0f; + const float l_c1 = 270270.0f; + const float l_c2 = 6930.0f; + const float l_c3 = 36.0f; + const float l_c1_d = 945945.0f; + const float l_c2_d = 51975.0f; + const float l_c3_d = 630.0f; + const float l_hi_bound = 4.97f; + const float l_lo_bound = -4.97f; + const float l_ones = 1.0f; + const float l_neg_ones = -1.0f; + const float x2 = i_x * i_x; + const float t1_nom = (l_c3 * x2) + l_c2; + const float t2_nom = (t1_nom * x2) + l_c1; + const float t3_nom = (t2_nom * x2) + l_c0; + const float nom = t3_nom * i_x; + const float t1_denom = x2 + l_c3_d; + const float t2_denom = (t1_denom * x2) + l_c2_d; + const float t3_denom = (t2_denom * x2) + l_c1_d; + const float denom = (t3_denom * x2) + l_c0; + float result = nom/denom ; + result = (result > l_hi_bound) ? l_ones : result; + result = (result < l_lo_bound) ? l_neg_ones : result; + return result; +} + +#endif /*LIBXSMM_MATH_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_memory.h b/third_party/libxsmm/include/libxsmm_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..53d4ed763077a9408cabfc117c8e19fd2177f649 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_memory.h @@ -0,0 +1,85 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_MEMORY_H +#define LIBXSMM_MEMORY_H + +#include "libxsmm_macros.h" + +#if defined(__clang_analyzer__) +# define LIBXSMM_MEMSET127(PTRDST, VALUE, SIZE) memset((void*)(PTRDST), VALUE, SIZE) +#else +# define LIBXSMM_MEMSET127(PTRDST, VALUE, SIZE) { \ + char *const libxsmm_memset127_dst_ = (char*)(PTRDST); \ + union { size_t size; signed char size1; } libxsmm_memset127_; \ + signed char libxsmm_memset127_i_; LIBXSMM_ASSERT((SIZE) <= 127); \ + libxsmm_memset127_.size = (SIZE); \ + LIBXSMM_PRAGMA_UNROLL \ + for (libxsmm_memset127_i_ = 0; libxsmm_memset127_i_ < libxsmm_memset127_.size1; \ + ++libxsmm_memset127_i_) \ + { \ + libxsmm_memset127_dst_[libxsmm_memset127_i_] = (char)(VALUE); \ + } \ +} +#endif +#define LIBXSMM_MEMZERO127(PTRDST) LIBXSMM_MEMSET127(PTRDST, '\0', sizeof(*(PTRDST))) + +#define LIBXSMM_MEMCPY127_LOOP(PTRDST, PTRSRC, SIZE, NTS) { \ + const unsigned char *const libxsmm_memcpy127_loop_src_ = (const unsigned char*)(PTRSRC); \ + unsigned char *const libxsmm_memcpy127_loop_dst_ = (unsigned char*)(PTRDST); \ + signed char libxsmm_memcpy127_loop_i_; LIBXSMM_ASSERT((SIZE) <= 127); \ + NTS(libxsmm_memcpy127_loop_dst_) LIBXSMM_PRAGMA_UNROLL \ + for (libxsmm_memcpy127_loop_i_ = 0; libxsmm_memcpy127_loop_i_ < (signed char)(SIZE); \ + ++libxsmm_memcpy127_loop_i_) \ + { \ + libxsmm_memcpy127_loop_dst_[libxsmm_memcpy127_loop_i_] = \ + libxsmm_memcpy127_loop_src_[libxsmm_memcpy127_loop_i_]; \ + } \ +} +#define LIBXSMM_MEMCPY127_NTS(...) +#define LIBXSMM_MEMCPY127(PTRDST, PTRSRC, SIZE) \ + LIBXSMM_MEMCPY127_LOOP(PTRDST, PTRSRC, SIZE, LIBXSMM_MEMCPY127_NTS) +#define LIBXSMM_ASSIGN127(PTRDST, PTRSRC) LIBXSMM_ASSERT(sizeof(*(PTRSRC)) <= sizeof(*(PTRDST))); \ + LIBXSMM_MEMCPY127(PTRDST, PTRSRC, sizeof(*(PTRSRC))) + + +/** + * Calculates if there is a difference between two (short) buffers. + * Returns zero if there is no difference; otherwise non-zero. + */ +LIBXSMM_API unsigned char libxsmm_diff(const void* a, const void* b, unsigned char size); + +/** + * Calculates if there is a difference between "a" and "n x b". + * Returns the index of the first match (or "n" in case of no match). + */ +LIBXSMM_API unsigned int libxsmm_diff_n(const void* a, const void* bn, unsigned char size, + unsigned char stride, unsigned int hint, unsigned int n); + +/** Similar to memcmp (C standard library), but the result is conceptually only a boolean. */ +LIBXSMM_API int libxsmm_memcmp(const void* a, const void* b, size_t size); + +/** Calculate a hash value for the given buffer and seed; accepts NULL-buffer. */ +LIBXSMM_API unsigned int libxsmm_hash(const void* data, unsigned int size, unsigned int seed); + +/** Calculate a 64-bit hash for the given character string; accepts NULL-string. */ +LIBXSMM_API unsigned long long libxsmm_hash_string(const char* string); + +/** Return the pointer to the 1st match of "b" in "a", or NULL (no match). */ +LIBXSMM_API const char* libxsmm_stristr(const char* a, const char* b); + +/** + * Check if pointer is SIMD-aligned and optionally consider the next access (increment in Bytes). + * Optionally calculates the alignment of the given pointer in Bytes. + */ +LIBXSMM_API int libxsmm_aligned(const void* ptr, const size_t* inc, int* alignment); + +#endif /*LIBXSMM_MEMORY_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_mhd.h b/third_party/libxsmm/include/libxsmm_mhd.h new file mode 100644 index 0000000000000000000000000000000000000000..ab4cf17411fb9a04e543f9d43291414c3bc7c020 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_mhd.h @@ -0,0 +1,167 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_MHD_H +#define LIBXSMM_MHD_H + +#include "libxsmm_typedefs.h" + + +/** Denotes the element/pixel type of an image/channel. */ +typedef enum libxsmm_mhd_elemtype { + LIBXSMM_MHD_ELEMTYPE_F64 = LIBXSMM_DATATYPE_F64, /* MET_DOUBLE */ + LIBXSMM_MHD_ELEMTYPE_F32 = LIBXSMM_DATATYPE_F32, /* MET_FLOAT */ + LIBXSMM_MHD_ELEMTYPE_BF16 = LIBXSMM_DATATYPE_BF16, /* MET_BFLOAT */ + LIBXSMM_MHD_ELEMTYPE_I64 = LIBXSMM_DATATYPE_I64, /* MET_LONG */ + LIBXSMM_MHD_ELEMTYPE_I32 = LIBXSMM_DATATYPE_I32, /* MET_INT */ + LIBXSMM_MHD_ELEMTYPE_I16 = LIBXSMM_DATATYPE_I16, /* MET_SHORT */ + LIBXSMM_MHD_ELEMTYPE_I8 = LIBXSMM_DATATYPE_I8, /* MET_CHAR */ + LIBXSMM_MHD_ELEMTYPE_U64 = LIBXSMM_DATATYPE_UNSUPPORTED, /* MET_ULONG */ + LIBXSMM_MHD_ELEMTYPE_U32, /* MET_UINT */ + LIBXSMM_MHD_ELEMTYPE_U16, /* MET_USHORT */ + LIBXSMM_MHD_ELEMTYPE_U8, /* MET_UCHAR */ + LIBXSMM_MHD_ELEMTYPE_UNKNOWN +} libxsmm_mhd_elemtype; + + +/** + * Function type used for custom data-handler or element conversion. + * The value-range (src_min, src_max) may be used to scale values + * in case of a type-conversion. + */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE int (*libxsmm_mhd_element_handler)( + void* dst, libxsmm_mhd_elemtype dst_type, libxsmm_mhd_elemtype src_type, + const void* src, const void* src_min, const void* src_max); + +/** + * Predefined function to perform element data conversion. + * Scales source-values in case of non-NULL src_min and src_max, + * or otherwise clamps to the destination-type. + */ +LIBXSMM_API int libxsmm_mhd_element_conversion( + void* dst, libxsmm_mhd_elemtype dst_type, libxsmm_mhd_elemtype src_type, + const void* src, const void* src_min, const void* src_max); + +/** + * Predefined function to check a buffer against file content. + * In case of different types, libxsmm_mhd_element_conversion + * is performed to compare values using the source-type. + */ +LIBXSMM_API int libxsmm_mhd_element_comparison( + void* dst, libxsmm_mhd_elemtype dst_type, libxsmm_mhd_elemtype src_type, + const void* src, const void* src_min, const void* src_max); + + +/** Returns the name and size of the element type; result may be NULL/0 in case of an unknown type. */ +LIBXSMM_API const char* libxsmm_mhd_typename(libxsmm_mhd_elemtype type, size_t* typesize, const char** ctypename); + +/** Returns the type of the element for a given type-name. */ +LIBXSMM_API libxsmm_mhd_elemtype libxsmm_mhd_typeinfo(const char elemname[]); + + +/** + * Parse the header of an MHD-file. The header can be part of the data file (local), + * or separately stored (header: MHD, data MHA or RAW). + */ +LIBXSMM_API int libxsmm_mhd_read_header( + /* Filename referring to the header-file (may also contain the data). */ + const char header_filename[], + /* Maximum length of path/file name. */ + size_t filename_max_length, + /* Filename containing the data (may be the same as the header-file). */ + char filename[], + /* Yields the maximum/possible number of dimensions on input, + * and the actual number of dimensions on output. */ + size_t* ndims, + /* Image extents ("ndims" number of entries). */ + size_t size[], + /* Number of interleaved image channels. */ + size_t* ncomponents, + /* Type of the image elements (pixel type). */ + libxsmm_mhd_elemtype* type, + /* Size of the header in bytes; may be used to skip the header, + * when reading content; can be a NULL-argument (optional). */ + size_t* header_size, + /* Size (in Bytes) of an user-defined extended data record; + * can be a NULL-argument (optional). */ + size_t* extension_size); + + +/** + * Loads the data file, and optionally allows data conversion. + * Conversion is performed such that values are clamped to fit + * into the destination. + */ +LIBXSMM_API int libxsmm_mhd_read( + /* Filename referring to the data. */ + const char filename[], + /* Offset within pitched buffer (NULL: no offset). */ + const size_t offset[], + /* Image dimensions (extents). */ + const size_t size[], + /* Leading buffer dimensions (NULL: same as size). */ + const size_t pitch[], + /* Dimensionality (number of entries in size). */ + size_t ndims, + /* Number of interleaved image channels. */ + size_t ncomponents, + /* Used to skip the header, and to only read the data. */ + size_t header_size, + /* Data element type as stored (pixel type). */ + libxsmm_mhd_elemtype type_stored, + /* Storage type (data conversion, optional). */ + const libxsmm_mhd_elemtype* type_data, + /* Buffer where the data is read into. */ + void* data, + /** + * Optional callback executed per entry when reading the data. + * May assign the value to the left-most argument, but also + * allows to only compare with present data. Can be used to + * avoid allocating an actual destination. + */ + libxsmm_mhd_element_handler handle_element, + /* Post-content data (extension, optional). */ + char extension[], + /* Size of the extension; can be zero. */ + size_t extension_size); + + +/** + * Save a file using an extended data format, which is compatible with the Meta Image Format (MHD). + * The file is suitable for visual inspection using, e.g., ITK-SNAP or ParaView. + */ +LIBXSMM_API int libxsmm_mhd_write(const char filename[], + /* Offset within pitched buffer (NULL: no offset). */ + const size_t offset[], + /* Image dimensions (extents). */ + const size_t size[], + /* Leading buffer dimensions (NULL: same as size). */ + const size_t pitch[], + /* Dimensionality, i.e., number of entries in data_size/size. */ + size_t ndims, + /* Number of pixel components. */ + size_t ncomponents, + /* Type (input). */ + libxsmm_mhd_elemtype type_data, + /* Type (data conversion, optional). */ + const libxsmm_mhd_elemtype* type, + /* Raw data to be saved. */ + const void* data, + /* Size of the header; can be a NULL-argument (optional). */ + size_t* header_size, + /* Extension header data; can be NULL. */ + const char extension_header[], + /* Extension data stream; can be NULL. */ + const void* extension, + /* Extension data size; can be NULL. */ + size_t extension_size); + +#endif /*LIBXSMM_MHD_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_rng.h b/third_party/libxsmm/include/libxsmm_rng.h new file mode 100644 index 0000000000000000000000000000000000000000..fa0ae51471ad3618e88c0829b9d500708526c2d0 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_rng.h @@ -0,0 +1,57 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_RNG_H +#define LIBXSMM_RNG_H + +#include "libxsmm_typedefs.h" + +/** + * create a new external state for thread-save execution managed + * by the user. We do not provide a function for drawing the random numbers + * the user is supposed to call the LIBXSMM_INTRINSICS_MM512_RNG_EXTSTATE_PS + * or LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EXTSTATE_EPI32 intrinsic. + * */ +LIBXSMM_API unsigned int* libxsmm_rng_create_extstate(unsigned int/*uint32_t*/ seed); + +/** free a previously created rng_avx512_extstate */ +LIBXSMM_API void libxsmm_rng_destroy_extstate(unsigned int* stateptr); + +/** Set the seed of libxsmm_rng_* (similar to srand). */ +LIBXSMM_API void libxsmm_rng_set_seed(unsigned int/*uint32_t*/ seed); + +/** + * This SP-RNG is using xoshiro128+ 1.0, work done by + * David Blackman and Sebastiano Vigna (vigna@acm.org). + * It is their best and fastest 32-bit generator for + * 32-bit floating-point numbers. They suggest to use + * its upper bits for floating-point generation, what + * we do here and generate numbers in [0,1(. + */ +LIBXSMM_API void libxsmm_rng_f32_seq(float* rngs, libxsmm_blasint count); + +/** + * Returns a (pseudo-)random value based on rand/rand48 in the interval [0, n). + * This function compensates for an n, which is not a factor of RAND_MAX. + * Note: libxsmm_rng_set_seed must be used if one wishes to seed the generator. + */ +LIBXSMM_API unsigned int libxsmm_rng_u32(unsigned int n); + +/** Sequence of random data based on libxsmm_rng_u32. */ +LIBXSMM_API void libxsmm_rng_seq(void* data, libxsmm_blasint nbytes); + +/** + * Similar to libxsmm_rng_u32, but returns a DP-value in the interval [0, 1). + * Note: libxsmm_rng_set_seed must be used if one wishes to seed the generator. + */ +LIBXSMM_API double libxsmm_rng_f64(void); + +#endif /* LIBXSMM_RNG_H */ + diff --git a/third_party/libxsmm/include/libxsmm_source.h b/third_party/libxsmm/include/libxsmm_source.h new file mode 100644 index 0000000000000000000000000000000000000000..645cae215c422956ca06e35803133b49ba83af76 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_source.h @@ -0,0 +1,144 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_SOURCE_H +#define LIBXSMM_SOURCE_H + +#if defined(LIBXSMM_MACROS_H) +# error Please do not include any LIBXSMM header other than libxsmm_source.h! +#endif +#if defined(LIBXSMM_BUILD) +# error LIBXSMM_BUILD cannot be defined for the header-only LIBXSMM! +#endif + +/** + * This header is intentionally called "libxsmm_source.h" since the followings block + * includes *internal* files, and thereby exposes LIBXSMM's implementation. + * The so-called "header-only" usage model gives up the clearly defined binary interface + * (including support for hot-fixes after deployment), and requires to rebuild client + * code for every (internal) change of LIBXSMM. Please make sure to only rely on the + * public interface as the internal implementation may change without notice. + */ +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include "../src/generator_aarch64_instructions.c" +#include "../src/generator_common.c" +#include "../src/generator_common_aarch64.c" +#include "../src/generator_common_x86.c" +#include "../src/generator_gemm.c" +#include "../src/generator_gemm_aarch64.c" +#include "../src/generator_gemm_amx.c" +#include "../src/generator_gemm_amx_emu.c" +#include "../src/generator_gemm_amx_microkernel.c" +#include "../src/generator_gemm_amx_microkernel_emu.c" +#include "../src/generator_gemm_avx2_microkernel.c" +#include "../src/generator_gemm_avx512_microkernel.c" +#include "../src/generator_gemm_avx_microkernel.c" +#include "../src/generator_gemm_common.c" +#include "../src/generator_gemm_common_aarch64.c" +#include "../src/generator_gemm_noarch.c" +#include "../src/generator_gemm_sse_avx_avx2_avx512.c" +#include "../src/generator_gemm_sse_microkernel.c" +#include "../src/generator_mateltwise.c" +#include "../src/generator_mateltwise_misc_avx_avx512.c" +#include "../src/generator_mateltwise_reduce_avx_avx512.c" +#include "../src/generator_mateltwise_sse_avx_avx512.c" +#include "../src/generator_mateltwise_transform_avx.c" +#include "../src/generator_mateltwise_transform_avx512.c" +#include "../src/generator_mateltwise_transform_common.c" +#include "../src/generator_mateltwise_transform_common_x86.c" +#include "../src/generator_mateltwise_transform_sse.c" +#include "../src/generator_mateltwise_unary_binary_avx_avx512.c" +#include "../src/generator_matequation.c" +#include "../src/generator_matequation_avx_avx512.c" +#include "../src/generator_matequation_regblocks_avx_avx512.c" +#include "../src/generator_matequation_scratch_avx_avx512.c" +#include "../src/generator_packed_gemm_ac_rm.c" +#include "../src/generator_packed_gemm_ac_rm_aarch64.c" +#include "../src/generator_packed_gemm_ac_rm_avx_avx2_avx512.c" +#include "../src/generator_packed_gemm_bc_rm.c" +#include "../src/generator_packed_gemm_bc_rm_aarch64.c" +#include "../src/generator_packed_gemm_bc_rm_avx_avx2_avx512.c" +#include "../src/generator_packed_spgemm.c" +#include "../src/generator_packed_spgemm_csc_bsparse.c" +#include "../src/generator_packed_spgemm_csc_bsparse_aarch64.c" +#include "../src/generator_packed_spgemm_csc_bsparse_avx_avx2_avx512.c" +#include "../src/generator_packed_spgemm_csc_csparse.c" +#include "../src/generator_packed_spgemm_csc_csparse_avx_avx2_avx512.c" +#include "../src/generator_packed_spgemm_csr_asparse.c" +#include "../src/generator_packed_spgemm_csr_asparse_aarch64.c" +#include "../src/generator_packed_spgemm_csr_asparse_avx_avx2_avx512.c" +#include "../src/generator_packed_spgemm_csr_bsparse.c" +#include "../src/generator_packed_spgemm_csr_bsparse_aarch64.c" +#include "../src/generator_packed_spgemm_csr_bsparse_avx_avx2_avx512.c" +#include "../src/generator_spgemm.c" +#include "../src/generator_spgemm_csc_asparse.c" +#include "../src/generator_spgemm_csc_bsparse.c" +#include "../src/generator_spgemm_csc_reader.c" +#include "../src/generator_spgemm_csr_asparse.c" +#include "../src/generator_spgemm_csr_asparse_reg.c" +#include "../src/generator_spgemm_csr_reader.c" +#include "../src/generator_x86_instructions.c" +#include "../src/libxsmm_cpuid_arm.c" +#include "../src/libxsmm_cpuid_x86.c" +#include "../src/libxsmm_dnn.c" +#include "../src/libxsmm_dnn_convolution.c" +#include "../src/libxsmm_dnn_convolution_backward.c" +#include "../src/libxsmm_dnn_convolution_forward.c" +#include "../src/libxsmm_dnn_convolution_weight_update.c" +#include "../src/libxsmm_dnn_elementwise.c" +#include "../src/libxsmm_dnn_fullyconnected.c" +#include "../src/libxsmm_dnn_fullyconnected_backward_weight_update.c" +#include "../src/libxsmm_dnn_fullyconnected_forward.c" +#include "../src/libxsmm_dnn_fusedbatchnorm.c" +#include "../src/libxsmm_dnn_fusedbatchnorm_backward.c" +#include "../src/libxsmm_dnn_fusedbatchnorm_forward.c" +#include "../src/libxsmm_dnn_fusedgroupnorm.c" +#include "../src/libxsmm_dnn_fusedgroupnorm_backward.c" +#include "../src/libxsmm_dnn_fusedgroupnorm_forward.c" +#include "../src/libxsmm_dnn_optimizer.c" +#include "../src/libxsmm_dnn_optimizer_sgd.c" +#include "../src/libxsmm_dnn_pooling.c" +#include "../src/libxsmm_dnn_pooling_backward.c" +#include "../src/libxsmm_dnn_pooling_forward.c" +#include "../src/libxsmm_dnn_rnncell.c" +#include "../src/libxsmm_dnn_rnncell_backward_weight_update.c" +#include "../src/libxsmm_dnn_rnncell_forward.c" +#include "../src/libxsmm_dnn_softmaxloss.c" +#include "../src/libxsmm_dnn_softmaxloss_backward.c" +#include "../src/libxsmm_dnn_softmaxloss_forward.c" +#include "../src/libxsmm_dnn_tensor.c" +#include "../src/libxsmm_ext.c" +#include "../src/libxsmm_ext_gemm.c" +#include "../src/libxsmm_ext_xcopy.c" +#include "../src/libxsmm_fsspmdm.c" +#include "../src/libxsmm_gemm.c" +#include "../src/libxsmm_generator.c" +#include "../src/libxsmm_hash.c" +#include "../src/libxsmm_main.c" +#include "../src/libxsmm_malloc.c" +#include "../src/libxsmm_math.c" +#include "../src/libxsmm_matrixeqn.c" +#include "../src/libxsmm_memory.c" +#include "../src/libxsmm_mhd.c" +#include "../src/libxsmm_perf.c" +#include "../src/libxsmm_python.c" +#include "../src/libxsmm_rng.c" +#include "../src/libxsmm_spmdm.c" +#include "../src/libxsmm_sync.c" +#include "../src/libxsmm_timer.c" +#include "../src/libxsmm_trace.c" +#include "../src/libxsmm_xcopy.c" +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#endif /*LIBXSMM_SOURCE_H*/ diff --git a/third_party/libxsmm/include/libxsmm_spmdm.h b/third_party/libxsmm/include/libxsmm_spmdm.h new file mode 100644 index 0000000000000000000000000000000000000000..1f452dd396c47225c0ccdfb484160fa2bf69c32d --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_spmdm.h @@ -0,0 +1,115 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Nadathur Satish (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_SPMDM_H +#define LIBXSMM_SPMDM_H + +#include "libxsmm_typedefs.h" + + +typedef enum libxsmm_spmdm_datatype { + LIBXSMM_SPMDM_DATATYPE_F32, + LIBXSMM_SPMDM_DATATYPE_BFLOAT16 +} libxsmm_spmdm_datatype; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_spmdm_handle { + /* The following are the matrix multiply dimensions: A (sparse): m X k, B (dense): k X n, Output C (dense): m X n */ + int m; + int n; + int k; + /* The block sizes for A, B and C. */ + /* Here we fix A to be divided into 128 X 128 blocks, B/C to be 128 X 48 for HSW/BDW and 128 X 96 for SKX */ + int bm; + int bn; + int bk; + /* The number of blocks for the m, n and k dimensions */ + int mb; + int nb; + int kb; + libxsmm_spmdm_datatype datatype; + char* base_ptr_scratch_A; + char* base_ptr_scratch_B_scratch_C; + int memory_for_scratch_per_thread; +} libxsmm_spmdm_handle; + +/** + * This stores a single sparse splice (or block) of sparse matrix A using a CSR representation (rowidx, colidx, and values + * Each splice corresponds to a bm X bk region of A, and stores local indexes + */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_CSR_sparseslice { + /* Since bm and bk are assumed to be <=256, a 16-bit integer is enough to store the local rowidx, colidx */ + uint16_t* rowidx; + uint16_t* colidx; + float* values; +} libxsmm_CSR_sparseslice; + + +LIBXSMM_API void libxsmm_spmdm_init( + int M, int N, int K, + int max_threads, + libxsmm_spmdm_handle* handle, + libxsmm_CSR_sparseslice** libxsmm_output_csr); + +LIBXSMM_API void libxsmm_spmdm_destroy( + libxsmm_spmdm_handle* handle); + +LIBXSMM_API int libxsmm_spmdm_get_num_createSparseSlice_blocks( + const libxsmm_spmdm_handle* handle); + +LIBXSMM_API int libxsmm_spmdm_get_num_compute_blocks( + const libxsmm_spmdm_handle* handle); + +/** This converts a dense representation of the sparse matrix to 2D array of sparse slices. */ +LIBXSMM_API void libxsmm_spmdm_createSparseSlice_fp32_thread( + const libxsmm_spmdm_handle* handle, + char transa, + const float* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads); + +LIBXSMM_API void libxsmm_spmdm_createSparseSlice_bfloat16_thread( + const libxsmm_spmdm_handle* handle, + char transa, + const libxsmm_bfloat16* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads); + +/** NOTE: This code currently ignores alpha input to the matrix multiply */ +LIBXSMM_API void libxsmm_spmdm_compute_fp32_thread( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const float* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const float* b, + char transc, + const float* beta, + float* c, + int block_id, + int tid, int nthreads); + +/** NOTE: This code currently ignores alpha input to the matrix multiply */ +LIBXSMM_API void libxsmm_spmdm_compute_bfloat16_thread( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const libxsmm_bfloat16* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const libxsmm_bfloat16* b, + char transc, + const libxsmm_bfloat16* beta, + float* c, + int block_id, + int tid, int nthreads); + +#endif /*LIBXSMM_SPMDM_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_sync.h b/third_party/libxsmm/include/libxsmm_sync.h new file mode 100644 index 0000000000000000000000000000000000000000..1f40fab1c2b4838361909ce8199b06560312ed14 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_sync.h @@ -0,0 +1,816 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_SYNC_H +#define LIBXSMM_SYNC_H + +#include "libxsmm_intrinsics_x86.h" + +#if !defined(LIBXSMM_TLS) +# if (0 != LIBXSMM_SYNC) && !defined(LIBXSMM_NO_TLS) +# if defined(__CYGWIN__) && defined(__clang__) +# define LIBXSMM_NO_TLS +# define LIBXSMM_TLS +# else +# if (defined(_WIN32) && !defined(__GNUC__) && !defined(__clang__)) || (defined(__PGI) && !defined(__cplusplus)) +# define LIBXSMM_TLS LIBXSMM_ATTRIBUTE(thread) +# elif defined(__GNUC__) || defined(__clang__) || defined(_CRAYC) +# define LIBXSMM_TLS __thread +# elif defined(__cplusplus) +# define LIBXSMM_TLS thread_local +# else +# error Missing TLS support! +# endif +# endif +# else +# if !defined(LIBXSMM_NO_TLS) +# define LIBXSMM_NO_TLS +# endif +# define LIBXSMM_TLS +# endif +#endif + +#if !defined(LIBXSMM_GCC_BASELINE) && !defined(LIBXSMM_SYNC_LEGACY) && ((defined(_WIN32) && defined(__clang__)) || \ + (defined(__GNUC__) && LIBXSMM_VERSION2(4, 7) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) +# define LIBXSMM_GCC_BASELINE +#endif + +#if defined(__MIC__) +# define LIBXSMM_SYNC_PAUSE _mm_delay_32(8/*delay*/) +#elif !defined(LIBXSMM_INTRINSICS_NONE) +# if defined(LIBXSMM_GCC_BASELINE) && !defined(__INTEL_COMPILER) +# define LIBXSMM_SYNC_PAUSE __builtin_ia32_pause() +# else +# define LIBXSMM_SYNC_PAUSE _mm_pause() +# endif +#elif (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH) && defined(__GNUC__) +# define LIBXSMM_SYNC_PAUSE __asm__ __volatile__("pause" ::: "memory") +#else +# define LIBXSMM_SYNC_PAUSE +#endif + +/* permit thread-unsafe */ +#if !defined(LIBXSMM_SYNC_NONE) && ( \ + (defined(__PGI) && (!defined(LIBXSMM_LIBATOMIC) || !defined(__STATIC))) || \ + (defined(_CRAYC) && !defined(__GNUC__))) +# define LIBXSMM_SYNC_NONE +#endif + +#if !defined(LIBXSMM_ATOMIC_TRYLOCK_CMPSWP) && 0 +# define LIBXSMM_ATOMIC_TRYLOCK_CMPSWP +#endif +#if !defined(LIBXSMM_ATOMIC_ZERO_STORE) && defined(_CRAYC) +# define LIBXSMM_ATOMIC_ZERO_STORE +#endif +#if !defined(LIBXSMM_ATOMIC_LOCKTYPE) +# if defined(_WIN32) || 1/*alignment*/ +# define LIBXSMM_ATOMIC_LOCKTYPE int +# else +# define LIBXSMM_ATOMIC_LOCKTYPE char +# endif +#endif + +typedef enum libxsmm_atomic_kind { +#if defined(__ATOMIC_SEQ_CST) + LIBXSMM_ATOMIC_SEQ_CST = __ATOMIC_SEQ_CST, +#else + LIBXSMM_ATOMIC_SEQ_CST = 0, +#endif +#if defined(__ATOMIC_RELAXED) + LIBXSMM_ATOMIC_RELAXED = __ATOMIC_RELAXED +#else + LIBXSMM_ATOMIC_RELAXED = LIBXSMM_ATOMIC_SEQ_CST +#endif +} libxsmm_atomic_kind; + +#define LIBXSMM_NONATOMIC_LOCKTYPE LIBXSMM_ATOMIC_LOCKTYPE +#define LIBXSMM_NONATOMIC_LOAD(SRC_PTR, KIND) (*(SRC_PTR)) +#define LIBXSMM_NONATOMIC_STORE(DST_PTR, VALUE, KIND) { LIBXSMM_UNUSED(KIND); *(DST_PTR) = (VALUE); } +#define LIBXSMM_NONATOMIC_STORE_ZERO(DST_PTR, KIND) LIBXSMM_NONATOMIC_STORE(DST_PTR, 0, KIND) +#define LIBXSMM_NONATOMIC_FETCH_OR(DST_PTR, VALUE/*side-effect*/, KIND) (/* 1st step: swap(dst, val) */ \ + ((*DST_PTR) = (*DST_PTR) ^ (VALUE)), (VALUE = (VALUE) ^ (*DST_PTR)), ((*DST_PTR) = (*DST_PTR) ^ (VALUE)), \ + (*(DST_PTR) |= VALUE), (VALUE) /* 2nd step: or, and 3rd/last step: original dst-value */) +#define LIBXSMM_NONATOMIC_ADD_FETCH(DST_PTR, VALUE, KIND) (*(DST_PTR) += VALUE) +#define LIBXSMM_NONATOMIC_SUB_FETCH(DST_PTR, VALUE, KIND) (*(DST_PTR) -= VALUE) +#define LIBXSMM_NONATOMIC_FETCH_ADD(DST_PTR, VALUE, KIND) (LIBXSMM_NONATOMIC_ADD_FETCH(DST_PTR, VALUE, KIND), (*(DST_PTR) - (VALUE))) +#define LIBXSMM_NONATOMIC_FETCH_SUB(DST_PTR, VALUE, KIND) (LIBXSMM_NONATOMIC_SUB_FETCH(DST_PTR, VALUE, KIND), (*(DST_PTR) + (VALUE))) +#define LIBXSMM_NONATOMIC_CMPSWP(DST_PTR, OLDVAL, NEWVAL, KIND) ((NEWVAL) == (*(DST_PTR) == (OLDVAL) ? (*(DST_PTR) = (NEWVAL)) : (OLDVAL))) +#define LIBXSMM_NONATOMIC_TRYLOCK(DST_PTR, KIND) LIBXSMM_NONATOMIC_CMPSWP(DST_PTR, 0, 1, KIND) +#define LIBXSMM_NONATOMIC_ACQUIRE(DST_PTR, NPAUSE, KIND) { LIBXSMM_UNUSED(NPAUSE); \ + LIBXSMM_ASSERT_MSG(0 == *(DST_PTR), "LIBXSMM_NONATOMIC_ACQUIRE"); LIBXSMM_NONATOMIC_STORE(DST_PTR, 1, KIND); \ + LIBXSMM_ASSERT_MSG(0 != *(DST_PTR), "LIBXSMM_NONATOMIC_ACQUIRE"); } +#define LIBXSMM_NONATOMIC_RELEASE(DST_PTR, KIND) { LIBXSMM_UNUSED(DST_PTR); LIBXSMM_UNUSED(KIND); \ + LIBXSMM_ASSERT_MSG(0 != *(DST_PTR), "LIBXSMM_NONATOMIC_RELEASE"); LIBXSMM_NONATOMIC_STORE(DST_PTR, 0, KIND); \ + LIBXSMM_ASSERT_MSG(0 == *(DST_PTR), "LIBXSMM_NONATOMIC_RELEASE"); } +#define LIBXSMM_NONATOMIC_SYNC(KIND) LIBXSMM_UNUSED(KIND) + +#if (0 == LIBXSMM_SYNC) || defined(LIBXSMM_SYNC_NONE) +# define LIBXSMM_ATOMIC(FN, BITS) FN +# define LIBXSMM_ATOMIC_LOAD LIBXSMM_NONATOMIC_LOAD +# define LIBXSMM_ATOMIC_STORE LIBXSMM_NONATOMIC_STORE +# define LIBXSMM_ATOMIC_STORE_ZERO LIBXSMM_NONATOMIC_STORE_ZERO +# define LIBXSMM_ATOMIC_FETCH_OR LIBXSMM_NONATOMIC_FETCH_OR +# define LIBXSMM_ATOMIC_ADD_FETCH LIBXSMM_NONATOMIC_ADD_FETCH +# define LIBXSMM_ATOMIC_SUB_FETCH LIBXSMM_NONATOMIC_SUB_FETCH +# define LIBXSMM_ATOMIC_FETCH_ADD LIBXSMM_NONATOMIC_FETCH_ADD +# define LIBXSMM_ATOMIC_FETCH_SUB LIBXSMM_NONATOMIC_FETCH_SUB +# define LIBXSMM_ATOMIC_CMPSWP LIBXSMM_NONATOMIC_CMPSWP +# define LIBXSMM_ATOMIC_TRYLOCK LIBXSMM_NONATOMIC_TRYLOCK +# define LIBXSMM_ATOMIC_ACQUIRE LIBXSMM_NONATOMIC_ACQUIRE +# define LIBXSMM_ATOMIC_RELEASE LIBXSMM_NONATOMIC_RELEASE +# define LIBXSMM_ATOMIC_SYNC LIBXSMM_NONATOMIC_SYNC +# if !defined(LIBXSMM_SYNC_NPAUSE) +# define LIBXSMM_SYNC_NPAUSE 0 +# endif +#elif (defined(LIBXSMM_GCC_BASELINE) || defined(LIBXSMM_LIBATOMIC) /* GNU's libatomic required */ || \ + (defined(__GNUC__) && LIBXSMM_VERSION2(4, 1) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) +# if defined(LIBXSMM_LIBATOMIC) +# define LIBXSMM_ATOMIC(FN, BITS) LIBXSMM_CONCATENATE(LIBXSMM_ATOMIC, BITS)(FN) +# define LIBXSMM_ATOMIC8(FN) LIBXSMM_CONCATENATE(FN, 8) +# define LIBXSMM_ATOMIC16(FN) LIBXSMM_CONCATENATE(FN, 16) +# define LIBXSMM_ATOMIC32(FN) FN/*default*/ +# define LIBXSMM_ATOMIC64(FN) LIBXSMM_CONCATENATE(FN, 64) +# if defined(__PGI) +# define LIBXSMM_ATOMIC_LOAD(SRC_PTR, KIND) LIBXSMM_NONATOMIC_LOAD(SRC_PTR, KIND) +# define LIBXSMM_ATOMIC_LOAD8(SRC_PTR, KIND) LIBXSMM_NONATOMIC_LOAD(SRC_PTR, KIND) +# define LIBXSMM_ATOMIC_LOAD16(SRC_PTR, KIND) LIBXSMM_NONATOMIC_LOAD(SRC_PTR, KIND) +# define LIBXSMM_ATOMIC_LOAD64(SRC_PTR, KIND) LIBXSMM_NONATOMIC_LOAD(SRC_PTR, KIND) +# define LIBXSMM_ATOMIC_STORE(DST_PTR, VALUE, KIND) LIBXSMM_NONATOMIC_STORE(DST_PTR, VALUE, KIND) +# define LIBXSMM_ATOMIC_STORE8(DST_PTR, VALUE, KIND) LIBXSMM_NONATOMIC_STORE(DST_PTR, VALUE, KIND) +# define LIBXSMM_ATOMIC_STORE16(DST_PTR, VALUE, KIND) LIBXSMM_NONATOMIC_STORE(DST_PTR, VALUE, KIND) +# define LIBXSMM_ATOMIC_STORE64(DST_PTR, VALUE, KIND) LIBXSMM_NONATOMIC_STORE(DST_PTR, VALUE, KIND) +# else +# define LIBXSMM_ATOMIC_LOAD(SRC_PTR, KIND) __atomic_load_4(SRC_PTR, KIND) +# define LIBXSMM_ATOMIC_LOAD8(SRC_PTR, KIND) __atomic_load_1(SRC_PTR, KIND) +# define LIBXSMM_ATOMIC_LOAD16(SRC_PTR, KIND) __atomic_load_2(SRC_PTR, KIND) +# define LIBXSMM_ATOMIC_LOAD64(SRC_PTR, KIND) __atomic_load_8(SRC_PTR, KIND) +# define LIBXSMM_ATOMIC_STORE(DST_PTR, VALUE, KIND) __atomic_store_4(DST_PTR, (unsigned int)(VALUE), KIND) +# define LIBXSMM_ATOMIC_STORE8(DST_PTR, VALUE, KIND) __atomic_store_1(DST_PTR, (unsigned char)(VALUE), KIND) +# define LIBXSMM_ATOMIC_STORE16(DST_PTR, VALUE, KIND) __atomic_store_2(DST_PTR, (unsigned short)(VALUE), KIND) +# define LIBXSMM_ATOMIC_STORE64(DST_PTR, VALUE, KIND) __atomic_store_8(DST_PTR, (unsigned long long)(VALUE), KIND) +# endif +# define LIBXSMM_ATOMIC_FETCH_OR(DST_PTR, VALUE, KIND) __atomic_fetch_or_4(DST_PTR, (unsigned int)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_OR8(DST_PTR, VALUE, KIND) __atomic_fetch_or_1(DST_PTR, (unsigned char)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_OR16(DST_PTR, VALUE, KIND) __atomic_fetch_or_2(DST_PTR, (unsigned short)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_OR64(DST_PTR, VALUE, KIND) __atomic_fetch_or_8(DST_PTR, (unsigned long long)(VALUE), KIND) +# define LIBXSMM_ATOMIC_ADD_FETCH(DST_PTR, VALUE, KIND) __atomic_add_fetch_4(DST_PTR, (int)(VALUE), KIND) +# define LIBXSMM_ATOMIC_ADD_FETCH8(DST_PTR, VALUE, KIND) __atomic_add_fetch_1(DST_PTR, (signed char)(VALUE), KIND) +# define LIBXSMM_ATOMIC_ADD_FETCH16(DST_PTR, VALUE, KIND) __atomic_add_fetch_2(DST_PTR, (short)(VALUE), KIND) +# define LIBXSMM_ATOMIC_ADD_FETCH64(DST_PTR, VALUE, KIND) __atomic_add_fetch_8(DST_PTR, (long long)(VALUE), KIND) +# define LIBXSMM_ATOMIC_SUB_FETCH(DST_PTR, VALUE, KIND) __atomic_sub_fetch_4(DST_PTR, (int)(VALUE), KIND) +# define LIBXSMM_ATOMIC_SUB_FETCH8(DST_PTR, VALUE, KIND) __atomic_sub_fetch_1(DST_PTR, (signed char)(VALUE), KIND) +# define LIBXSMM_ATOMIC_SUB_FETCH16(DST_PTR, VALUE, KIND) __atomic_sub_fetch_2(DST_PTR, (short)(VALUE), KIND) +# define LIBXSMM_ATOMIC_SUB_FETCH64(DST_PTR, VALUE, KIND) __atomic_sub_fetch_8(DST_PTR, (long long)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_ADD(DST_PTR, VALUE, KIND) __atomic_fetch_add_4(DST_PTR, (int)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_ADD8(DST_PTR, VALUE, KIND) __atomic_fetch_add_1(DST_PTR, (signed char)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_ADD16(DST_PTR, VALUE, KIND) __atomic_fetch_add_2(DST_PTR, (short)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_ADD64(DST_PTR, VALUE, KIND) __atomic_fetch_add_8(DST_PTR, (long long)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_SUB(DST_PTR, VALUE, KIND) __atomic_fetch_sub_4(DST_PTR, (int)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_SUB8(DST_PTR, VALUE, KIND) __atomic_fetch_sub_1(DST_PTR, (signed char)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_SUB16(DST_PTR, VALUE, KIND) __atomic_fetch_sub_2(DST_PTR, (short)(VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_SUB64(DST_PTR, VALUE, KIND) __atomic_fetch_sub_8(DST_PTR, (long long)(VALUE), KIND) +# define LIBXSMM_ATOMIC_CMPSWP(DST_PTR, OLDVAL, NEWVAL, KIND) \ + __atomic_compare_exchange_4(DST_PTR, &(OLDVAL), (NEWVAL), 0/*false*/, KIND, LIBXSMM_ATOMIC_RELAXED) +# define LIBXSMM_ATOMIC_CMPSWP8(DST_PTR, OLDVAL, NEWVAL, KIND) \ + __atomic_compare_exchange_1(DST_PTR, &(OLDVAL), (NEWVAL), 0/*false*/, KIND, LIBXSMM_ATOMIC_RELAXED) +# define LIBXSMM_ATOMIC_CMPSWP16(DST_PTR, OLDVAL, NEWVAL, KIND) \ + __atomic_compare_exchange_2(DST_PTR, &(OLDVAL), (NEWVAL), 0/*false*/, KIND, LIBXSMM_ATOMIC_RELAXED) +# define LIBXSMM_ATOMIC_CMPSWP64(DST_PTR, OLDVAL, NEWVAL, KIND) \ + __atomic_compare_exchange_8(DST_PTR, &(OLDVAL), (NEWVAL), 0/*false*/, KIND, LIBXSMM_ATOMIC_RELAXED) +# if defined(LIBXSMM_ATOMIC_TRYLOCK_CMPSWP) +# define LIBXSMM_ATOMIC_TRYLOCK(DST_PTR, KIND) (!__atomic_test_and_set(DST_PTR, KIND)) +# endif +# if defined(__PGI) +# define LIBXSMM_ATOMIC_RELEASE(DST_PTR, KIND) { LIBXSMM_ASSERT_MSG(0 != *(DST_PTR), "LIBXSMM_ATOMIC_RELEASE"); \ + LIBXSMM_ATOMIC_STORE_ZERO8(DST_PTR, KIND); } /* matches bit-width of LIBXSMM_ATOMIC_LOCKTYPE */ +# else +# define LIBXSMM_ATOMIC_RELEASE(DST_PTR, KIND) { LIBXSMM_ASSERT_MSG(0 != *(DST_PTR), "LIBXSMM_ATOMIC_RELEASE"); \ + __atomic_clear(DST_PTR, KIND); } +# endif +# define LIBXSMM_ATOMIC_SYNC(KIND) __sync_synchronize() +# if !defined(LIBXSMM_ATOMIC_ZERO_STORE) +# define LIBXSMM_ATOMIC_ZERO_STORE +# endif +# elif defined(LIBXSMM_GCC_BASELINE) +# define LIBXSMM_ATOMIC(FN, BITS) FN +# define LIBXSMM_ATOMIC_LOAD(SRC_PTR, KIND) __atomic_load_n(SRC_PTR, KIND) +# define LIBXSMM_ATOMIC_STORE(DST_PTR, VALUE, KIND) __atomic_store_n(DST_PTR, VALUE, KIND) +# if !defined(LIBXSMM_ATOMIC_ZERO_STORE) +# define LIBXSMM_ATOMIC_STORE_ZERO(DST_PTR, KIND) do {} while (__atomic_and_fetch(DST_PTR, 0, KIND)) +# endif +# define LIBXSMM_ATOMIC_FETCH_OR(DST_PTR, VALUE, KIND) __atomic_fetch_or(DST_PTR, VALUE, KIND) +# define LIBXSMM_ATOMIC_ADD_FETCH(DST_PTR, VALUE, KIND) __atomic_add_fetch(DST_PTR, VALUE, KIND) +# define LIBXSMM_ATOMIC_SUB_FETCH(DST_PTR, VALUE, KIND) __atomic_sub_fetch(DST_PTR, VALUE, KIND) +# define LIBXSMM_ATOMIC_FETCH_ADD(DST_PTR, VALUE, KIND) __atomic_fetch_add(DST_PTR, VALUE, KIND) +# define LIBXSMM_ATOMIC_FETCH_SUB(DST_PTR, VALUE, KIND) __atomic_fetch_sub(DST_PTR, VALUE, KIND) +# define LIBXSMM_ATOMIC_CMPSWP(DST_PTR, OLDVAL, NEWVAL, KIND) __sync_bool_compare_and_swap(DST_PTR, OLDVAL, NEWVAL) +# if defined(LIBXSMM_ATOMIC_TRYLOCK_CMPSWP) +# define LIBXSMM_ATOMIC_TRYLOCK(DST_PTR, KIND) (!__atomic_test_and_set(DST_PTR, KIND)) +# endif +# define LIBXSMM_ATOMIC_RELEASE(DST_PTR, KIND) { LIBXSMM_ASSERT_MSG(0 != *(DST_PTR), "LIBXSMM_ATOMIC_RELEASE"); \ + __atomic_clear(DST_PTR, KIND); } +# if 0 /* __atomic_thread_fence: incorrect behavior in libxsmm_barrier (even with LIBXSMM_ATOMIC_SEQ_CST) */ +# define LIBXSMM_ATOMIC_SYNC(KIND) __atomic_thread_fence(KIND) +# else +# define LIBXSMM_ATOMIC_SYNC(KIND) __sync_synchronize() +# endif +# else /* GCC legacy atomics */ +# define LIBXSMM_ATOMIC(FN, BITS) FN +# define LIBXSMM_ATOMIC_LOAD(SRC_PTR, KIND) __sync_or_and_fetch(SRC_PTR, 0) +# if (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_ATOMIC_STORE(DST_PTR, VALUE, KIND) { \ + __asm__ __volatile__("" ::: "memory"); *(DST_PTR) = (VALUE); \ + __asm__ __volatile__("" ::: "memory"); } +# else +# define LIBXSMM_ATOMIC_SYNC_NOFENCE(KIND) +# define LIBXSMM_ATOMIC_STORE(DST_PTR, VALUE, KIND) *(DST_PTR) = (VALUE) +# endif +# if !defined(LIBXSMM_ATOMIC_ZERO_STORE) +# define LIBXSMM_ATOMIC_STORE_ZERO(DST_PTR, KIND) do {} while (__sync_and_and_fetch(DST_PTR, 0)) +# endif +# define LIBXSMM_ATOMIC_FETCH_OR(DST_PTR, VALUE, KIND) __sync_fetch_and_or(DST_PTR, VALUE) +# define LIBXSMM_ATOMIC_ADD_FETCH(DST_PTR, VALUE, KIND) __sync_add_and_fetch(DST_PTR, VALUE) +# define LIBXSMM_ATOMIC_SUB_FETCH(DST_PTR, VALUE, KIND) __sync_sub_and_fetch(DST_PTR, VALUE) +# define LIBXSMM_ATOMIC_FETCH_ADD(DST_PTR, VALUE, KIND) __sync_fetch_and_add(DST_PTR, VALUE) +# define LIBXSMM_ATOMIC_FETCH_SUB(DST_PTR, VALUE, KIND) __sync_fetch_and_sub(DST_PTR, VALUE) +# define LIBXSMM_ATOMIC_CMPSWP(DST_PTR, OLDVAL, NEWVAL, KIND) __sync_bool_compare_and_swap(DST_PTR, OLDVAL, NEWVAL) +# if defined(LIBXSMM_ATOMIC_TRYLOCK_CMPSWP) +# define LIBXSMM_ATOMIC_TRYLOCK(DST_PTR, KIND) (0 == __sync_lock_test_and_set(DST_PTR, 1)) +# endif +# define LIBXSMM_ATOMIC_RELEASE(DST_PTR, KIND) { LIBXSMM_ASSERT_MSG(0 != *(DST_PTR), "LIBXSMM_ATOMIC_RELEASE"); \ + __sync_lock_release(DST_PTR); } +# define LIBXSMM_ATOMIC_SYNC(KIND) __sync_synchronize() +# endif +# if defined(LIBXSMM_ATOMIC_ZERO_STORE) +# define LIBXSMM_ATOMIC_STORE_ZERO(DST_PTR, KIND) LIBXSMM_ATOMIC_STORE(DST_PTR, 0, KIND) +# define LIBXSMM_ATOMIC_STORE_ZERO8(DST_PTR, KIND) LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE, 8)(DST_PTR, 0, KIND) +# define LIBXSMM_ATOMIC_STORE_ZERO16(DST_PTR, KIND) LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE, 16)(DST_PTR, 0, KIND) +# define LIBXSMM_ATOMIC_STORE_ZERO64(DST_PTR, KIND) LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE, 64)(DST_PTR, 0, KIND) +# endif +# if !defined(LIBXSMM_ATOMIC_TRYLOCK_CMPSWP) +# define LIBXSMM_ATOMIC_TRYLOCK(DST_PTR, KIND) /* matches bit-width of LIBXSMM_ATOMIC_LOCKTYPE */ \ + (0 == LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_FETCH_OR, 8)(DST_PTR, 1, KIND)) +# endif +# define LIBXSMM_ATOMIC_ACQUIRE(DST_PTR, NPAUSE, KIND) \ + LIBXSMM_ASSERT(0 == LIBXSMM_MOD2((uintptr_t)(DST_PTR), 4)); \ + while (!LIBXSMM_ATOMIC_TRYLOCK(DST_PTR, KIND)) LIBXSMM_SYNC_CYCLE(DST_PTR, 0/*free*/, NPAUSE); \ + LIBXSMM_ASSERT_MSG(0 != *(DST_PTR), "LIBXSMM_ATOMIC_ACQUIRE") +# if !defined(LIBXSMM_SYNC_NPAUSE) +# define LIBXSMM_SYNC_NPAUSE 4096 +# endif +#elif defined(_WIN32) +# define LIBXSMM_ATOMIC(FN, BITS) LIBXSMM_CONCATENATE(LIBXSMM_ATOMIC, BITS)(FN) +# define LIBXSMM_ATOMIC8(FN) LIBXSMM_CONCATENATE(FN, 8) +# define LIBXSMM_ATOMIC16(FN) LIBXSMM_CONCATENATE(FN, 16) +# define LIBXSMM_ATOMIC32(FN) FN/*default*/ +# define LIBXSMM_ATOMIC64(FN) LIBXSMM_CONCATENATE(FN, 64) +# define LIBXSMM_ATOMIC_LOAD(SRC_PTR, KIND) InterlockedOr((volatile LONG*)(SRC_PTR), 0) +# define LIBXSMM_ATOMIC_LOAD8(SRC_PTR, KIND) _InterlockedOr8((volatile char*)(SRC_PTR), 0) +# define LIBXSMM_ATOMIC_LOAD64(SRC_PTR, KIND) InterlockedOr64((volatile LONGLONG*)(SRC_PTR), 0) +# define LIBXSMM_ATOMIC_STORE(DST_PTR, VALUE, KIND) InterlockedExchange((volatile LONG*)(DST_PTR), (LONG)(VALUE)) +# define LIBXSMM_ATOMIC_STORE8(DST_PTR, VALUE, KIND) InterlockedExchange8((volatile char*)(DST_PTR), (LONGLONG)(VALUE)) +# define LIBXSMM_ATOMIC_STORE64(DST_PTR, VALUE, KIND) InterlockedExchange64((volatile LONGLONG*)(DST_PTR), (LONGLONG)(VALUE)) +# if defined(LIBXSMM_ATOMIC_ZERO_STORE) +# define LIBXSMM_ATOMIC_STORE_ZERO(DST_PTR, KIND) LIBXSMM_ATOMIC_STORE(DST_PTR, 0, KIND) +# define LIBXSMM_ATOMIC_STORE_ZERO8(DST_PTR, KIND) LIBXSMM_ATOMIC_STORE8(DST_PTR, 0, KIND) +# define LIBXSMM_ATOMIC_STORE_ZERO64(DST_PTR, KIND) LIBXSMM_ATOMIC_STORE64(DST_PTR, 0, KIND) +# else +# define LIBXSMM_ATOMIC_STORE_ZERO(DST_PTR, KIND) InterlockedAnd((volatile LONG*)(DST_PTR), 0) +# define LIBXSMM_ATOMIC_STORE_ZERO8(DST_PTR, KIND) InterlockedAnd8((volatile char*)(DST_PTR), 0) +# define LIBXSMM_ATOMIC_STORE_ZERO64(DST_PTR, KIND) InterlockedAnd64((volatile LONGLONG*)(DST_PTR), 0) +# endif +# define LIBXSMM_ATOMIC_FETCH_OR(DST_PTR, VALUE, KIND) InterlockedOr((volatile LONG*)(DST_PTR), VALUE) +# define LIBXSMM_ATOMIC_FETCH_OR8(DST_PTR, VALUE, KIND) _InterlockedOr8((volatile char*)(DST_PTR), VALUE) +# define LIBXSMM_ATOMIC_ADD_FETCH(DST_PTR, VALUE, KIND) (LIBXSMM_ATOMIC_FETCH_ADD(DST_PTR, VALUE, KIND) + (VALUE)) +# define LIBXSMM_ATOMIC_ADD_FETCH16(DST_PTR, VALUE, KIND) (LIBXSMM_ATOMIC_FETCH_ADD16(DST_PTR, VALUE, KIND) + (VALUE)) +# define LIBXSMM_ATOMIC_ADD_FETCH64(DST_PTR, VALUE, KIND) (LIBXSMM_ATOMIC_FETCH_ADD64(DST_PTR, VALUE, KIND) + (VALUE)) +# define LIBXSMM_ATOMIC_SUB_FETCH(DST_PTR, VALUE, KIND) ((size_t)LIBXSMM_ATOMIC_FETCH_SUB(DST_PTR, VALUE, KIND) - ((size_t)VALUE)) +# define LIBXSMM_ATOMIC_SUB_FETCH16(DST_PTR, VALUE, KIND) (LIBXSMM_ATOMIC_FETCH_SUB16(DST_PTR, VALUE, KIND) - (VALUE)) +# define LIBXSMM_ATOMIC_SUB_FETCH64(DST_PTR, VALUE, KIND) (LIBXSMM_ATOMIC_FETCH_SUB64(DST_PTR, VALUE, KIND) - (VALUE)) +# define LIBXSMM_ATOMIC_FETCH_ADD(DST_PTR, VALUE, KIND) InterlockedExchangeAdd((volatile LONG*)(DST_PTR), VALUE) +# define LIBXSMM_ATOMIC_FETCH_ADD16(DST_PTR, VALUE, KIND) _InterlockedExchangeAdd16((volatile SHORT*)(DST_PTR), VALUE) +# define LIBXSMM_ATOMIC_FETCH_ADD64(DST_PTR, VALUE, KIND) InterlockedExchangeAdd64((volatile LONGLONG*)(DST_PTR), VALUE) +# define LIBXSMM_ATOMIC_FETCH_SUB(DST_PTR, VALUE, KIND) LIBXSMM_ATOMIC_FETCH_ADD(DST_PTR, -1 * (VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_SUB16(DST_PTR, VALUE, KIND) LIBXSMM_ATOMIC_FETCH_ADD16(DST_PTR, -1 * (VALUE), KIND) +# define LIBXSMM_ATOMIC_FETCH_SUB64(DST_PTR, VALUE, KIND) LIBXSMM_ATOMIC_FETCH_ADD64(DST_PTR, -1 * (VALUE), KIND) +# define LIBXSMM_ATOMIC_CMPSWP(DST_PTR, OLDVAL, NEWVAL, KIND) (((LONG)(OLDVAL)) == InterlockedCompareExchange((volatile LONG*)(DST_PTR), NEWVAL, OLDVAL)) +# define LIBXSMM_ATOMIC_CMPSWP8(DST_PTR, OLDVAL, NEWVAL, KIND) ((OLDVAL) == _InterlockedCompareExchange8((volatile char*)(DST_PTR), NEWVAL, OLDVAL)) +# if defined(LIBXSMM_ATOMIC_TRYLOCK_CMPSWP) +# define LIBXSMM_ATOMIC_TRYLOCK(DST_PTR, KIND) LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_CMPSWP, 8)(DST_PTR, 0, 1, KIND) +# else +# define LIBXSMM_ATOMIC_TRYLOCK(DST_PTR, KIND) (0 == LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_FETCH_OR, 8)(DST_PTR, 1, KIND)) +# endif +# define LIBXSMM_ATOMIC_ACQUIRE(DST_PTR, NPAUSE, KIND) \ + LIBXSMM_ASSERT(0 == LIBXSMM_MOD2((uintptr_t)(DST_PTR), 4)); \ + while (!LIBXSMM_ATOMIC_TRYLOCK(DST_PTR, KIND)) LIBXSMM_SYNC_CYCLE(DST_PTR, 0/*free*/, NPAUSE); \ + LIBXSMM_ASSERT_MSG(0 != *(DST_PTR), "LIBXSMM_ATOMIC_ACQUIRE") +# define LIBXSMM_ATOMIC_RELEASE(DST_PTR, KIND) { \ + LIBXSMM_ASSERT_MSG(0 != *(DST_PTR), "LIBXSMM_ATOMIC_RELEASE"); \ + LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE_ZERO, 8)(DST_PTR, KIND); } +# define LIBXSMM_ATOMIC_SYNC(KIND) _ReadWriteBarrier() +# if !defined(LIBXSMM_SYNC_NPAUSE) +# define LIBXSMM_SYNC_NPAUSE 4096 +# endif +#else /* consider to permit LIBXSMM_SYNC_NONE */ +# error LIBXSMM is missing atomic compiler builtins! +#endif + +#if !defined(LIBXSMM_SYNC_CYCLE) +# if (0 < LIBXSMM_SYNC_NPAUSE) +# define LIBXSMM_SYNC_CYCLE_ELSE(DST_PTR, EXP_STATE, NPAUSE, ELSE) do { int libxsmm_sync_cycle_npause_ = 1; \ + do { int libxsmm_sync_cycle_counter_ = 0; \ + for (; libxsmm_sync_cycle_counter_ < libxsmm_sync_cycle_npause_; ++libxsmm_sync_cycle_counter_) LIBXSMM_SYNC_PAUSE; \ + if (libxsmm_sync_cycle_npause_ < (NPAUSE)) { \ + libxsmm_sync_cycle_npause_ *= 2; \ + } \ + else { \ + libxsmm_sync_cycle_npause_ = (NPAUSE); \ + LIBXSMM_SYNC_YIELD; \ + ELSE \ + } \ + } while(((EXP_STATE) & 1) != (*(DST_PTR) & 1)); \ + } while(0) +# else +# define LIBXSMM_SYNC_CYCLE_ELSE(DST_PTR, EXP_STATE, NPAUSE, ELSE) LIBXSMM_SYNC_PAUSE +# endif +# define LIBXSMM_SYNC_CYCLE(DST_PTR, EXP_STATE, NPAUSE) \ + LIBXSMM_SYNC_CYCLE_ELSE(DST_PTR, EXP_STATE, NPAUSE, /*else*/;) +#endif + +#if (0 != LIBXSMM_SYNC) +# define LIBXSMM_LOCK_DEFAULT LIBXSMM_LOCK_SPINLOCK +# if !defined(LIBXSMM_LOCK_SYSTEM_SPINLOCK) && !(defined(_OPENMP) && defined(LIBXSMM_SYNC_OMP)) && \ + (!defined(__linux__) || defined(__USE_XOPEN2K)) && 0/*disabled*/ +# define LIBXSMM_LOCK_SYSTEM_SPINLOCK +# endif +# if !defined(LIBXSMM_LOCK_SYSTEM_MUTEX) && !(defined(_OPENMP) && defined(LIBXSMM_SYNC_OMP)) +# define LIBXSMM_LOCK_SYSTEM_MUTEX +# endif +# if !defined(LIBXSMM_LOCK_SYSTEM_RWLOCK) && !(defined(_OPENMP) && defined(LIBXSMM_SYNC_OMP)) && \ + (!defined(__linux__) || defined(__USE_XOPEN2K) || defined(__USE_UNIX98)) +# define LIBXSMM_LOCK_SYSTEM_RWLOCK +# endif + /* Lock type, initialization, destruction, (try-)lock, unlock, etc */ +# define LIBXSMM_LOCK_ACQUIRED(KIND) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_ACQUIRED_, KIND) +# define LIBXSMM_LOCK_TYPE_ISPOD(KIND) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_TYPE_ISPOD_, KIND) +# define LIBXSMM_LOCK_TYPE_ISRW(KIND) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_TYPE_ISRW_, KIND) +# define LIBXSMM_LOCK_TYPE(KIND) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_TYPE_, KIND) +# define LIBXSMM_LOCK_INIT(KIND, LOCK, ATTR) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_INIT_, KIND)(LOCK, ATTR) +# define LIBXSMM_LOCK_DESTROY(KIND, LOCK) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_DESTROY_, KIND)(LOCK) +# define LIBXSMM_LOCK_TRYLOCK(KIND, LOCK) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_TRYLOCK_, KIND)(LOCK) +# define LIBXSMM_LOCK_ACQUIRE(KIND, LOCK) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_ACQUIRE_, KIND)(LOCK) +# define LIBXSMM_LOCK_RELEASE(KIND, LOCK) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_RELEASE_, KIND)(LOCK) +# define LIBXSMM_LOCK_TRYREAD(KIND, LOCK) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_TRYREAD_, KIND)(LOCK) +# define LIBXSMM_LOCK_ACQREAD(KIND, LOCK) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_ACQREAD_, KIND)(LOCK) +# define LIBXSMM_LOCK_RELREAD(KIND, LOCK) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_RELREAD_, KIND)(LOCK) + /* Attribute type, initialization, destruction */ +# define LIBXSMM_LOCK_ATTR_TYPE(KIND) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_ATTR_TYPE_, KIND) +# define LIBXSMM_LOCK_ATTR_INIT(KIND, ATTR) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_ATTR_INIT_, KIND)(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY(KIND, ATTR) LIBXSMM_CONCATENATE(LIBXSMM_LOCK_ATTR_DESTROY_, KIND)(ATTR) + /* Cygwin's Pthread implementation appears to be broken; use Win32 */ +# if !defined(LIBXSMM_WIN32_THREADS) && (defined(_WIN32) || defined(__CYGWIN__)) +# define LIBXSMM_WIN32_THREADS _WIN32_WINNT +# if defined(__CYGWIN__) || defined(__MINGW32__) /* hack: make SRW-locks available */ +# if defined(_WIN32_WINNT) +# undef _WIN32_WINNT +# if !defined(NTDDI_VERSION) +# define NTDDI_VERSION 0x0600 +# endif +# define _WIN32_WINNT ((LIBXSMM_WIN32_THREADS) | 0x0600) +# else +# define _WIN32_WINNT 0x0600 +# endif +# endif +# endif +# if defined(LIBXSMM_WIN32_THREADS) +# define LIBXSMM_TLS_TYPE DWORD +# define LIBXSMM_TLS_CREATE(KEYPTR) *(KEYPTR) = TlsAlloc() +# define LIBXSMM_TLS_DESTROY(KEY) TlsFree(KEY) +# define LIBXSMM_TLS_SETVALUE(KEY, PTR) TlsSetValue(KEY, PTR) +# define LIBXSMM_TLS_GETVALUE(KEY) TlsGetValue(KEY) +# define LIBXSMM_LOCK_SPINLOCK spin +# if ((LIBXSMM_WIN32_THREADS) & 0x0600) +# define LIBXSMM_LOCK_MUTEX rwlock +# define LIBXSMM_LOCK_RWLOCK rwlock +# else /* mutex exposes high latency */ +# define LIBXSMM_LOCK_MUTEX mutex +# define LIBXSMM_LOCK_RWLOCK mutex +# endif +# if defined(LIBXSMM_LOCK_SYSTEM_SPINLOCK) +# define LIBXSMM_LOCK_ACQUIRED_spin TRUE +# define LIBXSMM_LOCK_TYPE_ISPOD_spin 0 +# define LIBXSMM_LOCK_TYPE_spin CRITICAL_SECTION +# define LIBXSMM_LOCK_INIT_spin(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); InitializeCriticalSection(LOCK); } +# define LIBXSMM_LOCK_DESTROY_spin(LOCK) DeleteCriticalSection((LIBXSMM_LOCK_TYPE_spin*)(LOCK)) +# define LIBXSMM_LOCK_TRYLOCK_spin(LOCK) TryEnterCriticalSection(LOCK) +# define LIBXSMM_LOCK_ACQUIRE_spin(LOCK) EnterCriticalSection(LOCK) +# define LIBXSMM_LOCK_RELEASE_spin(LOCK) LeaveCriticalSection(LOCK) +# define LIBXSMM_LOCK_TRYREAD_spin(LOCK) LIBXSMM_LOCK_TRYLOCK_spin(LOCK) +# define LIBXSMM_LOCK_ACQREAD_spin(LOCK) LIBXSMM_LOCK_ACQUIRE_spin(LOCK) +# define LIBXSMM_LOCK_RELREAD_spin(LOCK) LIBXSMM_LOCK_RELEASE_spin(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_spin int +# define LIBXSMM_LOCK_ATTR_INIT_spin(ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY_spin(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# if defined(LIBXSMM_LOCK_SYSTEM_MUTEX) +# define LIBXSMM_LOCK_ACQUIRED_mutex WAIT_OBJECT_0 +# define LIBXSMM_LOCK_TYPE_ISPOD_mutex 0 +# define LIBXSMM_LOCK_TYPE_ISRW_mutex 0 +# define LIBXSMM_LOCK_TYPE_mutex HANDLE +# define LIBXSMM_LOCK_INIT_mutex(LOCK, ATTR) (*(LOCK) = CreateMutex(*(ATTR), FALSE, NULL)) +# define LIBXSMM_LOCK_DESTROY_mutex(LOCK) CloseHandle(*(LOCK)) +# define LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) WaitForSingleObject(*(LOCK), 0) +# define LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) WaitForSingleObject(*(LOCK), INFINITE) +# define LIBXSMM_LOCK_RELEASE_mutex(LOCK) ReleaseMutex(*(LOCK)) +# define LIBXSMM_LOCK_TRYREAD_mutex(LOCK) LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) +# define LIBXSMM_LOCK_ACQREAD_mutex(LOCK) LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) +# define LIBXSMM_LOCK_RELREAD_mutex(LOCK) LIBXSMM_LOCK_RELEASE_mutex(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_mutex LPSECURITY_ATTRIBUTES +# define LIBXSMM_LOCK_ATTR_INIT_mutex(ATTR) (*(ATTR) = NULL) +# define LIBXSMM_LOCK_ATTR_DESTROY_mutex(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# if defined(LIBXSMM_LOCK_SYSTEM_RWLOCK) +# define LIBXSMM_LOCK_ACQUIRED_rwlock TRUE +# define LIBXSMM_LOCK_TYPE_ISPOD_rwlock 1 +# define LIBXSMM_LOCK_TYPE_ISRW_rwlock 1 +# define LIBXSMM_LOCK_TYPE_rwlock SRWLOCK +# define LIBXSMM_LOCK_INIT_rwlock(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); InitializeSRWLock(LOCK); } +# define LIBXSMM_LOCK_DESTROY_rwlock(LOCK) LIBXSMM_UNUSED(LOCK) +# define LIBXSMM_LOCK_TRYLOCK_rwlock(LOCK) TryAcquireSRWLockExclusive(LOCK) +# define LIBXSMM_LOCK_ACQUIRE_rwlock(LOCK) AcquireSRWLockExclusive(LOCK) +# define LIBXSMM_LOCK_RELEASE_rwlock(LOCK) ReleaseSRWLockExclusive(LOCK) +# define LIBXSMM_LOCK_TRYREAD_rwlock(LOCK) TryAcquireSRWLockShared(LOCK) +# define LIBXSMM_LOCK_ACQREAD_rwlock(LOCK) AcquireSRWLockShared(LOCK) +# define LIBXSMM_LOCK_RELREAD_rwlock(LOCK) ReleaseSRWLockShared(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_rwlock int +# define LIBXSMM_LOCK_ATTR_INIT_rwlock(ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY_rwlock(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# define LIBXSMM_SYNC_YIELD YieldProcessor() +# else +# define LIBXSMM_TLS_TYPE pthread_key_t +# define LIBXSMM_TLS_CREATE(KEYPTR) pthread_key_create(KEYPTR, NULL) +# define LIBXSMM_TLS_DESTROY(KEY) pthread_key_delete(KEY) +# define LIBXSMM_TLS_SETVALUE(KEY, PTR) pthread_setspecific(KEY, PTR) +# define LIBXSMM_TLS_GETVALUE(KEY) pthread_getspecific(KEY) +# if defined(__APPLE__) && defined(__MACH__) +# define LIBXSMM_SYNC_YIELD pthread_yield_np() +# else +# if defined(__USE_GNU) || !defined(__BSD_VISIBLE) + LIBXSMM_EXTERN int pthread_yield(void) LIBXSMM_THROW; +# else + LIBXSMM_EXTERN void pthread_yield(void); +# endif +# define LIBXSMM_SYNC_YIELD pthread_yield() +# endif +# if defined(LIBXSMM_LOCK_SYSTEM_SPINLOCK) && defined(__APPLE__) && defined(__MACH__) +# define LIBXSMM_LOCK_SPINLOCK mutex +# else +# define LIBXSMM_LOCK_SPINLOCK spin +# endif +# define LIBXSMM_LOCK_MUTEX mutex +# define LIBXSMM_LOCK_RWLOCK rwlock +# if defined(LIBXSMM_LOCK_SYSTEM_SPINLOCK) +# define LIBXSMM_LOCK_ACQUIRED_spin 0 +# define LIBXSMM_LOCK_TYPE_ISPOD_spin 0 +# define LIBXSMM_LOCK_TYPE_ISRW_spin 0 +# define LIBXSMM_LOCK_TYPE_spin pthread_spinlock_t +# define LIBXSMM_LOCK_INIT_spin(LOCK, ATTR) LIBXSMM_EXPECT(0, pthread_spin_init(LOCK, *(ATTR))) +# define LIBXSMM_LOCK_DESTROY_spin(LOCK) LIBXSMM_EXPECT(0, pthread_spin_destroy(LOCK)) +# define LIBXSMM_LOCK_TRYLOCK_spin(LOCK) pthread_spin_trylock(LOCK) +# define LIBXSMM_LOCK_ACQUIRE_spin(LOCK) LIBXSMM_EXPECT(0, pthread_spin_lock(LOCK)) +# define LIBXSMM_LOCK_RELEASE_spin(LOCK) LIBXSMM_EXPECT(0, pthread_spin_unlock(LOCK)) +# define LIBXSMM_LOCK_TRYREAD_spin(LOCK) LIBXSMM_LOCK_TRYLOCK_spin(LOCK) +# define LIBXSMM_LOCK_ACQREAD_spin(LOCK) LIBXSMM_LOCK_ACQUIRE_spin(LOCK) +# define LIBXSMM_LOCK_RELREAD_spin(LOCK) LIBXSMM_LOCK_RELEASE_spin(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_spin int +# define LIBXSMM_LOCK_ATTR_INIT_spin(ATTR) (*(ATTR) = 0) +# define LIBXSMM_LOCK_ATTR_DESTROY_spin(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# if defined(LIBXSMM_LOCK_SYSTEM_MUTEX) +# define LIBXSMM_LOCK_ACQUIRED_mutex 0 +# define LIBXSMM_LOCK_TYPE_ISPOD_mutex 0 +# define LIBXSMM_LOCK_TYPE_ISRW_mutex 0 +# define LIBXSMM_LOCK_TYPE_mutex pthread_mutex_t +# define LIBXSMM_LOCK_INIT_mutex(LOCK, ATTR) LIBXSMM_EXPECT(0, pthread_mutex_init(LOCK, ATTR)) +# define LIBXSMM_LOCK_DESTROY_mutex(LOCK) LIBXSMM_EXPECT_DEBUG(0, pthread_mutex_destroy(LOCK)) +# define LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) pthread_mutex_trylock(LOCK) /*!LIBXSMM_EXPECT*/ +# define LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) LIBXSMM_EXPECT(0, pthread_mutex_lock(LOCK)) +# define LIBXSMM_LOCK_RELEASE_mutex(LOCK) LIBXSMM_EXPECT(0, pthread_mutex_unlock(LOCK)) +# define LIBXSMM_LOCK_TRYREAD_mutex(LOCK) LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) +# define LIBXSMM_LOCK_ACQREAD_mutex(LOCK) LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) +# define LIBXSMM_LOCK_RELREAD_mutex(LOCK) LIBXSMM_LOCK_RELEASE_mutex(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_mutex pthread_mutexattr_t +#if !defined(__linux__) || defined(__USE_UNIX98) || defined(__USE_XOPEN2K8) +# if defined(_DEBUG) +# define LIBXSMM_LOCK_ATTR_INIT_mutex(ATTR) (LIBXSMM_EXPECT(0, pthread_mutexattr_init(ATTR)), \ + LIBXSMM_EXPECT(0, pthread_mutexattr_settype(ATTR, PTHREAD_MUTEX_ERRORCHECK))) +# else +# define LIBXSMM_LOCK_ATTR_INIT_mutex(ATTR) (pthread_mutexattr_init(ATTR), \ + pthread_mutexattr_settype(ATTR, PTHREAD_MUTEX_NORMAL)) +# endif +#else +# define LIBXSMM_LOCK_ATTR_INIT_mutex(ATTR) pthread_mutexattr_init(ATTR) +#endif +# define LIBXSMM_LOCK_ATTR_DESTROY_mutex(ATTR) LIBXSMM_EXPECT(0, pthread_mutexattr_destroy(ATTR)) +# endif +# if defined(LIBXSMM_LOCK_SYSTEM_RWLOCK) +# define LIBXSMM_LOCK_ACQUIRED_rwlock 0 +# define LIBXSMM_LOCK_TYPE_ISPOD_rwlock 0 +# define LIBXSMM_LOCK_TYPE_ISRW_rwlock 1 +# define LIBXSMM_LOCK_TYPE_rwlock pthread_rwlock_t +# define LIBXSMM_LOCK_INIT_rwlock(LOCK, ATTR) LIBXSMM_EXPECT(0, pthread_rwlock_init(LOCK, ATTR)) +# define LIBXSMM_LOCK_DESTROY_rwlock(LOCK) LIBXSMM_EXPECT(0, pthread_rwlock_destroy(LOCK)) +# define LIBXSMM_LOCK_TRYLOCK_rwlock(LOCK) pthread_rwlock_trywrlock(LOCK) +# define LIBXSMM_LOCK_ACQUIRE_rwlock(LOCK) LIBXSMM_EXPECT(0, pthread_rwlock_wrlock(LOCK)) +# define LIBXSMM_LOCK_RELEASE_rwlock(LOCK) LIBXSMM_EXPECT(0, pthread_rwlock_unlock(LOCK)) +# define LIBXSMM_LOCK_TRYREAD_rwlock(LOCK) pthread_rwlock_tryrdlock(LOCK) +# define LIBXSMM_LOCK_ACQREAD_rwlock(LOCK) LIBXSMM_EXPECT(0, pthread_rwlock_rdlock(LOCK)) +# define LIBXSMM_LOCK_RELREAD_rwlock(LOCK) LIBXSMM_LOCK_RELEASE_rwlock(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_rwlock pthread_rwlockattr_t +# define LIBXSMM_LOCK_ATTR_INIT_rwlock(ATTR) LIBXSMM_EXPECT(0, pthread_rwlockattr_init(ATTR)) +# define LIBXSMM_LOCK_ATTR_DESTROY_rwlock(ATTR) LIBXSMM_EXPECT(0, pthread_rwlockattr_destroy(ATTR)) +# endif +# endif +/* OpenMP based locks need to stay disabled unless both + * libxsmm and libxsmmext are built with OpenMP support. + */ +# if defined(_OPENMP) && defined(LIBXSMM_SYNC_OMP) +# if !defined(LIBXSMM_LOCK_SYSTEM_SPINLOCK) +# define LIBXSMM_LOCK_ACQUIRED_spin 1 +# define LIBXSMM_LOCK_TYPE_ISPOD_spin 0 +# define LIBXSMM_LOCK_TYPE_ISRW_spin 0 +# define LIBXSMM_LOCK_TYPE_spin omp_lock_t +# define LIBXSMM_LOCK_DESTROY_spin(LOCK) omp_destroy_lock(LOCK) +# define LIBXSMM_LOCK_TRYLOCK_spin(LOCK) omp_test_lock(LOCK) +# define LIBXSMM_LOCK_ACQUIRE_spin(LOCK) omp_set_lock(LOCK) +# define LIBXSMM_LOCK_RELEASE_spin(LOCK) omp_unset_lock(LOCK) +# define LIBXSMM_LOCK_TRYREAD_spin(LOCK) LIBXSMM_LOCK_TRYLOCK_spin(LOCK) +# define LIBXSMM_LOCK_ACQREAD_spin(LOCK) LIBXSMM_LOCK_ACQUIRE_spin(LOCK) +# define LIBXSMM_LOCK_RELREAD_spin(LOCK) LIBXSMM_LOCK_RELEASE_spin(LOCK) +# if (201811 <= _OPENMP/*v5.0*/) +# define LIBXSMM_LOCK_INIT_spin(LOCK, ATTR) omp_init_lock_with_hint(LOCK, *(ATTR)) +# define LIBXSMM_LOCK_ATTR_TYPE_spin omp_lock_hint_t +# define LIBXSMM_LOCK_ATTR_INIT_spin(ATTR) (*(ATTR) = omp_lock_hint_none) +# else +# define LIBXSMM_LOCK_INIT_spin(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); omp_init_lock(LOCK); } +# define LIBXSMM_LOCK_ATTR_TYPE_spin const void* +# define LIBXSMM_LOCK_ATTR_INIT_spin(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# define LIBXSMM_LOCK_ATTR_DESTROY_spin(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# if !defined(LIBXSMM_LOCK_SYSTEM_MUTEX) +# define LIBXSMM_LOCK_ACQUIRED_mutex 1 +# define LIBXSMM_LOCK_TYPE_ISPOD_mutex 0 +# define LIBXSMM_LOCK_TYPE_ISRW_mutex 0 +# define LIBXSMM_LOCK_TYPE_mutex omp_lock_t +# define LIBXSMM_LOCK_DESTROY_mutex(LOCK) omp_destroy_lock(LOCK) +# define LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) omp_test_lock(LOCK) +# define LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) omp_set_lock(LOCK) +# define LIBXSMM_LOCK_RELEASE_mutex(LOCK) omp_unset_lock(LOCK) +# define LIBXSMM_LOCK_TRYREAD_mutex(LOCK) LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) +# define LIBXSMM_LOCK_ACQREAD_mutex(LOCK) LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) +# define LIBXSMM_LOCK_RELREAD_mutex(LOCK) LIBXSMM_LOCK_RELEASE_mutex(LOCK) +# if (201811 <= _OPENMP/*v5.0*/) +# define LIBXSMM_LOCK_INIT_mutex(LOCK, ATTR) omp_init_lock_with_hint(LOCK, *(ATTR)) +# define LIBXSMM_LOCK_ATTR_TYPE_mutex omp_lock_hint_t +# define LIBXSMM_LOCK_ATTR_INIT_mutex(ATTR) (*(ATTR) = omp_lock_hint_none) +# else +# define LIBXSMM_LOCK_INIT_mutex(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); omp_init_lock(LOCK); } +# define LIBXSMM_LOCK_ATTR_TYPE_mutex const void* +# define LIBXSMM_LOCK_ATTR_INIT_mutex(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# define LIBXSMM_LOCK_ATTR_DESTROY_mutex(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# if !defined(LIBXSMM_LOCK_SYSTEM_RWLOCK) +# define LIBXSMM_LOCK_ACQUIRED_rwlock 1 +# define LIBXSMM_LOCK_TYPE_ISPOD_rwlock 0 +# define LIBXSMM_LOCK_TYPE_ISRW_rwlock 0 +# define LIBXSMM_LOCK_TYPE_rwlock omp_lock_t +# define LIBXSMM_LOCK_DESTROY_rwlock(LOCK) omp_destroy_lock(LOCK) +# define LIBXSMM_LOCK_TRYLOCK_rwlock(LOCK) omp_test_lock(LOCK) +# define LIBXSMM_LOCK_ACQUIRE_rwlock(LOCK) omp_set_lock(LOCK) +# define LIBXSMM_LOCK_RELEASE_rwlock(LOCK) omp_unset_lock(LOCK) +# define LIBXSMM_LOCK_TRYREAD_rwlock(LOCK) LIBXSMM_LOCK_TRYLOCK_rwlock(LOCK) +# define LIBXSMM_LOCK_ACQREAD_rwlock(LOCK) LIBXSMM_LOCK_ACQUIRE_rwlock(LOCK) +# define LIBXSMM_LOCK_RELREAD_rwlock(LOCK) LIBXSMM_LOCK_RELEASE_rwlock(LOCK) +# if (201811 <= _OPENMP/*v5.0*/) +# define LIBXSMM_LOCK_INIT_rwlock(LOCK, ATTR) omp_init_lock_with_hint(LOCK, *(ATTR)) +# define LIBXSMM_LOCK_ATTR_TYPE_rwlock omp_lock_hint_t +# define LIBXSMM_LOCK_ATTR_INIT_rwlock(ATTR) (*(ATTR) = omp_lock_hint_none) +# else +# define LIBXSMM_LOCK_INIT_rwlock(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); omp_init_lock(LOCK); } +# define LIBXSMM_LOCK_ATTR_TYPE_rwlock const void* +# define LIBXSMM_LOCK_ATTR_INIT_rwlock(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# define LIBXSMM_LOCK_ATTR_DESTROY_rwlock(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# elif !defined(LIBXSMM_SYNC_NONE) /* based on atomic primitives */ +# if !defined(LIBXSMM_LOCK_SYSTEM_SPINLOCK) +# define LIBXSMM_LOCK_ACQUIRED_spin 0 +# define LIBXSMM_LOCK_TYPE_ISPOD_spin 1 +# define LIBXSMM_LOCK_TYPE_ISRW_spin 0 +# define LIBXSMM_LOCK_TYPE_spin volatile LIBXSMM_ATOMIC_LOCKTYPE +# define LIBXSMM_LOCK_INIT_spin(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); (*(LOCK) = 0); } +# define LIBXSMM_LOCK_DESTROY_spin(LOCK) LIBXSMM_UNUSED(LOCK) +# define LIBXSMM_LOCK_TRYLOCK_spin(LOCK) (LIBXSMM_LOCK_ACQUIRED_spin + !LIBXSMM_ATOMIC_TRYLOCK(LOCK, LIBXSMM_ATOMIC_RELAXED)) +# define LIBXSMM_LOCK_ACQUIRE_spin(LOCK) LIBXSMM_ATOMIC_ACQUIRE(LOCK, LIBXSMM_SYNC_NPAUSE, LIBXSMM_ATOMIC_RELAXED) +# define LIBXSMM_LOCK_RELEASE_spin(LOCK) LIBXSMM_ATOMIC_RELEASE(LOCK, LIBXSMM_ATOMIC_RELAXED) +# define LIBXSMM_LOCK_TRYREAD_spin(LOCK) LIBXSMM_LOCK_TRYLOCK_spin(LOCK) +# define LIBXSMM_LOCK_ACQREAD_spin(LOCK) LIBXSMM_LOCK_ACQUIRE_spin(LOCK) +# define LIBXSMM_LOCK_RELREAD_spin(LOCK) LIBXSMM_LOCK_RELEASE_spin(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_spin int +# define LIBXSMM_LOCK_ATTR_INIT_spin(ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY_spin(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# if !defined(LIBXSMM_LOCK_SYSTEM_MUTEX) +# define LIBXSMM_LOCK_ACQUIRED_mutex 0 +# define LIBXSMM_LOCK_TYPE_ISPOD_mutex 1 +# define LIBXSMM_LOCK_TYPE_ISRW_mutex 0 +# define LIBXSMM_LOCK_TYPE_mutex volatile LIBXSMM_ATOMIC_LOCKTYPE +# define LIBXSMM_LOCK_INIT_mutex(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); (*(LOCK) = 0); } +# define LIBXSMM_LOCK_DESTROY_mutex(LOCK) LIBXSMM_UNUSED(LOCK) +# define LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) (LIBXSMM_LOCK_ACQUIRED_mutex + !LIBXSMM_ATOMIC_TRYLOCK(LOCK, LIBXSMM_ATOMIC_RELAXED)) +# define LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) LIBXSMM_ATOMIC_ACQUIRE(LOCK, LIBXSMM_SYNC_NPAUSE, LIBXSMM_ATOMIC_RELAXED) +# define LIBXSMM_LOCK_RELEASE_mutex(LOCK) LIBXSMM_ATOMIC_RELEASE(LOCK, LIBXSMM_ATOMIC_RELAXED) +# define LIBXSMM_LOCK_TRYREAD_mutex(LOCK) LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) +# define LIBXSMM_LOCK_ACQREAD_mutex(LOCK) LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) +# define LIBXSMM_LOCK_RELREAD_mutex(LOCK) LIBXSMM_LOCK_RELEASE_mutex(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_mutex int +# define LIBXSMM_LOCK_ATTR_INIT_mutex(ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY_mutex(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# if !defined(LIBXSMM_LOCK_SYSTEM_RWLOCK) +# define LIBXSMM_LOCK_ACQUIRED_rwlock 0 +# define LIBXSMM_LOCK_TYPE_ISPOD_rwlock 1 +# define LIBXSMM_LOCK_TYPE_ISRW_rwlock 0 +# define LIBXSMM_LOCK_TYPE_rwlock volatile LIBXSMM_ATOMIC_LOCKTYPE +# define LIBXSMM_LOCK_INIT_rwlock(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); (*(LOCK) = 0); } +# define LIBXSMM_LOCK_DESTROY_rwlock(LOCK) LIBXSMM_UNUSED(LOCK) +# define LIBXSMM_LOCK_TRYLOCK_rwlock(LOCK) (LIBXSMM_LOCK_ACQUIRED_rwlock + !LIBXSMM_ATOMIC_TRYLOCK(LOCK, LIBXSMM_ATOMIC_RELAXED)) +# define LIBXSMM_LOCK_ACQUIRE_rwlock(LOCK) LIBXSMM_ATOMIC_ACQUIRE(LOCK, LIBXSMM_SYNC_NPAUSE, LIBXSMM_ATOMIC_RELAXED) +# define LIBXSMM_LOCK_RELEASE_rwlock(LOCK) LIBXSMM_ATOMIC_RELEASE(LOCK, LIBXSMM_ATOMIC_RELAXED) +# define LIBXSMM_LOCK_TRYREAD_rwlock(LOCK) LIBXSMM_LOCK_TRYLOCK_rwlock(LOCK) +# define LIBXSMM_LOCK_ACQREAD_rwlock(LOCK) LIBXSMM_LOCK_ACQUIRE_rwlock(LOCK) +# define LIBXSMM_LOCK_RELREAD_rwlock(LOCK) LIBXSMM_LOCK_RELEASE_rwlock(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_rwlock int +# define LIBXSMM_LOCK_ATTR_INIT_rwlock(ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY_rwlock(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# else /* experimental */ +# if !defined(LIBXSMM_LOCK_SYSTEM_SPINLOCK) +# define LIBXSMM_LOCK_ACQUIRED_spin 0 +# define LIBXSMM_LOCK_TYPE_ISPOD_spin 0 +# define LIBXSMM_LOCK_TYPE_ISRW_spin 0 +# define LIBXSMM_LOCK_TYPE_spin libxsmm_spinlock* +# define LIBXSMM_LOCK_INIT_spin(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); (*(LOCK) = libxsmm_spinlock_create()); } +# define LIBXSMM_LOCK_DESTROY_spin(LOCK) libxsmm_spinlock_destroy(*(LOCK)) +# define LIBXSMM_LOCK_TRYLOCK_spin(LOCK) libxsmm_spinlock_trylock(*(LOCK)) +# define LIBXSMM_LOCK_ACQUIRE_spin(LOCK) libxsmm_spinlock_acquire(*(LOCK)) +# define LIBXSMM_LOCK_RELEASE_spin(LOCK) libxsmm_spinlock_release(*(LOCK)) +# define LIBXSMM_LOCK_TRYREAD_spin(LOCK) LIBXSMM_LOCK_TRYLOCK_spin(LOCK) +# define LIBXSMM_LOCK_ACQREAD_spin(LOCK) LIBXSMM_LOCK_ACQUIRE_spin(LOCK) +# define LIBXSMM_LOCK_RELREAD_spin(LOCK) LIBXSMM_LOCK_RELEASE_spin(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_spin int +# define LIBXSMM_LOCK_ATTR_INIT_spin(ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY_spin(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# if !defined(LIBXSMM_LOCK_SYSTEM_MUTEX) +# define LIBXSMM_LOCK_ACQUIRED_mutex 0 +# define LIBXSMM_LOCK_TYPE_ISPOD_mutex 0 +# define LIBXSMM_LOCK_TYPE_ISRW_mutex 0 +# define LIBXSMM_LOCK_TYPE_mutex libxsmm_mutex* +# define LIBXSMM_LOCK_INIT_mutex(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); (*(LOCK) = libxsmm_mutex_create()); } +# define LIBXSMM_LOCK_DESTROY_mutex(LOCK) libxsmm_mutex_destroy(*(LOCK)) +# define LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) libxsmm_mutex_trylock(*(LOCK)) +# define LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) libxsmm_mutex_acquire(*(LOCK)) +# define LIBXSMM_LOCK_RELEASE_mutex(LOCK) libxsmm_mutex_release(*(LOCK)) +# define LIBXSMM_LOCK_TRYREAD_mutex(LOCK) LIBXSMM_LOCK_TRYLOCK_mutex(LOCK) +# define LIBXSMM_LOCK_ACQREAD_mutex(LOCK) LIBXSMM_LOCK_ACQUIRE_mutex(LOCK) +# define LIBXSMM_LOCK_RELREAD_mutex(LOCK) LIBXSMM_LOCK_RELEASE_mutex(LOCK) +# define LIBXSMM_LOCK_ATTR_TYPE_mutex int +# define LIBXSMM_LOCK_ATTR_INIT_mutex(ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY_mutex(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# if !defined(LIBXSMM_LOCK_SYSTEM_RWLOCK) +# define LIBXSMM_LOCK_ACQUIRED_rwlock 0 +# define LIBXSMM_LOCK_TYPE_ISPOD_rwlock 0 +# define LIBXSMM_LOCK_TYPE_ISRW_rwlock 1 +# define LIBXSMM_LOCK_TYPE_rwlock libxsmm_rwlock* +# define LIBXSMM_LOCK_INIT_rwlock(LOCK, ATTR) { LIBXSMM_UNUSED(ATTR); (*(LOCK) = libxsmm_rwlock_create()); } +# define LIBXSMM_LOCK_DESTROY_rwlock(LOCK) libxsmm_rwlock_destroy(*(LOCK)) +# define LIBXSMM_LOCK_TRYLOCK_rwlock(LOCK) libxsmm_rwlock_trylock(*(LOCK)) +# define LIBXSMM_LOCK_ACQUIRE_rwlock(LOCK) libxsmm_rwlock_acquire(*(LOCK)) +# define LIBXSMM_LOCK_RELEASE_rwlock(LOCK) libxsmm_rwlock_release(*(LOCK)) +# define LIBXSMM_LOCK_TRYREAD_rwlock(LOCK) libxsmm_rwlock_tryread(*(LOCK)) +# define LIBXSMM_LOCK_ACQREAD_rwlock(LOCK) libxsmm_rwlock_acqread(*(LOCK)) +# define LIBXSMM_LOCK_RELREAD_rwlock(LOCK) libxsmm_rwlock_relread(*(LOCK)) +# define LIBXSMM_LOCK_ATTR_TYPE_rwlock int +# define LIBXSMM_LOCK_ATTR_INIT_rwlock(ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY_rwlock(ATTR) LIBXSMM_UNUSED(ATTR) +# endif +# endif +#else /* no synchronization */ +# define LIBXSMM_SYNC_YIELD LIBXSMM_SYNC_PAUSE +# define LIBXSMM_LOCK_SPINLOCK spinlock_dummy +# define LIBXSMM_LOCK_MUTEX mutex_dummy +# define LIBXSMM_LOCK_RWLOCK rwlock_dummy +# define LIBXSMM_LOCK_ACQUIRED(KIND) 0 +# define LIBXSMM_LOCK_TYPE_ISPOD(KIND) 1 +# define LIBXSMM_LOCK_TYPE_ISRW(KIND) 0 +# define LIBXSMM_LOCK_ATTR_TYPE(KIND) int +# define LIBXSMM_LOCK_ATTR_INIT(KIND, ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_ATTR_DESTROY(KIND, ATTR) LIBXSMM_UNUSED(ATTR) +# define LIBXSMM_LOCK_TYPE(KIND) int +# define LIBXSMM_LOCK_INIT(KIND, LOCK, ATTR) { LIBXSMM_UNUSED(LOCK); LIBXSMM_UNUSED(ATTR); } +# define LIBXSMM_LOCK_DESTROY(KIND, LOCK) LIBXSMM_UNUSED(LOCK) +# define LIBXSMM_LOCK_TRYLOCK(KIND, LOCK) LIBXSMM_LOCK_ACQUIRED(KIND) +# define LIBXSMM_LOCK_ACQUIRE(KIND, LOCK) LIBXSMM_UNUSED(LOCK) +# define LIBXSMM_LOCK_RELEASE(KIND, LOCK) LIBXSMM_UNUSED(LOCK) +# define LIBXSMM_LOCK_TRYREAD(KIND, LOCK) LIBXSMM_LOCK_TRYLOCK(KIND, LOCK) +# define LIBXSMM_LOCK_ACQREAD(KIND, LOCK) LIBXSMM_LOCK_ACQUIRE(KIND, LOCK) +# define LIBXSMM_LOCK_RELREAD(KIND, LOCK) LIBXSMM_LOCK_RELEASE(KIND, LOCK) +#endif + +#if (0 == LIBXSMM_SYNC) +# define LIBXSMM_FLOCK(FILE) +# define LIBXSMM_FUNLOCK(FILE) +#elif defined(_WIN32) +# define LIBXSMM_FLOCK(FILE) _lock_file(FILE) +# define LIBXSMM_FUNLOCK(FILE) _unlock_file(FILE) +#else +# if !defined(__CYGWIN__) +# define LIBXSMM_FLOCK(FILE) flockfile(FILE) +# define LIBXSMM_FUNLOCK(FILE) funlockfile(FILE) + LIBXSMM_EXTERN void flockfile(FILE*) LIBXSMM_THROW; + LIBXSMM_EXTERN void funlockfile(FILE*) LIBXSMM_THROW; +# else /* Only available with __CYGWIN__ *and* C++0x. */ +# define LIBXSMM_FLOCK(FILE) +# define LIBXSMM_FUNLOCK(FILE) +# endif +#endif + +/** Synchronize console output */ +#define LIBXSMM_STDIO_ACQUIRE() LIBXSMM_FLOCK(stdout); LIBXSMM_FLOCK(stderr) +#define LIBXSMM_STDIO_RELEASE() LIBXSMM_FUNLOCK(stderr); LIBXSMM_FUNLOCK(stdout) + + +/** Opaque type which represents a barrier. */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_barrier libxsmm_barrier; + +/** Create barrier from one of the threads. */ +LIBXSMM_API libxsmm_barrier* libxsmm_barrier_create(int ncores, int nthreads_per_core); +/** Initialize the barrier from each thread of the team. */ +LIBXSMM_API void libxsmm_barrier_init(libxsmm_barrier* barrier, int tid); +/** Wait for the entire team to arrive. */ +LIBXSMM_API void libxsmm_barrier_wait(libxsmm_barrier* barrier, int tid); +/** Destroy the resources associated with this barrier. */ +LIBXSMM_API void libxsmm_barrier_destroy(const libxsmm_barrier* barrier); +/** DEPRECATED: use libxsmm_barrier_destroy instead. */ +#define libxsmm_barrier_release libxsmm_barrier_destroy + +/** Spin-lock, which eventually differs from LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_SPINLOCK). */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_spinlock libxsmm_spinlock; +LIBXSMM_API libxsmm_spinlock* libxsmm_spinlock_create(void); +LIBXSMM_API void libxsmm_spinlock_destroy(const libxsmm_spinlock* spinlock); +LIBXSMM_API int libxsmm_spinlock_trylock(libxsmm_spinlock* spinlock); +LIBXSMM_API void libxsmm_spinlock_acquire(libxsmm_spinlock* spinlock); +LIBXSMM_API void libxsmm_spinlock_release(libxsmm_spinlock* spinlock); + +/** Mutual-exclusive lock (Mutex), which eventually differs from LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_MUTEX). */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_mutex libxsmm_mutex; +LIBXSMM_API libxsmm_mutex* libxsmm_mutex_create(void); +LIBXSMM_API void libxsmm_mutex_destroy(const libxsmm_mutex* mutex); +LIBXSMM_API int libxsmm_mutex_trylock(libxsmm_mutex* mutex); +LIBXSMM_API void libxsmm_mutex_acquire(libxsmm_mutex* mutex); +LIBXSMM_API void libxsmm_mutex_release(libxsmm_mutex* mutex); + +/** Reader-Writer lock (RW-lock), which eventually differs from LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK). */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_rwlock libxsmm_rwlock; +LIBXSMM_API libxsmm_rwlock* libxsmm_rwlock_create(void); +LIBXSMM_API void libxsmm_rwlock_destroy(const libxsmm_rwlock* rwlock); +LIBXSMM_API int libxsmm_rwlock_trylock(libxsmm_rwlock* rwlock); +LIBXSMM_API void libxsmm_rwlock_acquire(libxsmm_rwlock* rwlock); +LIBXSMM_API void libxsmm_rwlock_release(libxsmm_rwlock* rwlock); +LIBXSMM_API int libxsmm_rwlock_tryread(libxsmm_rwlock* rwlock); +LIBXSMM_API void libxsmm_rwlock_acqread(libxsmm_rwlock* rwlock); +LIBXSMM_API void libxsmm_rwlock_relread(libxsmm_rwlock* rwlock); + +/** Utility function to receive the process ID of the calling process. */ +LIBXSMM_API unsigned int libxsmm_get_pid(void); +/** + * Utility function to receive a Thread-ID (TID) for the calling thread. + * The TID is not related to a specific threading runtime. TID=0 may not + * represent the main thread. TIDs are zero-based and consecutive numbers. + */ +LIBXSMM_API unsigned int libxsmm_get_tid(void); + +#endif /*LIBXSMM_SYNC_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_timer.h b/third_party/libxsmm/include/libxsmm_timer.h new file mode 100644 index 0000000000000000000000000000000000000000..dcd9cb04ff88264304d66755cfc8de73e8b02231 --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_timer.h @@ -0,0 +1,41 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_TIMER_H +#define LIBXSMM_TIMER_H + +#include "libxsmm_macros.h" + + +typedef unsigned long long libxsmm_timer_tickint; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_timer_info { + int tsc; +} libxsmm_timer_info; + + +/** Query timer properties. */ +LIBXSMM_API int libxsmm_get_timer_info(libxsmm_timer_info* info); + +/** + * Returns the current clock tick of a monotonic timer source with + * platform-specific resolution (not necessarily CPU cycles). + */ +LIBXSMM_API libxsmm_timer_tickint libxsmm_timer_tick(void); + +/** Returns the difference between two timer ticks (cycles); avoids potential side-effects/assumptions of LIBXSMM_DIFF. */ +LIBXSMM_API_INLINE libxsmm_timer_tickint libxsmm_timer_ncycles(libxsmm_timer_tickint tick0, libxsmm_timer_tickint tick1) { + return LIBXSMM_DELTA(tick0, tick1); +} + +/** Returns the duration (in seconds) between two values received by libxsmm_timer_tick. */ +LIBXSMM_API double libxsmm_timer_duration(libxsmm_timer_tickint tick0, libxsmm_timer_tickint tick1); + +#endif /*LIBXSMM_TIMER_H*/ diff --git a/third_party/libxsmm/include/libxsmm_typedefs.h b/third_party/libxsmm/include/libxsmm_typedefs.h new file mode 100644 index 0000000000000000000000000000000000000000..dc2405c9578c2a3d465d1e50dcf79e299d269e4e --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_typedefs.h @@ -0,0 +1,878 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_TYPEDEFS_H +#define LIBXSMM_TYPEDEFS_H + +#include "libxsmm_macros.h" + +/** Check ILP64 configuration for sanity. */ +#if !defined(LIBXSMM_ILP64) || (0 == LIBXSMM_ILP64 && defined(MKL_ILP64)) +# error "Inconsistent ILP64 configuration detected!" +#elif (0 != LIBXSMM_ILP64 && !defined(MKL_ILP64)) +# define MKL_ILP64 +#endif +#if (0 != LIBXSMM_ILP64) +# define LIBXSMM_BLASINT_NBITS 64 +# define LIBXSMM_BLASINT long long +#else /* LP64 */ +# define LIBXSMM_BLASINT_NBITS 32 +# define LIBXSMM_BLASINT int +#endif + +/** Generic prefetches; similar to LIBXSMM_PREFETCH_AUTO (libxsmm_frontend.h) */ +#define LIBXSMM_PREFETCH_SIGONLY 1 +#define LIBXSMM_PREFETCH_NONE 0 + +/** Helper macro for type names. */ +#define LIBXSMM_TYPENAME(TYPE) LIBXSMM_STRINGIFY(LIBXSMM_CONCATENATE(LIBXSMM_TYPENAME_, TYPE)) +#define LIBXSMM_TYPENAME_double f64 +#define LIBXSMM_TYPENAME_float f32 +#define LIBXSMM_TYPENAME_libxsmm_bfloat16 bf16 +#define LIBXSMM_TYPENAME_libxsmm_float16 f16 +#define LIBXSMM_TYPENAME_int i32 +#define LIBXSMM_TYPENAME_short i16 +#define LIBXSMM_TYPENAME_char i8 + +/** Helper macro for type information: INFO := { FP }. */ +#define LIBXSMM_TYPEINFO(TYPE, INFO) LIBXSMM_CONCATENATE4(LIBXSMM_TYPEINFO_, INFO, _, TYPE) +#define LIBXSMM_TYPEINFO_FP_double 1 +#define LIBXSMM_TYPEINFO_FP_float 1 +#define LIBXSMM_TYPEINFO_FP_libxsmm_bfloat16 1 +#define LIBXSMM_TYPEINFO_FP_libxsmm_float16 1 +#define LIBXSMM_TYPEINFO_FP_int 0 +#define LIBXSMM_TYPEINFO_FP_short 0 +#define LIBXSMM_TYPEINFO_FP_char 0 + +/** Helper macro for type postfixes. */ +#define LIBXSMM_TYPESYMBOL(TYPE) LIBXSMM_CONCATENATE(LIBXSMM_TYPESYMBOL_, TYPE) +#define LIBXSMM_TYPESYMBOL_double F64 +#define LIBXSMM_TYPESYMBOL_float F32 +#define LIBXSMM_TYPESYMBOL_libxsmm_bfloat16 BF16 +#define LIBXSMM_TYPESYMBOL_libxsmm_float16 F16 +#define LIBXSMM_TYPESYMBOL_int I32 +#define LIBXSMM_TYPESYMBOL_short I16 +#define LIBXSMM_TYPESYMBOL_char I8 + +#define LIBXSMM_TYPESIZE(ENUM) ( \ + ((int)(ENUM)) == LIBXSMM_DATATYPE_F64 ? 8 : ( \ + ((int)(ENUM)) == LIBXSMM_DATATYPE_F32 ? 4 : ( \ + ((int)(ENUM)) == LIBXSMM_DATATYPE_BF16 ? 2 : ( \ + ((int)(ENUM)) == LIBXSMM_DATATYPE_F16 ? 2 : ( \ + ((int)(ENUM)) == LIBXSMM_DATATYPE_I64 ? 8 : ( \ + ((int)(ENUM)) == LIBXSMM_DATATYPE_I32 ? 4 : ( \ + ((int)(ENUM)) == LIBXSMM_DATATYPE_I16 ? 2 : ( \ + ((int)(ENUM)) == LIBXSMM_DATATYPE_I8 ? 1 : ( \ + 0/*invalid*/))))))))) + +/* Get input or output precision */ +#define LIBXSMM_GETENUM_INP(SRC) ((SRC) & 0x0F) +#define LIBXSMM_GETENUM_OUT(SRC) (0 == ((SRC) >> 4) ? LIBXSMM_GETENUM_INP(SRC) : ((SRC) >> 4)) +/* Get/Set input and output precision */ +#define LIBXSMM_GETENUM(INP, OUT) (((INP) == (OUT)) ? (INP) : ((INP) | ((OUT) << 4))) +#define LIBXSMM_SETENUM(DST, INP, OUT) DST = LIBXSMM_GETENUM(INP, OUT) + +/* Construct an enumerator (libxsmm_datatype) from a built-in type (float, double, etc.). */ +#define LIBXSMM_DATATYPE(TYPE) LIBXSMM_CONCATENATE(LIBXSMM_DATATYPE_, LIBXSMM_TYPESYMBOL(TYPE)) +/* Construct a type-id from built-in input/output types (float, double, etc.). */ +#define LIBXSMM_DATATYPE2(ITYPE, OTYPE) LIBXSMM_GETENUM(LIBXSMM_DATATYPE(ITYPE), LIBXSMM_DATATYPE(OTYPE)) + +/* Construct an enumerator (libxsmm_gemm_precision) from a built-in type (float, double, etc.). */ +#define LIBXSMM_GEMM_PRECISION(TYPE) LIBXSMM_CONCATENATE(LIBXSMM_GEMM_PRECISION_, LIBXSMM_TYPESYMBOL(TYPE)) +/* Construct GEMM-precision from built-in input/output types (float, double, etc.). */ +#define LIBXSMM_GEMM_PRECISION2(ITYPE, OTYPE) (libxsmm_gemm_precision)LIBXSMM_GETENUM( \ + LIBXSMM_GEMM_PRECISION(ITYPE), LIBXSMM_GEMM_PRECISION(OTYPE)) + +/** Maximum size available to store a descriptor/blob (GEMM, MCOPY, TRANS, TRSM, TRMM). */ +#if !defined(LIBXSMM_DESCRIPTOR_MAXSIZE) +# define LIBXSMM_DESCRIPTOR_MAXSIZE 96 +#endif +/** Size of the descriptor considered as unique/small signature. */ +#if !defined(LIBXSMM_DESCRIPTOR_SIGSIZE) +# if defined(LIBXSMM_UNPACKED) +# define LIBXSMM_DESCRIPTOR_SIGSIZE 64 +# else +# define LIBXSMM_DESCRIPTOR_SIGSIZE 32 +# endif +#endif + + +/* Support for Bfloat16 */ +typedef unsigned short libxsmm_bfloat16; +typedef unsigned short libxsmm_float16; + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_bfloat16_hp { + libxsmm_bfloat16 i[2]; + float f; +} libxsmm_bfloat16_hp; + +#if defined(__cplusplus) +namespace Eigen { struct bfloat16; } +#endif /*__cplusplus*/ + +/** Integer type for LAPACK/BLAS (LP64: 32-bit, and ILP64: 64-bit). */ +typedef LIBXSMM_BLASINT libxsmm_blasint; + +/** Type representing sufficient storage space for a GEMM handle. */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_gemm_blob { char data[128]; } libxsmm_gemm_blob; +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_gemm_handle libxsmm_gemm_handle; + +/** Type representing sufficient storage space for descriptors (GEMM, TCOPY, MCOPY). */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_descriptor_blob { + char data[LIBXSMM_DESCRIPTOR_MAXSIZE]; +} libxsmm_descriptor_blob; + +/** Structure storing arguments of GEMM-like routines. */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_gemm_descriptor libxsmm_gemm_descriptor; +/** Structure storing arguments of the matrix-eltw routine. */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_descriptor libxsmm_meltw_descriptor; +/** Structure storing arguments of the matrix-equation routine. */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meqn_descriptor libxsmm_meqn_descriptor; + +/** Enumerates element/data types. */ +typedef enum libxsmm_datatype { + LIBXSMM_DATATYPE_F64, + LIBXSMM_DATATYPE_F32, + LIBXSMM_DATATYPE_BF16, + LIBXSMM_DATATYPE_F16, + LIBXSMM_DATATYPE_I64, + LIBXSMM_DATATYPE_I32, + LIBXSMM_DATATYPE_I16, + LIBXSMM_DATATYPE_I8, + LIBXSMM_DATATYPE_UNSUPPORTED +} libxsmm_datatype; + +/** Denotes the precision/data type of GEMM. */ +typedef enum libxsmm_gemm_precision { + LIBXSMM_GEMM_PRECISION_F64 = LIBXSMM_DATATYPE_F64, + LIBXSMM_GEMM_PRECISION_F32 = LIBXSMM_DATATYPE_F32, + LIBXSMM_GEMM_PRECISION_BF16 = LIBXSMM_DATATYPE_BF16, + LIBXSMM_GEMM_PRECISION_F16 = LIBXSMM_DATATYPE_F16, + LIBXSMM_GEMM_PRECISION_I32 = LIBXSMM_DATATYPE_I32, + LIBXSMM_GEMM_PRECISION_I16 = LIBXSMM_DATATYPE_I16, + LIBXSMM_GEMM_PRECISION_I8 = LIBXSMM_DATATYPE_I8 +} libxsmm_gemm_precision; + +typedef enum libxsmm_meltw_operation { + LIBXSMM_MELTW_OPERATION_NONE = 0, + /* for fusion into AMX GEMM */ + LIBXSMM_MELTW_OPERATION_CVTFP32BF16 = 1, + LIBXSMM_MELTW_OPERATION_CVTFP32BF16_ACT = 2, + LIBXSMM_MELTW_OPERATION_ACT_CVTFP32BF16 = 3, + LIBXSMM_MELTW_OPERATION_COLBIAS_ACT = 4, + LIBXSMM_MELTW_OPERATION_DECOMPRESS_A = 5, + LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A = 6, + LIBXSMM_MELTW_OPERATION_TRANSFORM_B_NORM_TO_NORMT_EXT_BUFFER = 7, + LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_TRANSFORM_B_NORM_TO_NORMT_EXT_BUFFER = 8, + LIBXSMM_MELTW_OPERATION_TRANSFORM_C_NORM_TO_VNNI_EXT_BUFFER = 9, + LIBXSMM_MELTW_OPERATION_ACT_TRANSFORM_C_NORM_TO_VNNI_EXT_BUFFER = 10, + /* standalone TPPs */ + LIBXSMM_MELTW_OPERATION_REDUCE = 11, /* to be removed */ + LIBXSMM_MELTW_OPERATION_REDUCE_COLS_IDX = 12, + LIBXSMM_MELTW_OPERATION_OPREDUCE_VECS_IDX = 13, + LIBXSMM_MELTW_OPERATION_UNARY = 14, + LIBXSMM_MELTW_OPERATION_BINARY = 15, + LIBXSMM_MELTW_OPERATION_TERNARY = 16 +} libxsmm_meltw_operation; + +typedef enum libxsmm_meltw_null_flags { + LIBXSMM_MELTW_FLAG_NONE = 0 +} libxsmm_meltw_null_flags; + +typedef enum libxsmm_meltw_redu_flags { + LIBXSMM_MELTW_FLAG_REDUCE_NONE = 0, + LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD = 1, + LIBXSMM_MELTW_FLAG_REDUCE_OP_MAX = 2, + LIBXSMM_MELTW_FLAG_REDUCE_OP_MUL = 4, + LIBXSMM_MELTW_FLAG_REDUCE_ROWS = 8, + LIBXSMM_MELTW_FLAG_REDUCE_COLS = 16, + LIBXSMM_MELTW_FLAG_REDUCE_ELTS = 32, + LIBXSMM_MELTW_FLAG_REDUCE_ELTS_SQUARED = 64, + LIBXSMM_MELTW_FLAG_REDUCE_NCNC_FORMAT = 128, + LIBXSMM_MELTW_FLAG_REDUCE_COLS_IDX_XOR_ACC = 256, + LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_ROWS = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_ROWS, + LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_COLS = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_COLS, + LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_ROWS_ELTS_SQUARED = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_ROWS | LIBXSMM_MELTW_FLAG_REDUCE_ELTS_SQUARED , + LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_ROWS_ELTS_ELTS_SQUARED = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_ROWS | LIBXSMM_MELTW_FLAG_REDUCE_ELTS | LIBXSMM_MELTW_FLAG_REDUCE_ELTS_SQUARED , + LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_COLS_ELTS_ELTS_SQUARED = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_COLS | LIBXSMM_MELTW_FLAG_REDUCE_ELTS | LIBXSMM_MELTW_FLAG_REDUCE_ELTS_SQUARED , + LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_ROWS_ELTS = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_ROWS | LIBXSMM_MELTW_FLAG_REDUCE_ELTS, + LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_COLS_ELTS = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_COLS | LIBXSMM_MELTW_FLAG_REDUCE_ELTS, + LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_COLS_ELTS_NCNC_FORMAT = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_COLS | LIBXSMM_MELTW_FLAG_REDUCE_ELTS | LIBXSMM_MELTW_FLAG_REDUCE_NCNC_FORMAT +} libxsmm_meltw_redu_flags; + +typedef enum libxsmm_meltw_relu_flags { + LIBXSMM_MELTW_FLAG_RELU_NONE = 0, + LIBXSMM_MELTW_FLAG_RELU_FWD = 1, + LIBXSMM_MELTW_FLAG_RELU_BWD = 2, + LIBXSMM_MELTW_FLAG_RELU_BITMASK = 4, + LIBXSMM_MELTW_FLAG_RELU_FWD_BITMASK = LIBXSMM_MELTW_FLAG_RELU_FWD | LIBXSMM_MELTW_FLAG_RELU_BITMASK, + LIBXSMM_MELTW_FLAG_RELU_BWD_BITMASK = LIBXSMM_MELTW_FLAG_RELU_BWD | LIBXSMM_MELTW_FLAG_RELU_BITMASK +} libxsmm_meltw_relu_flags; + +typedef enum libxsmm_meltw_cvt_flags { + LIBXSMM_MELTW_FLAG_CVT_NONE = 0, + LIBXSMM_MELTW_FLAG_CVT_VNNI_FORMAT = 1 +} libxsmm_meltw_cvt_flags; + +typedef enum libxsmm_meltw_cvta_flags { + LIBXSMM_MELTW_FLAG_CVTA_NONE = 0, + LIBXSMM_MELTW_FLAG_CVTA_FUSE_RELU = 1, + LIBXSMM_MELTW_FLAG_CVTA_FUSE_TANH = 2, + LIBXSMM_MELTW_FLAG_CVTA_FUSE_SIGM = 4 +} libxsmm_meltw_cvta_flags; + +typedef enum libxsmm_meltw_acvt_flags { + LIBXSMM_MELTW_FLAG_ACVT_NONE = 0, + LIBXSMM_MELTW_FLAG_ACVT_FUSE_TANH = 1, + LIBXSMM_MELTW_FLAG_ACVT_FUSE_SIGM = 2 +} libxsmm_meltw_acvt_flags; + +typedef enum libxsmm_meltw_flags { + LIBXSMM_MELTW_FLAG_FUSE_NONE = 0, + LIBXSMM_MELTW_FLAG_COLBIAS = 1, + LIBXSMM_MELTW_FLAG_ACT_RELU = 2, + LIBXSMM_MELTW_FLAG_ACT_TANH = 4, + LIBXSMM_MELTW_FLAG_ACT_SIGM = 8, + LIBXSMM_MELTW_FLAG_ACT_GELU = 16, + LIBXSMM_MELTW_FLAG_OVERWRITE_C = 32, + LIBXSMM_MELTW_FLAG_ACT_RELU_BWD = 64, + LIBXSMM_MELTW_FLAG_COLBIAS_OVERWRITE_C = LIBXSMM_MELTW_FLAG_COLBIAS | LIBXSMM_MELTW_FLAG_OVERWRITE_C, + LIBXSMM_MELTW_FLAG_ACT_RELU_OVERWRITE_C = LIBXSMM_MELTW_FLAG_ACT_RELU | LIBXSMM_MELTW_FLAG_OVERWRITE_C, + LIBXSMM_MELTW_FLAG_ACT_TANH_OVERWRITE_C = LIBXSMM_MELTW_FLAG_ACT_TANH | LIBXSMM_MELTW_FLAG_OVERWRITE_C, + LIBXSMM_MELTW_FLAG_ACT_SIGM_OVERWRITE_C = LIBXSMM_MELTW_FLAG_ACT_SIGM | LIBXSMM_MELTW_FLAG_OVERWRITE_C, + LIBXSMM_MELTW_FLAG_ACT_GELU_OVERWRITE_C = LIBXSMM_MELTW_FLAG_ACT_GELU | LIBXSMM_MELTW_FLAG_OVERWRITE_C, + LIBXSMM_MELTW_FLAG_ACT_RELU_BWD_OVERWRITE_C = LIBXSMM_MELTW_FLAG_ACT_RELU_BWD | LIBXSMM_MELTW_FLAG_OVERWRITE_C, + LIBXSMM_MELTW_FLAG_COLBIAS_ACT_RELU = LIBXSMM_MELTW_FLAG_COLBIAS | LIBXSMM_MELTW_FLAG_ACT_RELU, + LIBXSMM_MELTW_FLAG_COLBIAS_ACT_TANH = LIBXSMM_MELTW_FLAG_COLBIAS | LIBXSMM_MELTW_FLAG_ACT_TANH, + LIBXSMM_MELTW_FLAG_COLBIAS_ACT_SIGM = LIBXSMM_MELTW_FLAG_COLBIAS | LIBXSMM_MELTW_FLAG_ACT_SIGM, + LIBXSMM_MELTW_FLAG_COLBIAS_ACT_GELU = LIBXSMM_MELTW_FLAG_COLBIAS | LIBXSMM_MELTW_FLAG_ACT_GELU, + LIBXSMM_MELTW_FLAG_COLBIAS_ACT_RELU_OVERWRITE_C = LIBXSMM_MELTW_FLAG_COLBIAS | LIBXSMM_MELTW_FLAG_ACT_RELU | LIBXSMM_MELTW_FLAG_OVERWRITE_C, + LIBXSMM_MELTW_FLAG_COLBIAS_ACT_TANH_OVERWRITE_C = LIBXSMM_MELTW_FLAG_COLBIAS | LIBXSMM_MELTW_FLAG_ACT_TANH | LIBXSMM_MELTW_FLAG_OVERWRITE_C, + LIBXSMM_MELTW_FLAG_COLBIAS_ACT_SIGM_OVERWRITE_C = LIBXSMM_MELTW_FLAG_COLBIAS | LIBXSMM_MELTW_FLAG_ACT_SIGM | LIBXSMM_MELTW_FLAG_OVERWRITE_C, + LIBXSMM_MELTW_FLAG_COLBIAS_ACT_GELU_OVERWRITE_C = LIBXSMM_MELTW_FLAG_COLBIAS | LIBXSMM_MELTW_FLAG_ACT_GELU | LIBXSMM_MELTW_FLAG_OVERWRITE_C +} libxsmm_meltw_flags; + +typedef enum libxsmm_meltw_opreduce_vecs_flags { + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_NONE = 0, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIN_VECIDX = 1, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN = 2, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY = 4, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_ADD = 8, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_SUB = 16, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_MUL = 32, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_DIV = 64, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_DOT = 128, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_SCALE_OP_RESULT = 256, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_NONE = 512, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_SUM = 1024, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MAX = 2048, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MIN = 4096, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_INDEXED_VEC = 8192, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VEC = 16384, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VECIDX = 32768, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_0 = 65536, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_1 = 131072, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY_REDOP_SUM = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_SUM, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_MUL_REDOP_SUM = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_MUL | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_SUM, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY_REDOP_MAX = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MAX, + LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY_REDOP_MIN = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MIN +} libxsmm_meltw_opreduce_vecs_flags; + +typedef enum libxsmm_meltw_unary_flags { + LIBXSMM_MELTW_FLAG_UNARY_NONE = 0, + LIBXSMM_MELTW_FLAG_UNARY_BITMASK = 1, + LIBXSMM_MELTW_FLAG_UNARY_BCAST_ROW = 2, + LIBXSMM_MELTW_FLAG_UNARY_BCAST_COL = 4, + LIBXSMM_MELTW_FLAG_UNARY_BCAST_SCALAR = 8, + LIBXSMM_MELTW_FLAG_UNARY_REDUCE_COLS = 16, + LIBXSMM_MELTW_FLAG_UNARY_REDUCE_ROWS = 32 +} libxsmm_meltw_unary_flags; + +typedef enum libxsmm_meltw_unary_type { + LIBXSMM_MELTW_TYPE_UNARY_NONE = 0, + LIBXSMM_MELTW_TYPE_UNARY_IDENTITY = 1, /* this is copy */ + LIBXSMM_MELTW_TYPE_UNARY_XOR = 2, /* this is zero */ + LIBXSMM_MELTW_TYPE_UNARY_X2 = 3, + LIBXSMM_MELTW_TYPE_UNARY_SQRT = 4, + LIBXSMM_MELTW_TYPE_UNARY_RELU = 5, + LIBXSMM_MELTW_TYPE_UNARY_RELU_INV = 6, + LIBXSMM_MELTW_TYPE_UNARY_TANH = 7, + LIBXSMM_MELTW_TYPE_UNARY_TANH_INV = 8, + LIBXSMM_MELTW_TYPE_UNARY_SIGMOID = 9, + LIBXSMM_MELTW_TYPE_UNARY_SIGMOID_INV = 10, + LIBXSMM_MELTW_TYPE_UNARY_GELU = 11, + LIBXSMM_MELTW_TYPE_UNARY_GELU_INV = 12, + LIBXSMM_MELTW_TYPE_UNARY_NEGATE = 13, + LIBXSMM_MELTW_TYPE_UNARY_INC = 14, + LIBXSMM_MELTW_TYPE_UNARY_RECIPROCAL = 15, + LIBXSMM_MELTW_TYPE_UNARY_RECIPROCAL_SQRT = 16, + LIBXSMM_MELTW_TYPE_UNARY_EXP = 17, + LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_OP_ADD = 18, + LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X2_OP_ADD = 19, + LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_X2_OP_ADD = 20, + LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_OP_MAX = 21, + LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_OP_MUL = 22, + LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_OP_ADD_NCNC_FORMAT = 23, + LIBXSMM_MELTW_TYPE_UNARY_REDUCE_TO_SCALAR_OP_ADD = 24, + LIBXSMM_MELTW_TYPE_UNARY_DROPOUT = 25, + LIBXSMM_MELTW_TYPE_UNARY_DROPOUT_INV = 26, + LIBXSMM_MELTW_TYPE_UNARY_REPLICATE_COL_VAR = 27, + LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI = 28, + LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT = 29, + LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI_TO_VNNIT = 30, + LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNIT = 31, + LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI_PAD = 32, + LIBXSMM_MELTW_TYPE_UNARY_UNPACK_TO_BLOCKS = 33, + LIBXSMM_MELTW_TYPE_UNARY_LEAKY_RELU = 34, + LIBXSMM_MELTW_TYPE_UNARY_LEAKY_RELU_INV = 35, + LIBXSMM_MELTW_TYPE_UNARY_ELU = 36, + LIBXSMM_MELTW_TYPE_UNARY_ELU_INV = 37, + LIBXSMM_MELTW_TYPE_UNARY_STOCHASTIC_ROUND = 38 +} libxsmm_meltw_unary_type; + +typedef enum libxsmm_meltw_binary_flags { + LIBXSMM_MELTW_FLAG_BINARY_NONE = 0, + LIBXSMM_MELTW_FLAG_BINARY_BCAST_ROW_IN_0 = 1, + LIBXSMM_MELTW_FLAG_BINARY_BCAST_ROW_IN_1 = 2, + LIBXSMM_MELTW_FLAG_BINARY_BCAST_COL_IN_0 = 4, + LIBXSMM_MELTW_FLAG_BINARY_BCAST_COL_IN_1 = 8, + LIBXSMM_MELTW_FLAG_BINARY_BCAST_SCALAR_IN_0 = 16, + LIBXSMM_MELTW_FLAG_BINARY_BCAST_SCALAR_IN_1 = 32 +} libxsmm_meltw_binary_flags; + +typedef enum libxsmm_meltw_binary_type { + LIBXSMM_MELTW_TYPE_BINARY_NONE = 0, + LIBXSMM_MELTW_TYPE_BINARY_ADD = 1, + LIBXSMM_MELTW_TYPE_BINARY_MUL = 2, + LIBXSMM_MELTW_TYPE_BINARY_SUB = 3, + LIBXSMM_MELTW_TYPE_BINARY_DIV = 4, + LIBXSMM_MELTW_TYPE_BINARY_MULADD = 5, + LIBXSMM_MELTW_TYPE_BINARY_MATMUL = 6, + LIBXSMM_MELTW_TYPE_BINARY_MUL_AND_REDUCE_TO_SCALAR_OP_ADD = 7, + LIBXSMM_MELTW_TYPE_BINARY_PACK = 8 +} libxsmm_meltw_binary_type; + +typedef enum libxsmm_meltw_ternary_flags { + LIBXSMM_MELTW_FLAG_TERNARY_NONE = 0, + LIBXSMM_MELTW_FLAG_TERNARY_BCAST_ROW_IN_0 = 1, + LIBXSMM_MELTW_FLAG_TERNARY_BCAST_ROW_IN_1 = 2, + LIBXSMM_MELTW_FLAG_TERNARY_BCAST_ROW_IN_2 = 4, + LIBXSMM_MELTW_FLAG_TERNARY_BCAST_COL_IN_0 = 8, + LIBXSMM_MELTW_FLAG_TERNARY_BCAST_COL_IN_1 = 16, + LIBXSMM_MELTW_FLAG_TERNARY_BCAST_COL_IN_2 = 32, + LIBXSMM_MELTW_FLAG_TERNARY_BCAST_SCALAR_IN_0 = 64, + LIBXSMM_MELTW_FLAG_TERNARY_BCAST_SCALAR_IN_1 = 128, + LIBXSMM_MELTW_FLAG_TERNARY_BCAST_SCALAR_IN_2 = 256, + LIBXSMM_MELTW_FLAG_TERNARY_REUSE_IN_2_AS_OUT = 512 +} libxsmm_meltw_ternary_flags; + +typedef enum libxsmm_meltw_ternary_type { + LIBXSMM_MELTW_TYPE_TERNARY_NONE = 0, + LIBXSMM_MELTW_TYPE_TERNARY_MULADD = 1, + LIBXSMM_MELTW_TYPE_TERNARY_MATMUL = 2, + LIBXSMM_MELTW_TYPE_TERNARY_BLEND = 3, + LIBXSMM_MELTW_TYPE_TERNARY_NMULADD = 4 +} libxsmm_meltw_ternary_type; + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_xmelt_flags { + libxsmm_meltw_null_flags elt_null; + libxsmm_meltw_opreduce_vecs_flags elt_opredvecs; + libxsmm_meltw_relu_flags elt_relu; + libxsmm_meltw_cvta_flags elt_cvta; + libxsmm_meltw_cvt_flags elt_cvt; + libxsmm_meltw_acvt_flags elt_acvt; + libxsmm_meltw_flags elt_meltwfused; +} libxsmm_xmelt_flags; + +/** Flag enumeration which can be binary ORed. */ +typedef enum libxsmm_gemm_flags { + LIBXSMM_GEMM_FLAG_NONE = 0, + /** Transpose matrix A. */ + LIBXSMM_GEMM_FLAG_TRANS_A = 1, + /** Transpose matrix B. */ + LIBXSMM_GEMM_FLAG_TRANS_B = 2, + /** Transpose matrix A and B. */ + LIBXSMM_GEMM_FLAG_TRANS_AB = LIBXSMM_GEMM_FLAG_TRANS_A | LIBXSMM_GEMM_FLAG_TRANS_B, +#if 0 + /** Alpha=0|1 */ + LIBXSMM_GEMM_FLAG_ALPHA_0 = 4, + /** Alpha=neg|pos */ + LIBXSMM_GEMM_FLAG_ALPHA_S = 8, +#endif + /** Beta=0|1 */ + LIBXSMM_GEMM_FLAG_BETA_0 = 16, +#if 0 + /** Beta=neg|pos */ + LIBXSMM_GEMM_FLAG_BETA_S = 32, +#endif + /** Generate aligned load instructions. */ + LIBXSMM_GEMM_FLAG_ALIGN_A = 64, + /** Aligned load/store instructions. */ + LIBXSMM_GEMM_FLAG_ALIGN_C = 128, + /** Batch-reduce Ai * Bi. */ + /** AMX hint to avoid tileconfig/release, it's negated bits, so that 0 is default "on" */ + LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG = 4, + LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG = 8, + LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS = 256, + /** Batch-reduce Ai * Bi. */ + LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET = 512, + /** Batch-reduce Ai * Bi. */ + LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE = 1024, + /** Aligned C matrix, but using NTS Hint when storing */ + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT = 2176, + /* in case of integer GEMM, if A is unsigned */ + LIBXSMM_GEMM_FLAG_A_UNSIGNED = 4096, + /* in case of integer GEMM, if B is unsigned */ + LIBXSMM_GEMM_FLAG_B_UNSIGNED = 8192, + /* in case of integer GEMM, if C is unsigned */ + LIBXSMM_GEMM_FLAG_C_UNSIGNED = 16384, + /* in case of integer GEMM, if A and B are unsigned */ + LIBXSMM_GEMM_FLAG_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_A_UNSIGNED | LIBXSMM_GEMM_FLAG_B_UNSIGNED, + /* for low precision we also require up-front packed formats "VNNI" for best performance, this flag indicates A */ + LIBXSMM_GEMM_FLAG_VNNI_A = 32768, + /* for low precision we also require up-front packed formats "VNNI" for best performance, this flag indicates B */ + LIBXSMM_GEMM_FLAG_VNNI_B = 65536, + /* for low precision we also require post packed formats "VNNI" for best performance, this flag indicated C */ + LIBXSMM_GEMM_FLAG_VNNI_C = 131072, + /* combined types */ + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0 = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_ADDRESS = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_ADDRESS = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_OFFSET = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_OFFSET = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_STRIDE = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_STRIDE = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_A_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_ADDRESS_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_A_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_ADDRESS_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_A_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_OFFSET_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_A_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_OFFSET_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_A_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_STRIDE_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_A_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_STRIDE_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_A_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_B_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_ADDRESS_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_B_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_ADDRESS_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_B_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_OFFSET_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_B_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_OFFSET_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_B_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_STRIDE_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_B_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_STRIDE_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_B_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_ADDRESS_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_ADDRESS_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_OFFSET_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_OFFSET_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_STRIDE_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, + LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_STRIDE_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, + /** Marker flag; do not use. */ + LIBXSMM_GEMM_FLAG_INVALID = 262144 +} libxsmm_gemm_flags; + +/** Flag enumeration which can be binary ORed. */ +typedef enum libxsmm_gemm_handle_flags { + LIBXSMM_GEMM_HANDLE_FLAG_AUTO = 0, + LIBXSMM_GEMM_HANDLE_FLAG_COPY_A = 1, + LIBXSMM_GEMM_HANDLE_FLAG_COPY_B = 2, + LIBXSMM_GEMM_HANDLE_FLAG_COPY_C = 4 +} libxsmm_gemm_handle_flags; + +/** Auto-batch flags (can be ORed) applicable to mmbatch_begin/mmbatch_end. */ +typedef enum libxsmm_mmbatch_flags { + /** Handle recorded batch unsynchronized-parallel. */ + LIBXSMM_MMBATCH_FLAG_DEFAULT = LIBXSMM_GEMM_FLAG_INVALID * 0, + /** Synchronize among C matrices. */ + LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED = LIBXSMM_GEMM_FLAG_INVALID * 1, + /** Handle recorded batch sequentially. */ + LIBXSMM_MMBATCH_FLAG_SEQUENTIAL = LIBXSMM_GEMM_FLAG_INVALID * 2, + /** Only record a statistic of potential SMMs. */ + LIBXSMM_MMBATCH_FLAG_STATISTIC = LIBXSMM_GEMM_FLAG_INVALID * 4 +} libxsmm_mmbatch_flags; + +/** Enumeration of the available prefetch strategies. */ +typedef enum libxsmm_gemm_prefetch_type { + /** No prefetching and no prefetch fn. signature. */ + LIBXSMM_GEMM_PREFETCH_NONE = LIBXSMM_PREFETCH_NONE, + /** Only function prefetch signature. */ + LIBXSMM_GEMM_PREFETCH_SIGONLY = LIBXSMM_PREFETCH_SIGONLY, + /** Prefetch PA using accesses to A. */ + LIBXSMM_GEMM_PREFETCH_AL2 = 2, + /** Prefetch PA (aggressive). */ + LIBXSMM_GEMM_PREFETCH_BL2_VIA_C = 4, + /** Prefetch A ahead. */ + LIBXSMM_GEMM_PREFETCH_AL2_AHEAD = 8, + LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C = LIBXSMM_GEMM_PREFETCH_BL2_VIA_C | LIBXSMM_GEMM_PREFETCH_AL2, + LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD = LIBXSMM_GEMM_PREFETCH_BL2_VIA_C | LIBXSMM_GEMM_PREFETCH_AL2_AHEAD, + /** Backward compatibility: AL2CL2BL2_VIA_C is an alias for AL2BL2_VIA_C (Eigen library). */ + LIBXSMM_PREFETCH_AL2CL2BL2_VIA_C = LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C, + /** Current B into L1. */ + LIBXSMM_GEMM_PREFETCH_BL1 = 16, + LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB = 32 +} libxsmm_gemm_prefetch_type; + +/** Flag enumeration which can be binary ORed. */ +typedef enum libxsmm_matcopy_flags { + LIBXSMM_MATCOPY_FLAG_DEFAULT = 0, + /** If set, then use zero matrix as source */ + LIBXSMM_MATCOPY_FLAG_ZERO_SOURCE = 1 +} libxsmm_matcopy_flags; + +/** Determines the kernel kind. */ +typedef enum libxsmm_kernel_kind { + /** Matrix multiplication kernel */ + LIBXSMM_KERNEL_KIND_MATMUL = 0, + /** Mateltw kernel kind */ + LIBXSMM_KERNEL_KIND_MELTW = 1, + /** Mateqn kernel kind */ + LIBXSMM_KERNEL_KIND_MEQN = 2, + /** User-defined kernels */ + LIBXSMM_KERNEL_KIND_USER = 3, + /** Not a JIT kernel */ + LIBXSMM_KERNEL_UNREGISTERED = 4 +} libxsmm_kernel_kind; + +typedef enum libxsmm_dnn_tensor_format { + /* use LIBXSMM internal format, we need to copy data into that */ + LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM = 1, + /* use NHWC format internally, this allows no-copy operations */ + LIBXSMM_DNN_TENSOR_FORMAT_NHWC = 2, + /* use NCHW format internally, this will include shadow copies, not preferred */ + LIBXSMM_DNN_TENSOR_FORMAT_NCHW = 4, + /* use RSCK format internally, this allows no-copy operations */ + LIBXSMM_DNN_TENSOR_FORMAT_RSCK = 8, + /* use KCRS format internally, this will include shadow copies, not preferred */ + LIBXSMM_DNN_TENSOR_FORMAT_KCRS = 16, + LIBXSMM_DNN_TENSOR_FORMAT_CK = 32, + LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED = 64, + LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED = 128, + LIBXSMM_DNN_TENSOR_FORMAT_NC = 256 +} libxsmm_dnn_tensor_format; + +/** Denotes the element/pixel type of an image/channel. */ +typedef enum libxsmm_dnn_datatype { + LIBXSMM_DNN_DATATYPE_F64 = LIBXSMM_DATATYPE_F64, + LIBXSMM_DNN_DATATYPE_F32 = LIBXSMM_DATATYPE_F32, + LIBXSMM_DNN_DATATYPE_BF16 = LIBXSMM_DATATYPE_BF16, + LIBXSMM_DNN_DATATYPE_F16 = LIBXSMM_DATATYPE_F16, + LIBXSMM_DNN_DATATYPE_I32 = LIBXSMM_DATATYPE_I32, + LIBXSMM_DNN_DATATYPE_I16 = LIBXSMM_DATATYPE_I16, + LIBXSMM_DNN_DATATYPE_I8 = LIBXSMM_DATATYPE_I8 +} libxsmm_dnn_datatype; + +typedef enum libxsmm_dnn_conv_option { + /* we get default settings */ + LIBXSMM_DNN_CONV_OPTION_NONE = 0, + /* overwrite results buffer (set it to zero before running the operations) */ + LIBXSMM_DNN_CONV_OPTION_OVERWRITE = 1, + /* external filter transpose to bwd convolutions */ + LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE = 2, + /* compound types */ + LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE_OVERWRITE = LIBXSMM_DNN_CONV_OPTION_OVERWRITE | LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE +} libxsmm_dnn_conv_option; + +typedef enum libxsmm_dnn_fusedbatchnorm_fuse_order { + /* the fuse order is: 1. BN, 2. element-wise 3. RELU */ + LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU = 0 +} libxsmm_dnn_fusedbatchnorm_fuse_order; + +typedef enum libxsmm_dnn_fusedbatchnorm_fuse_op { + /* the fuse order is: 1. BN, 2. element-wise 3. RELU */ + LIBXSMM_DNN_FUSEDBN_OPS_BN = 1, + LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE = 2, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS = 4, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED = 8, + LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE = 16, + LIBXSMM_DNN_FUSEDBN_OPS_RELU = 32, + LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK = 64, + LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, + LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDBN_OPS_BN_ELTWISE = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE, + LIBXSMM_DNN_FUSEDBN_OPS_BN_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_RELU, + LIBXSMM_DNN_FUSEDBN_OPS_BN_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDBN_OPS_BN_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, + LIBXSMM_DNN_FUSEDBN_OPS_BN_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_ELTWISE = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE, + LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, + LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, + LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_ELTWISE = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_RELU, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_ELTWISE = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_RELU, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, + LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK +} libxsmm_dnn_fusedbatchnorm_fuse_op; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedbatchnorm_desc { + int partN; /* number of images in mini-batch, used for all elementwise computations */ + int fullN; /* number of images in mini-batch, used for statistics computations */ + int C; /* number of input feature maps */ + int H; /* height of input image */ + int W; /* width of input image */ + int u; /* vertical stride */ + int v; /* horizontal stride */ + int pad_h_in; /* height of physical zero-padding in input buffer */ + int pad_w_in; /* width of physical zero-padding in input buffer */ + int pad_h_out; /* height of physical zero-padding in output buffer */ + int pad_w_out; /* width of physical zero-padding in output buffer */ + int threads; /* number of threads used */ + libxsmm_dnn_datatype datatype_in; /* datatype used for all input related buffers */ + libxsmm_dnn_datatype datatype_out; /* datatype used for all output related buffers */ + libxsmm_dnn_datatype datatype_stats; /* datatype used for all stats related buffers */ + libxsmm_dnn_tensor_format buffer_format; /* format which is for activation buffers */ + libxsmm_dnn_fusedbatchnorm_fuse_order fuse_order; /* additional options */ + libxsmm_dnn_fusedbatchnorm_fuse_op fuse_ops; /* used ops into convolutions */ +} libxsmm_dnn_fusedbatchnorm_desc; + +typedef enum libxsmm_dnn_fusedgroupnorm_fuse_order { + /* the fuse order is: 1. BN, 2. element-wise 3. RELU */ + LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU = 0 +} libxsmm_dnn_fusedgroupnorm_fuse_order; + +typedef enum libxsmm_dnn_fusedgroupnorm_fuse_op { + /* the fuse order is: 1. GN, 2. element-wise 3. RELU */ + LIBXSMM_DNN_FUSEDGN_OPS_GN = 1, + LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE = 2, + LIBXSMM_DNN_FUSEDGN_OPS_RELU = 4, + LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK = 8, + LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU = LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDGN_OPS_RELU, + LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDGN_OPS_GN_ELTWISE = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE, + LIBXSMM_DNN_FUSEDGN_OPS_GN_RELU = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_RELU, + LIBXSMM_DNN_FUSEDGN_OPS_GN_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK, + LIBXSMM_DNN_FUSEDGN_OPS_GN_ELTWISE_RELU = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDGN_OPS_RELU, + LIBXSMM_DNN_FUSEDGN_OPS_GN_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK +} libxsmm_dnn_fusedgroupnorm_fuse_op; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedgroupnorm_desc { + int N; /* number of images in mini-batch */ + int G; /* groups of channels to norm */ + int C; /* number of input feature maps */ + int H; /* height of input image */ + int W; /* width of input image */ + int u; /* vertical stride */ + int v; /* horizontal stride */ + int pad_h_in; /* height of physical zero-padding in input buffer */ + int pad_w_in; /* width of physical zero-padding in input buffer */ + int pad_h_out; /* height of physical zero-padding in output buffer */ + int pad_w_out; /* width of physical zero-padding in output buffer */ + int threads; /* number of threads used */ + libxsmm_dnn_datatype datatype_in; /* datatype used for all input related buffers */ + libxsmm_dnn_datatype datatype_out; /* datatype used for all output related buffers */ + libxsmm_dnn_datatype datatype_stats; /* datatype used for all stats related buffers */ + libxsmm_dnn_tensor_format buffer_format; /* format which is for activation buffers */ + libxsmm_dnn_fusedgroupnorm_fuse_order fuse_order; /* additional options */ + libxsmm_dnn_fusedgroupnorm_fuse_op fuse_ops; /* used ops into convolutions */ +} libxsmm_dnn_fusedgroupnorm_desc; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_matrix_arg { + void* primary; + void* secondary; + void* tertiary; +} libxsmm_matrix_arg; + +/** argument struct for matrix-eltwise: reduce */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_reduce_cols_idx_param { + unsigned long long n; + const void* ind_ptr; /* index array pointer */ + const void* inp_ptr; /* input pointer */ + void* out_ptr; /* output pointer */ +} libxsmm_meltw_reduce_cols_idx_param; + +/** argument struct for matrix-eltwise: opreduce vecs indexed */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_opreduce_vecs_idx_param { + unsigned long long n; + const void* indices; /* index array pointer */ + const void* in_matrix; /* input matrix pointer */ + const void* in_vec; /* input vector pointer */ + void* out_vec; /* output pointer */ + const void* scale_vals; /* scale values of indexed vectors after ops */ + const void* indices2; /* index array pointer */ + const void* in_matrix2; /* input matrix pointer */ + void* argop_off_vec_0; + void* argop_off_vec_1; +} libxsmm_meltw_opreduce_vecs_idx_param; + +/** argument struct for matrix-eltwise: unary */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_unary_param { + libxsmm_matrix_arg in; /* input */ + libxsmm_matrix_arg out; /* output */ +} libxsmm_meltw_unary_param; + +/** argument struct for matrix-eltwise: binary */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_binary_param { + libxsmm_matrix_arg in0; /* 1st input */ + libxsmm_matrix_arg in1; /* 2nd input */ + libxsmm_matrix_arg out; /* output */ +} libxsmm_meltw_binary_param; + +/** argument struct for matrix-eltwise: ternary */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_ternary_param { + libxsmm_matrix_arg in0; /* 1st input */ + libxsmm_matrix_arg in1; /* 2nd input */ + libxsmm_matrix_arg in2; /* 3rd input */ + libxsmm_matrix_arg out; /* output */ +} libxsmm_meltw_ternary_param; + +/** argument struct for matrix equation */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_matrix_eqn_param { + const libxsmm_matrix_arg* inputs; /* array of input args */ + libxsmm_matrix_arg output; /* output arg */ +} libxsmm_matrix_eqn_param; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_gemm_param { + const void* bias_ptr; /* optional, col-bias pointer */ + void* out_ptr; /* optional, pointer to output after eltwise (contains mask in case of ReLU); */ + /* Need for some activation functions, assumed to have the same shape as C matrix, */ + /* may not be set when OVERWRITE_C option is chosen */ + /* If OVERWRITE_C is false: out_ptr contains the post-act output, C has the pre-act output */ + /* If OVERWRITE_C is true: C contains post-act output, out_ptr contains the ReLU mask (only when act was ReLU) for other act unused */ + void* sparse_bitmap; + void* decompress_buffer; + void* relu_bitmask_bwd; +} libxsmm_meltw_gemm_param; + +/** Specialized function for matrix-eltw (weak-typed). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_reduce_cols_idx)(const libxsmm_meltw_reduce_cols_idx_param* in_struct); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_opreduce_vecs_idx)(const libxsmm_meltw_opreduce_vecs_idx_param* in_struct); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_unary)(const libxsmm_meltw_unary_param* in_struct); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_binary)(const libxsmm_meltw_binary_param* in_struct); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_ternary)(const libxsmm_meltw_ternary_param* in_struct); + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_xmeltwfunction { + void (*xmeltw)(const void* in_struct); + libxsmm_meltwfunction_reduce_cols_idx meltw_reduce_cols_idx; + libxsmm_meltwfunction_opreduce_vecs_idx meltw_opreduce_vecs_idx; + libxsmm_meltwfunction_unary meltw_unary; + libxsmm_meltwfunction_binary meltw_binary; + libxsmm_meltwfunction_ternary meltw_ternary; +} libxsmm_xmeltwfunction; + +/** Specialized function with fused alpha and beta arguments, and optional prefetch locations (double-precision). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dmmfunction)(const double* a, const double* b, double* c, ...); +/** Specialized function with fused alpha and beta arguments, and optional prefetch locations (single-precision). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_smmfunction)(const float* a, const float* b, float* c, ...); +/** Specialized function with fused alpha and beta arguments, and optional prefetch locations (bf16, fp32-accumulate). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bsmmfunction)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, float* c, ...); +/** Specialized function with fused alpha and beta arguments, and optional prefetch locations (bf16, fp32-accumulate). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bmmfunction)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, libxsmm_bfloat16* c, ...); +/** Specialized function with fused alpha and beta arguments, and optional prefetch locations (low-precision). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_wimmfunction)(const short* a, const short* b, int* c, ...); +/** Specialized function with fused alpha and beta arguments, and optional prefetch locations (int8, int32 accumulate). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_ssbimmfunction)(const char* a, const char* b, int* c, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_usbimmfunction)(const unsigned char* a, const char* b, int* c, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_subimmfunction)(const char* a, const unsigned char* b, int* c, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_uubimmfunction)(const unsigned char* a, const unsigned char* b, int* c, ...); +/** Specialized function with fused alpha and beta arguments, and optional prefetch locations (int8, int32 accumulate, int8 downconvert). */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sububmmfunction)(const char* a, const unsigned char* b, unsigned char* c, float* scf, ...); + +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dmmfunction_reducebatch_addr)(const double** a, const double** b, double* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_smmfunction_reducebatch_addr)(const float** a, const float** b, float* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bsmmfunction_reducebatch_addr)(const libxsmm_bfloat16** a, const libxsmm_bfloat16** b, float* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bmmfunction_reducebatch_addr)(const libxsmm_bfloat16** a, const libxsmm_bfloat16** b, libxsmm_bfloat16* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_wimmfunction_reducebatch_addr)(const short** a, const short** b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_ssbimmfunction_reducebatch_addr)(const char** a, const char** b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_usbimmfunction_reducebatch_addr)(const unsigned char** a, const char** b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_subimmfunction_reducebatch_addr)(const char** a, const unsigned char** b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_uubimmfunction_reducebatch_addr)(const unsigned char** a, const unsigned char** b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sububmmfunction_reducebatch_addr)(const char** a, const unsigned char** b, unsigned char* c, const unsigned long long* count, float* scf, ...); + +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dmmfunction_reducebatch_offs)(const double* a, const double* b, double* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_smmfunction_reducebatch_offs)(const float* a, const float* b, float* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bsmmfunction_reducebatch_offs)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, float* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bmmfunction_reducebatch_offs)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, libxsmm_bfloat16* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_wimmfunction_reducebatch_offs)(const short* a, const short* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_ssbimmfunction_reducebatch_offs)(const char* a, const char* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_usbimmfunction_reducebatch_offs)(const unsigned char* a, const char* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_subimmfunction_reducebatch_offs)(const char* a, const unsigned char* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_uubimmfunction_reducebatch_offs)(const unsigned char* a, const unsigned char* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sububmmfunction_reducebatch_offs)(const char* a, const unsigned char* b, unsigned char* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, float* scf, ...); + +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dmmfunction_reducebatch_strd)(const double* a, const double* b, double* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_smmfunction_reducebatch_strd)(const float* a, const float* b, float* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bsmmfunction_reducebatch_strd)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, float* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bmmfunction_reducebatch_strd)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, libxsmm_bfloat16* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_wimmfunction_reducebatch_strd)(const short* a, const short* b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_ssbimmfunction_reducebatch_strd)(const char* a, const char* b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_usbimmfunction_reducebatch_strd)(const unsigned char* a, const char* b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_subimmfunction_reducebatch_strd)(const char* a, const unsigned char* b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_uubimmfunction_reducebatch_strd)(const unsigned char* a, const unsigned char* b, int* c, const unsigned long long* count, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sububmmfunction_reducebatch_strd)(const char* a, const unsigned char* b, unsigned char* c, const unsigned long long* count, float* scf, ...); + +/* GEMM fused with elwise */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bmmfunction_reducebatch_strd_meltwfused)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, libxsmm_bfloat16* c, const unsigned long long* count, const libxsmm_meltw_gemm_param* meltw_param, ...); +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bsmmfunction_reducebatch_strd_meltwfused)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, float* c, const unsigned long long* count, const libxsmm_meltw_gemm_param* meltw_param, ...); + +/** Function type which is either libxsmm_smmfunction or libxsmm_dmmfunction (weak-typed). */ +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_xmmfunction { + void (*xmm)(const void* a, const void* b, void* c, ...); + void (*xbm)(const void** a, const void** b, void* c, const unsigned long long* count, ...); + libxsmm_dmmfunction dmm; libxsmm_smmfunction smm; libxsmm_wimmfunction wimm; libxsmm_bsmmfunction bsmm; libxsmm_bmmfunction bmm; + libxsmm_ssbimmfunction ssbimm; libxsmm_usbimmfunction usbimm; libxsmm_subimmfunction subimm; libxsmm_uubimmfunction uubimm; libxsmm_sububmmfunction sububmm; + libxsmm_dmmfunction_reducebatch_addr dmra; libxsmm_smmfunction_reducebatch_addr smra; libxsmm_bsmmfunction_reducebatch_addr bsmra; libxsmm_bmmfunction_reducebatch_addr bmra; + libxsmm_wimmfunction_reducebatch_addr wimra; libxsmm_ssbimmfunction_reducebatch_addr ssbimra; libxsmm_usbimmfunction_reducebatch_addr usbimra; libxsmm_subimmfunction_reducebatch_addr subimra; libxsmm_uubimmfunction_reducebatch_addr uubimra; + libxsmm_sububmmfunction_reducebatch_addr sububmra; + libxsmm_dmmfunction_reducebatch_offs dmro; libxsmm_smmfunction_reducebatch_offs smro; libxsmm_bsmmfunction_reducebatch_offs bsmro; libxsmm_bmmfunction_reducebatch_offs bmro; + libxsmm_wimmfunction_reducebatch_offs wimro; libxsmm_ssbimmfunction_reducebatch_offs ssbimro; libxsmm_usbimmfunction_reducebatch_offs usbimro; libxsmm_subimmfunction_reducebatch_offs subimro; libxsmm_uubimmfunction_reducebatch_offs uubimro; + libxsmm_sububmmfunction_reducebatch_offs sububmro; + libxsmm_dmmfunction_reducebatch_strd dmrs; libxsmm_smmfunction_reducebatch_strd smrs; libxsmm_bsmmfunction_reducebatch_strd bsmrs; libxsmm_bmmfunction_reducebatch_strd bmrs; + libxsmm_wimmfunction_reducebatch_strd wimrs; libxsmm_ssbimmfunction_reducebatch_strd ssbimrs; libxsmm_usbimmfunction_reducebatch_strd usbimrs; libxsmm_subimmfunction_reducebatch_strd subimrs; libxsmm_uubimmfunction_reducebatch_strd uubimrs; + libxsmm_sububmmfunction_reducebatch_strd sububmrs; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bmrs_meltwfused; + libxsmm_bsmmfunction_reducebatch_strd_meltwfused bsmrs_meltwfused; +} libxsmm_xmmfunction; + +/* matrix equation function */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_matrix_eqn_function)(const libxsmm_matrix_eqn_param* in_struct); + +/** Structure to receive information about GEMM-kernels (libxsmm_get_mmkernel_info). */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_mmkernel_info { + /** Input/output data-type */ + libxsmm_gemm_precision iprecision, oprecision; + /** Prefetch strategy. */ + libxsmm_gemm_prefetch_type prefetch; + /** Leading dimensions. */ + unsigned int lda, ldb, ldc; + /** Extents/shape. */ + unsigned int m, n, k; + /** Set of flags. */ + int flags; +} libxsmm_mmkernel_info; + +/** Structure to receive information about matrix-eltw kernels (libxsmm_get_meltwkernel_info). */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltwkernel_info { + /** LDx, M, and N. */ + unsigned int ldi, ldo, m, n; + /** Size of data element. */ + unsigned int datatype; + /** Set of flags. */ + unsigned int flags; + /** Set of operation. */ + unsigned int operation; +} libxsmm_meltwkernel_info; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_kernel_info { + libxsmm_kernel_kind kind; + /** Number of FLoating Point OperationS (FLOPS). */ + unsigned int nflops; + /** Code size (Bytes). */ + size_t code_size; +} libxsmm_kernel_info; + +/** Structure to receive information about the code registry status (libxsmm_get_registry_info). */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_registry_info { + size_t capacity, size, nbytes, nstatic, ncache; +} libxsmm_registry_info; + +#endif /*LIBXSMM_TYPEDEFS_H*/ + diff --git a/third_party/libxsmm/include/libxsmm_version.h b/third_party/libxsmm/include/libxsmm_version.h new file mode 100644 index 0000000000000000000000000000000000000000..1c0bdd906a7b1a995609c048f39d8c66960fd22c --- /dev/null +++ b/third_party/libxsmm/include/libxsmm_version.h @@ -0,0 +1,13 @@ +#ifndef LIBXSMM_VERSION_H +#define LIBXSMM_VERSION_H + +#define LIBXSMM_CONFIG_VERSION "1.16.1-1534" +#define LIBXSMM_CONFIG_BRANCH "master" +#define LIBXSMM_CONFIG_VERSION_MAJOR 1 +#define LIBXSMM_CONFIG_VERSION_MINOR 16 +#define LIBXSMM_CONFIG_VERSION_UPDATE 1 +#define LIBXSMM_CONFIG_VERSION_PATCH 1534 +#define LIBXSMM_CONFIG_BUILD_DATE 20230510 + +#endif + diff --git a/third_party/libxsmm/obj/.make b/third_party/libxsmm/obj/.make new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/libxsmm/obj/intel64/.make b/third_party/libxsmm/obj/intel64/.make new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/libxsmm/obj/intel64/generator_aarch64_instructions.o b/third_party/libxsmm/obj/intel64/generator_aarch64_instructions.o new file mode 100644 index 0000000000000000000000000000000000000000..a37798f9410f85ea3e963ec8d1f06e3578e95d53 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_aarch64_instructions.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_common.o b/third_party/libxsmm/obj/intel64/generator_common.o new file mode 100644 index 0000000000000000000000000000000000000000..43f84a07bbce88548c76f0cc074974d6ca5d4c70 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_common.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_common_aarch64.o b/third_party/libxsmm/obj/intel64/generator_common_aarch64.o new file mode 100644 index 0000000000000000000000000000000000000000..60e8e5ed6433244d8e7ef010db0762d355559720 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_common_aarch64.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_common_x86.o b/third_party/libxsmm/obj/intel64/generator_common_x86.o new file mode 100644 index 0000000000000000000000000000000000000000..96e1f116d41a4cedbf12a5ba171b045ed798dbb4 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_common_x86.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm.o b/third_party/libxsmm/obj/intel64/generator_gemm.o new file mode 100644 index 0000000000000000000000000000000000000000..6c33deb33929b9f6143ce05802f30e740d441dfb Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_aarch64.o b/third_party/libxsmm/obj/intel64/generator_gemm_aarch64.o new file mode 100644 index 0000000000000000000000000000000000000000..43e65ebce710531a0394c79aeec3b14328790209 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_aarch64.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_amx.o b/third_party/libxsmm/obj/intel64/generator_gemm_amx.o new file mode 100644 index 0000000000000000000000000000000000000000..f125c0247ae6ba627f4bc88ce78640c7747bbe46 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_amx.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_amx_emu.o b/third_party/libxsmm/obj/intel64/generator_gemm_amx_emu.o new file mode 100644 index 0000000000000000000000000000000000000000..fc4deb0e61aa534fcfc79346c5015913f63d6e40 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_amx_emu.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_amx_microkernel.o b/third_party/libxsmm/obj/intel64/generator_gemm_amx_microkernel.o new file mode 100644 index 0000000000000000000000000000000000000000..fe853e4387c82efdeeec86f6936633d7c08864da Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_amx_microkernel.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_amx_microkernel_emu.o b/third_party/libxsmm/obj/intel64/generator_gemm_amx_microkernel_emu.o new file mode 100644 index 0000000000000000000000000000000000000000..2aa07abbb2e87447578dfe88b9371bcde1ac0528 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_amx_microkernel_emu.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_avx2_microkernel.o b/third_party/libxsmm/obj/intel64/generator_gemm_avx2_microkernel.o new file mode 100644 index 0000000000000000000000000000000000000000..2fa5ae1c29c57c57b2f143152dd862d01838570c Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_avx2_microkernel.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_avx512_microkernel.o b/third_party/libxsmm/obj/intel64/generator_gemm_avx512_microkernel.o new file mode 100644 index 0000000000000000000000000000000000000000..60c31e1b0574a19094288590f5f43f52d6674b51 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_avx512_microkernel.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_avx_microkernel.o b/third_party/libxsmm/obj/intel64/generator_gemm_avx_microkernel.o new file mode 100644 index 0000000000000000000000000000000000000000..0bfaa883e1efb94a77c43a0473cafbafa90a2674 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_avx_microkernel.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_common.o b/third_party/libxsmm/obj/intel64/generator_gemm_common.o new file mode 100644 index 0000000000000000000000000000000000000000..a715ea1e033887573a0b51525163b81dc7425401 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_common.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_common_aarch64.o b/third_party/libxsmm/obj/intel64/generator_gemm_common_aarch64.o new file mode 100644 index 0000000000000000000000000000000000000000..0244d5b7ac0e84be99374ce90376098f14368dcc Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_common_aarch64.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_noarch.o b/third_party/libxsmm/obj/intel64/generator_gemm_noarch.o new file mode 100644 index 0000000000000000000000000000000000000000..a734003404017380a1a780bfa7266495e7793c44 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_noarch.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_sse_avx_avx2_avx512.o b/third_party/libxsmm/obj/intel64/generator_gemm_sse_avx_avx2_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..bd7ed313d76dd4fa798360ec1436a5b80e669cfb Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_sse_avx_avx2_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_gemm_sse_microkernel.o b/third_party/libxsmm/obj/intel64/generator_gemm_sse_microkernel.o new file mode 100644 index 0000000000000000000000000000000000000000..3160b1c50d31fe7e51270012cde7e1525457bd9e Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_gemm_sse_microkernel.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise.o b/third_party/libxsmm/obj/intel64/generator_mateltwise.o new file mode 100644 index 0000000000000000000000000000000000000000..6dcf2cddc55bf3d851b3c19a3e032281d337da3c Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise_misc_avx_avx512.o b/third_party/libxsmm/obj/intel64/generator_mateltwise_misc_avx_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..ce787042c2d03cb6cca05da146fdc7e596d1f508 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise_misc_avx_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise_reduce_avx_avx512.o b/third_party/libxsmm/obj/intel64/generator_mateltwise_reduce_avx_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..d6df904d4046ada1d70babb8bda0896857663417 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise_reduce_avx_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise_sse_avx_avx512.o b/third_party/libxsmm/obj/intel64/generator_mateltwise_sse_avx_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..b1605f6c7176cf4d90f8fed41bbfe818ee5e015d Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise_sse_avx_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_avx.o b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_avx.o new file mode 100644 index 0000000000000000000000000000000000000000..03bd0967d9f493b6ec769f6bc54175548159bd55 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_avx.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_avx512.o b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..e5f21fe126688c948444b72cae9f189f0395ad42 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_common.o b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_common.o new file mode 100644 index 0000000000000000000000000000000000000000..8e5127af77bea45e86c5b94f05b16b22017dacfc Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_common.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_common_x86.o b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_common_x86.o new file mode 100644 index 0000000000000000000000000000000000000000..c168e70acb846298bd1c19da41ed54c5760f9e9b Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_common_x86.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_sse.o b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_sse.o new file mode 100644 index 0000000000000000000000000000000000000000..a1d22bcebce0e37da31df0559490a4d97ff44879 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise_transform_sse.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_mateltwise_unary_binary_avx_avx512.o b/third_party/libxsmm/obj/intel64/generator_mateltwise_unary_binary_avx_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..58d1deaaa21683fa025ec3a50fdccde9c036c5a8 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_mateltwise_unary_binary_avx_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_matequation.o b/third_party/libxsmm/obj/intel64/generator_matequation.o new file mode 100644 index 0000000000000000000000000000000000000000..1178ec41b530f0e3e43f343ff49b435c73dee5d4 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_matequation.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_matequation_avx_avx512.o b/third_party/libxsmm/obj/intel64/generator_matequation_avx_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..4ca47e60a590bb4c2a6dc6e486f46f5aaa0f80dc Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_matequation_avx_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_matequation_regblocks_avx_avx512.o b/third_party/libxsmm/obj/intel64/generator_matequation_regblocks_avx_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..1b1cfcad1a02c7d94f3f160f0489223c1e1fcf5c Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_matequation_regblocks_avx_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_matequation_scratch_avx_avx512.o b/third_party/libxsmm/obj/intel64/generator_matequation_scratch_avx_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..1d376ef16ebce4565872d27c7750f98407fe731d Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_matequation_scratch_avx_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_gemm_ac_rm.o b/third_party/libxsmm/obj/intel64/generator_packed_gemm_ac_rm.o new file mode 100644 index 0000000000000000000000000000000000000000..497b421f1d56c8a5272c8b238430ce37192faf4a Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_gemm_ac_rm.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_gemm_ac_rm_aarch64.o b/third_party/libxsmm/obj/intel64/generator_packed_gemm_ac_rm_aarch64.o new file mode 100644 index 0000000000000000000000000000000000000000..080905596cabb6897c96b6ee5be7a64495c3749e Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_gemm_ac_rm_aarch64.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_gemm_ac_rm_avx_avx2_avx512.o b/third_party/libxsmm/obj/intel64/generator_packed_gemm_ac_rm_avx_avx2_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..ad9130465696eccb9d2a5fe266c3369199617195 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_gemm_ac_rm_avx_avx2_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_gemm_bc_rm.o b/third_party/libxsmm/obj/intel64/generator_packed_gemm_bc_rm.o new file mode 100644 index 0000000000000000000000000000000000000000..6d29c66da8ca8ab489ebf59e4074bccbcef96475 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_gemm_bc_rm.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_gemm_bc_rm_aarch64.o b/third_party/libxsmm/obj/intel64/generator_packed_gemm_bc_rm_aarch64.o new file mode 100644 index 0000000000000000000000000000000000000000..a1260bfcf46162524a496f8a1be8fa13c1addecd Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_gemm_bc_rm_aarch64.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_gemm_bc_rm_avx_avx2_avx512.o b/third_party/libxsmm/obj/intel64/generator_packed_gemm_bc_rm_avx_avx2_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..52658d176820ac1a02fc9e3aa91c90d54875ab6a Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_gemm_bc_rm_avx_avx2_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm.o new file mode 100644 index 0000000000000000000000000000000000000000..f38340f8a284a8911480a8cc367dfe1b4993ee61 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_bsparse.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_bsparse.o new file mode 100644 index 0000000000000000000000000000000000000000..5fc2dd3aedbfbdb0830d44a7409b03441e4d60d2 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_bsparse.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_bsparse_aarch64.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_bsparse_aarch64.o new file mode 100644 index 0000000000000000000000000000000000000000..93d56d3a1fd33272f38f93c78a20c04d8e369f59 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_bsparse_aarch64.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_bsparse_avx_avx2_avx512.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_bsparse_avx_avx2_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..2c6bd9c529f9d755a5e7f0eafa542662ce2d39a3 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_bsparse_avx_avx2_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_csparse.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_csparse.o new file mode 100644 index 0000000000000000000000000000000000000000..566ce3edd38d07afa8e86c875e06961db014cfb8 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_csparse.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_csparse_avx_avx2_avx512.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_csparse_avx_avx2_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..b4613a810f481328f65adeab39b13603294efe01 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csc_csparse_avx_avx2_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_asparse.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_asparse.o new file mode 100644 index 0000000000000000000000000000000000000000..a492f54fc80aee4708614982db85b27c44fc1e46 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_asparse.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_asparse_aarch64.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_asparse_aarch64.o new file mode 100644 index 0000000000000000000000000000000000000000..a45aa25341da9cb4d84120386e5425b41746e9ed Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_asparse_aarch64.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_asparse_avx_avx2_avx512.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_asparse_avx_avx2_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..d420cc09aed5429ad9ba57c05ff42d99208c44f2 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_asparse_avx_avx2_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_bsparse.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_bsparse.o new file mode 100644 index 0000000000000000000000000000000000000000..f4369b88fce5bb177198c9ba7c85c833e0322384 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_bsparse.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_bsparse_aarch64.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_bsparse_aarch64.o new file mode 100644 index 0000000000000000000000000000000000000000..2c0bd993939f6b82b920cd853e5ab7c38193c9c7 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_bsparse_aarch64.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_bsparse_avx_avx2_avx512.o b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_bsparse_avx_avx2_avx512.o new file mode 100644 index 0000000000000000000000000000000000000000..1cf33353cadf8e7247cf400eba67b2bc78ad0e8d Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_packed_spgemm_csr_bsparse_avx_avx2_avx512.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_spgemm.o b/third_party/libxsmm/obj/intel64/generator_spgemm.o new file mode 100644 index 0000000000000000000000000000000000000000..8a588852d0bc64f82daed6b3b96da06b1cb80ad3 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_spgemm.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_spgemm_csc_asparse.o b/third_party/libxsmm/obj/intel64/generator_spgemm_csc_asparse.o new file mode 100644 index 0000000000000000000000000000000000000000..3f8b3882cc5d49ca6dde0a43aad917842687701e Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_spgemm_csc_asparse.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_spgemm_csc_bsparse.o b/third_party/libxsmm/obj/intel64/generator_spgemm_csc_bsparse.o new file mode 100644 index 0000000000000000000000000000000000000000..e31c00cc7996792d5372722a2e8ec3b8a78693a1 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_spgemm_csc_bsparse.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_spgemm_csc_reader.o b/third_party/libxsmm/obj/intel64/generator_spgemm_csc_reader.o new file mode 100644 index 0000000000000000000000000000000000000000..03b908c4fe26ccbcf5d75350946af7766675cc1c Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_spgemm_csc_reader.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_spgemm_csr_asparse.o b/third_party/libxsmm/obj/intel64/generator_spgemm_csr_asparse.o new file mode 100644 index 0000000000000000000000000000000000000000..5dbe6b33bcee7058b5bc45b0c19eac0012812b9a Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_spgemm_csr_asparse.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_spgemm_csr_asparse_reg.o b/third_party/libxsmm/obj/intel64/generator_spgemm_csr_asparse_reg.o new file mode 100644 index 0000000000000000000000000000000000000000..31a58a5873312addd7080de4b91eaf27bf485c09 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_spgemm_csr_asparse_reg.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_spgemm_csr_reader.o b/third_party/libxsmm/obj/intel64/generator_spgemm_csr_reader.o new file mode 100644 index 0000000000000000000000000000000000000000..ca0b2b6058ade1e02fde9718820ccc6cba453949 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_spgemm_csr_reader.o differ diff --git a/third_party/libxsmm/obj/intel64/generator_x86_instructions.o b/third_party/libxsmm/obj/intel64/generator_x86_instructions.o new file mode 100644 index 0000000000000000000000000000000000000000..81ff0a85377766ef8edd4fb80aac4a85145cf1b4 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/generator_x86_instructions.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm-mod.o b/third_party/libxsmm/obj/intel64/libxsmm-mod.o new file mode 100644 index 0000000000000000000000000000000000000000..edd73c49b852688d23c54bce742a27b50be583cc Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm-mod.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_cpuid_arm.o b/third_party/libxsmm/obj/intel64/libxsmm_cpuid_arm.o new file mode 100644 index 0000000000000000000000000000000000000000..3e9ecc90272cff0b421bb18945b6282944b00b50 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_cpuid_arm.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_cpuid_x86.o b/third_party/libxsmm/obj/intel64/libxsmm_cpuid_x86.o new file mode 100644 index 0000000000000000000000000000000000000000..ada0c9c09ea86d9a7468f6edb4fefa4cd3175424 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_cpuid_x86.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn.o new file mode 100644 index 0000000000000000000000000000000000000000..9b980615c4d28a306b578f2976aec230773d2ee1 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution.o new file mode 100644 index 0000000000000000000000000000000000000000..64b703c901876c27560184cd9e74b5957555c8a2 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution_backward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution_backward.o new file mode 100644 index 0000000000000000000000000000000000000000..243e077f4421d6e9af5a223f20b4e3fc61769908 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution_backward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution_forward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution_forward.o new file mode 100644 index 0000000000000000000000000000000000000000..24112eb6b89d03d8fad741ad257091ca182aa813 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution_forward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution_weight_update.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution_weight_update.o new file mode 100644 index 0000000000000000000000000000000000000000..bc90715ace3671d8d10377a797a1e28d57397bd6 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_convolution_weight_update.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_elementwise.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_elementwise.o new file mode 100644 index 0000000000000000000000000000000000000000..403a4c38fe0f1e2d42231a7c91ffd1f69821320c Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_elementwise.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_fullyconnected.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fullyconnected.o new file mode 100644 index 0000000000000000000000000000000000000000..f40c25aa8ec3845330007da4fe1b9fa8abdc5f1e Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fullyconnected.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_fullyconnected_backward_weight_update.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fullyconnected_backward_weight_update.o new file mode 100644 index 0000000000000000000000000000000000000000..a489205e7cfeebde53e174f8a37a95cfd434f222 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fullyconnected_backward_weight_update.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_fullyconnected_forward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fullyconnected_forward.o new file mode 100644 index 0000000000000000000000000000000000000000..eeebcf14d8b78cc28f051d8ee3460f2dbaf1cc2d Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fullyconnected_forward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedbatchnorm.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedbatchnorm.o new file mode 100644 index 0000000000000000000000000000000000000000..8e11b865d6d604293136aec63c51eab6ae65b8e3 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedbatchnorm.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedbatchnorm_backward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedbatchnorm_backward.o new file mode 100644 index 0000000000000000000000000000000000000000..682c31aaece1267c0cb35fa8c83ec78c14a376af Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedbatchnorm_backward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedbatchnorm_forward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedbatchnorm_forward.o new file mode 100644 index 0000000000000000000000000000000000000000..e24bbbdbce2b4261f3b6b3193596d7c20258ed2e Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedbatchnorm_forward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedgroupnorm.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedgroupnorm.o new file mode 100644 index 0000000000000000000000000000000000000000..b9da9931455be7dab5fa1f089cfdc6ed9aad5711 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedgroupnorm.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedgroupnorm_backward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedgroupnorm_backward.o new file mode 100644 index 0000000000000000000000000000000000000000..280d07a6fb31649b8385193b89dc32bb36a32fd6 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedgroupnorm_backward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedgroupnorm_forward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedgroupnorm_forward.o new file mode 100644 index 0000000000000000000000000000000000000000..019ebba85861d5763dea9b9d67e8ab11cbd6e9ba Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_fusedgroupnorm_forward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_optimizer.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_optimizer.o new file mode 100644 index 0000000000000000000000000000000000000000..a8d509b90c0896725112f2ca3c6ce12b8c027b32 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_optimizer.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_optimizer_sgd.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_optimizer_sgd.o new file mode 100644 index 0000000000000000000000000000000000000000..2b00d754f11119f7fc46554be7567b22892c6925 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_optimizer_sgd.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_pooling.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_pooling.o new file mode 100644 index 0000000000000000000000000000000000000000..8d0bd632a2df09fe8566f480e4c3de52e6754dae Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_pooling.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_pooling_backward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_pooling_backward.o new file mode 100644 index 0000000000000000000000000000000000000000..87b419ef3ae131ed8bcce8fc9470cdfd6a7a2766 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_pooling_backward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_pooling_forward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_pooling_forward.o new file mode 100644 index 0000000000000000000000000000000000000000..14039d2cf2f6894514aeb6d2b44f2b00ec5bf81b Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_pooling_forward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_rnncell.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_rnncell.o new file mode 100644 index 0000000000000000000000000000000000000000..acae445562ddda108a67b0d6334386ab7b5fa90c Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_rnncell.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_rnncell_backward_weight_update.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_rnncell_backward_weight_update.o new file mode 100644 index 0000000000000000000000000000000000000000..215bc074e45e63d4538a6ea7764a2e110833ef98 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_rnncell_backward_weight_update.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_rnncell_forward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_rnncell_forward.o new file mode 100644 index 0000000000000000000000000000000000000000..6710f0aadbb9e922cb1f2b3f26c4041e098776ec Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_rnncell_forward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_softmaxloss.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_softmaxloss.o new file mode 100644 index 0000000000000000000000000000000000000000..1c5c6590013c59c56df6bdb4a08ed07793d5b02a Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_softmaxloss.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_softmaxloss_backward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_softmaxloss_backward.o new file mode 100644 index 0000000000000000000000000000000000000000..bd98388f1cf47cecd879be98681a6e64e1fb8074 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_softmaxloss_backward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_softmaxloss_forward.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_softmaxloss_forward.o new file mode 100644 index 0000000000000000000000000000000000000000..a6cee187667948c38b36050239e77f1f2b8e7f0c Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_softmaxloss_forward.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_dnn_tensor.o b/third_party/libxsmm/obj/intel64/libxsmm_dnn_tensor.o new file mode 100644 index 0000000000000000000000000000000000000000..4c48e95a3fc50bb45ac88c3ea58d380ec98fcdaf Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_dnn_tensor.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_ext.o b/third_party/libxsmm/obj/intel64/libxsmm_ext.o new file mode 100644 index 0000000000000000000000000000000000000000..ab0db087577b6acef14f8cce5ea352dc8eb4e75b Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_ext.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_ext_gemm.o b/third_party/libxsmm/obj/intel64/libxsmm_ext_gemm.o new file mode 100644 index 0000000000000000000000000000000000000000..b7b11eabcf1cfe491dfb7fdcb006a4ca7fe80f82 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_ext_gemm.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_ext_xcopy.o b/third_party/libxsmm/obj/intel64/libxsmm_ext_xcopy.o new file mode 100644 index 0000000000000000000000000000000000000000..e3e0f57fd36da190964588f9be5e47e0f84e0af7 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_ext_xcopy.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_fsspmdm.o b/third_party/libxsmm/obj/intel64/libxsmm_fsspmdm.o new file mode 100644 index 0000000000000000000000000000000000000000..9fc1de47917773c3048951ed3f8fd19f7d8024d4 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_fsspmdm.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_gemm.o b/third_party/libxsmm/obj/intel64/libxsmm_gemm.o new file mode 100644 index 0000000000000000000000000000000000000000..0c65c7f8a8d3c745c3cb2e9fe8c63935144e4c83 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_gemm.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_generator.o b/third_party/libxsmm/obj/intel64/libxsmm_generator.o new file mode 100644 index 0000000000000000000000000000000000000000..72b9047dae909d5f078366b7ca915581fe94fefd Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_generator.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_generator_gemm_driver.o b/third_party/libxsmm/obj/intel64/libxsmm_generator_gemm_driver.o new file mode 100644 index 0000000000000000000000000000000000000000..938c8d1bae01f9cf708897e97ea51a792f7aede6 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_generator_gemm_driver.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_hash.o b/third_party/libxsmm/obj/intel64/libxsmm_hash.o new file mode 100644 index 0000000000000000000000000000000000000000..b036d8fcf1f42c12ca3d8807092b27e8a9445b70 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_hash.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_main.o b/third_party/libxsmm/obj/intel64/libxsmm_main.o new file mode 100644 index 0000000000000000000000000000000000000000..c238590b5ec3f17560fa5ae3ac1a579e5e1bf51c Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_main.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_malloc.o b/third_party/libxsmm/obj/intel64/libxsmm_malloc.o new file mode 100644 index 0000000000000000000000000000000000000000..b6e92b973e92106ef7da19ac4a297ba9660bbaf1 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_malloc.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_math.o b/third_party/libxsmm/obj/intel64/libxsmm_math.o new file mode 100644 index 0000000000000000000000000000000000000000..2e83312233caaa6685c10fa077873df02acf38b5 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_math.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_matrixeqn.o b/third_party/libxsmm/obj/intel64/libxsmm_matrixeqn.o new file mode 100644 index 0000000000000000000000000000000000000000..ef72d412a8e4deffca4f42ebaa1f4255ed01a141 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_matrixeqn.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_memory.o b/third_party/libxsmm/obj/intel64/libxsmm_memory.o new file mode 100644 index 0000000000000000000000000000000000000000..82f628d413254583cfa2d8177c9b345ef1b273e0 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_memory.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_mhd.o b/third_party/libxsmm/obj/intel64/libxsmm_mhd.o new file mode 100644 index 0000000000000000000000000000000000000000..f7ca4cd06e101e1cd1d4e3d98a8180f1de8647a0 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_mhd.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_noblas.o b/third_party/libxsmm/obj/intel64/libxsmm_noblas.o new file mode 100644 index 0000000000000000000000000000000000000000..ee54b138a9286dba044e8f882453735355730add Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_noblas.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_perf.o b/third_party/libxsmm/obj/intel64/libxsmm_perf.o new file mode 100644 index 0000000000000000000000000000000000000000..35c387a89ff528e2236590a45ea10c02378ccaea Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_perf.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_python.o b/third_party/libxsmm/obj/intel64/libxsmm_python.o new file mode 100644 index 0000000000000000000000000000000000000000..13e5bffd11ad168768231455759a944eaf2801db Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_python.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_rng.o b/third_party/libxsmm/obj/intel64/libxsmm_rng.o new file mode 100644 index 0000000000000000000000000000000000000000..c6fe7cea28970168df4aabbd299a73469c85323e Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_rng.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_spmdm.o b/third_party/libxsmm/obj/intel64/libxsmm_spmdm.o new file mode 100644 index 0000000000000000000000000000000000000000..34acd1e73424f0ec108c588cc1fa62d696b6e38b Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_spmdm.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_sync.o b/third_party/libxsmm/obj/intel64/libxsmm_sync.o new file mode 100644 index 0000000000000000000000000000000000000000..1ea4199fe0d0bd257746eeffc1582bd4f6cdf1aa Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_sync.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_timer.o b/third_party/libxsmm/obj/intel64/libxsmm_timer.o new file mode 100644 index 0000000000000000000000000000000000000000..530d82eebaa3888228b02565e750deccdc67a6b8 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_timer.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_trace.o b/third_party/libxsmm/obj/intel64/libxsmm_trace.o new file mode 100644 index 0000000000000000000000000000000000000000..b8a118abce233b38974f8b75c6f642711a518d97 Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_trace.o differ diff --git a/third_party/libxsmm/obj/intel64/libxsmm_xcopy.o b/third_party/libxsmm/obj/intel64/libxsmm_xcopy.o new file mode 100644 index 0000000000000000000000000000000000000000..f4e1f0161e5aad3a5e497f233898fed82b5bd0eb Binary files /dev/null and b/third_party/libxsmm/obj/intel64/libxsmm_xcopy.o differ diff --git a/third_party/libxsmm/obj/libxsmm_dispatch.h b/third_party/libxsmm/obj/libxsmm_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..0f2804e30f1f9f7ec6832e0d2d5a9f50fa7c4225 --- /dev/null +++ b/third_party/libxsmm/obj/libxsmm_dispatch.h @@ -0,0 +1,7 @@ +#if !defined(_WIN32) +{ static const char *const build_state = +# include "../.state" + ; + internal_build_state = build_state; +} +#endif diff --git a/third_party/libxsmm/samples/cp2k/.make b/third_party/libxsmm/samples/cp2k/.make new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/libxsmm/samples/deeplearning/tvm_cnnlayer/libxsmm_wrapper/Makefile b/third_party/libxsmm/samples/deeplearning/tvm_cnnlayer/libxsmm_wrapper/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..17404d66c6b236e2c0f43542874b39f21683e552 --- /dev/null +++ b/third_party/libxsmm/samples/deeplearning/tvm_cnnlayer/libxsmm_wrapper/Makefile @@ -0,0 +1,19 @@ +CC= icpc +CFLAGS= -O3 -fPIC -std=c++11 -fopenmp +LDFLAGS= -shared +SOURCES = batch_reduce_plus_init.cc +LIBXSMMDIR=./../../../../ + +INC=-I$(LIBXSMMDIR)/include +LIBS = $(LIBXSMMDIR)/lib/libxsmm.a $(LIBXSMMDIR)/lib/libxsmmext.a \ + $(LIBXSMMDIR)/lib/libxsmmnoblas.a $(LIBXSMMDIR)/lib/libxsmmgen.a \ + $(LIBXSMMDIR)/lib/libxsmmf.a + +TARGET= libxsmm_wrapper.so + +all: + $(CC) $(INC) $(CFLAGS) -fPIC $(SOURCES) $(LIBS) -o $(TARGET) $(LDFLAGS) + +clean: + rm -f $(TARGET) + diff --git a/third_party/libxsmm/samples/deeplearning/tvm_cnnlayer/libxsmm_wrapper/batch_reduce_plus_init.cc b/third_party/libxsmm/samples/deeplearning/tvm_cnnlayer/libxsmm_wrapper/batch_reduce_plus_init.cc new file mode 100644 index 0000000000000000000000000000000000000000..c51a0ebde4634357d380b7b61ddb7df7b14a8d7d --- /dev/null +++ b/third_party/libxsmm/samples/deeplearning/tvm_cnnlayer/libxsmm_wrapper/batch_reduce_plus_init.cc @@ -0,0 +1,89 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Anand Venkat (Intel Corp.) +******************************************************************************/ + +#include +#include + +extern "C" int batch_reduce_kernel_update(const float *weight, const float *input, float *output, int blocks, int ofmblock, int ifmblock, int ofw, int stride_w, int r, int s, int ifh, int ifw){ + int ld_b = stride_w*ifmblock; + libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr(ofmblock,ofw, ifmblock,NULL,&ld_b,NULL,NULL,NULL, NULL, NULL); + const unsigned long long cblocks = blocks; + const float * A[cblocks]; + const float * B[cblocks]; + int weight_stride = ofmblock*ifmblock*r*s; + int input_stride = ifw*ifh*ifmblock; + if(r == 1 && s == 1){ + for (int icb = 0; icb < cblocks; icb ++) { + A[icb] = &weight[icb*weight_stride]; + B[icb] = &input[icb*input_stride]; + } + }else{/*Eg.if( r == 3 && s == 3){*/ + for( int k = 0 ; k < blocks/(r*s); k++){ + for(int i=0; i < r; i++){ + for(int j =0; j < s; j++){ + A[k*r*s + i*s + j] = &weight[k*r*s*ofmblock*ifmblock + (i*s + j)*ofmblock*ifmblock]; + B[k*r*s + i*s + j] = &input[k*ifw*ifh*ifmblock + i*ifw*ifmblock + j*ifmblock]; + } + } + } + } + + /* Reduce batch gemm call */ + batchreduce_kernela(A, B, output, &cblocks); + + return 0; +} + +extern "C" int batch_reduce_kernel_init_update(const float *weight, const float *input, float *output, int blocks, int ofmblock, int ifmblock,int r, int s, int ifh, int ifw,int ofw, int stride_w ){ + float beta = 0.0; + int lda = ofmblock; + int ldx = ofmblock; + int ld_b = stride_w*ifmblock; + int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ); + libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr(ofmblock,ofw, ifmblock,&lda,&ld_b,&ldx,NULL,&beta, &l_flags, NULL); + + const unsigned long long cblocks = blocks; + const float * A[cblocks]; + const float * B[cblocks]; + int weight_stride = ofmblock*ifmblock*r*s; + int input_stride = ifw*ifh*ifmblock; + if(r == 1 && s == 1){ + for (int icb = 0; icb < cblocks; icb ++) { + A[icb] = &weight[icb*weight_stride]; + B[icb] = &input[icb*input_stride]; + } + }else{ /*if( r == 3 && s == 3){*/ + for( int k = 0 ; k < blocks/(r*s); k++) + for(int i=0; i < r; i++) + for(int j =0; j < s; j++){ + A[k*r*s + i*s + j] = &weight[k*r*s*ofmblock*ifmblock + (i*s + j)*ofmblock*ifmblock]; + B[k*r*s + i*s + j] = &input[k*ifw*ifh*ifmblock + i*ifw*ifmblock + j*ifmblock]; + } + + } + /* Reduce batch gemm call */ + batchreduce_kernela(A, B, output, &cblocks); + + + return 0; +} + +extern "C" int batch_reduce_kernel_init(float *output, int ofmblock, int ofw){ + int num_elements = ofw*ofmblock; + + LIBXSMM_PRAGMA_SIMD + for(int i=0; i < num_elements; i++) + output[i] = 0.0; + + return 0; +} + + diff --git a/third_party/libxsmm/samples/nek/.make b/third_party/libxsmm/samples/nek/.make new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/libxsmm/samples/smm/.make b/third_party/libxsmm/samples/smm/.make new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/libxsmm/samples/utilities/mhd/mhd_in.mhd b/third_party/libxsmm/samples/utilities/mhd/mhd_in.mhd new file mode 100644 index 0000000000000000000000000000000000000000..9eaad14031b1c51fe700d0e5446f824d60538a48 Binary files /dev/null and b/third_party/libxsmm/samples/utilities/mhd/mhd_in.mhd differ diff --git a/third_party/libxsmm/scripts/libxsmm_config.py b/third_party/libxsmm/scripts/libxsmm_config.py new file mode 100755 index 0000000000000000000000000000000000000000..fd49ca80128b74f7198a6215540a2b4bcfccd564 --- /dev/null +++ b/third_party/libxsmm/scripts/libxsmm_config.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) Intel Corporation - All rights reserved. # +# This file is part of the LIBXSMM library. # +# # +# For information on the license, see the LICENSE file. # +# Further information: https://github.com/hfp/libxsmm/ # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################### +# Hans Pabst (Intel Corp.) +############################################################################### +from string import Template +from datetime import date +import libxsmm_utilities +import fnmatch +import sys + + +if __name__ == "__main__": + argc = len(sys.argv) + if 1 < argc: + # required argument(s) + filename = sys.argv[1] + + # default configuration if no arguments are given + ilp64 = offload = precision = flags = threshold = 0 + sync = jit = 1 + alpha = beta = 1 + cacheline = 64 + prefetch = -1 + wrap = 1 + malloc = 0 + mnklist = list() + + # optional argument(s) + if 2 < argc: + ilp64 = int(sys.argv[2]) + if 3 < argc: + offload = int(sys.argv[3]) + if 4 < argc: + cacheline = libxsmm_utilities.sanitize_alignment(int(sys.argv[4])) + if 5 < argc: + precision = int(sys.argv[5]) + if 6 < argc: + prefetch = int(sys.argv[6]) + if 7 < argc: + threshold = int(sys.argv[7]) + if 8 < argc: + sync = int(sys.argv[8]) + if 9 < argc: + jit = int(sys.argv[9]) + if 10 < argc: + flags = int(sys.argv[10]) + if 11 < argc: + alpha = int(sys.argv[11]) + if 12 < argc: + beta = int(sys.argv[12]) + if 13 < argc: + wrap = int(sys.argv[13]) + if 14 < argc: + malloc = int(sys.argv[14]) + if 15 < argc: + mnklist = sorted(libxsmm_utilities.load_mnklist(sys.argv[15:], 0)) + + version, branch, realversion = libxsmm_utilities.version_branch() + major, minor, update, patch = libxsmm_utilities.version_numbers( + version + ) + + if 0 == threshold: + threshold = 64 * 64 * 64 + maxmnk = libxsmm_utilities.max_mnk(mnklist, threshold) + maxdim = int(maxmnk ** (1.0 / 3.0) + 0.5) + avgdim = int(0.5 * maxdim + 0.5) + + avgm = libxsmm_utilities.median( + list(map(lambda mnk: mnk[0], mnklist)), avgdim, False + ) + avgn = libxsmm_utilities.median( + list(map(lambda mnk: mnk[1], mnklist)), avgdim, False + ) + avgk = libxsmm_utilities.median( + list(map(lambda mnk: mnk[2], mnklist)), avgdim, False + ) + + maxm = libxsmm_utilities.max_mnk(mnklist, avgdim, 0) + maxn = libxsmm_utilities.max_mnk(mnklist, avgdim, 1) + maxk = libxsmm_utilities.max_mnk(mnklist, avgdim, 2) + + substitute = { + "VERSION": realversion, + "BRANCH": branch, + "MAJOR": major, + "MINOR": minor, + "UPDATE": update, + "PATCH": patch, + "DATE": date.today().strftime("%Y%m%d"), + "CACHELINE": cacheline, + "PREFETCH": [-1, prefetch][0 <= prefetch], + "MAX_MNK": maxmnk, + "MAX_DIM": maxdim, + "AVG_DIM": int((maxdim + 1) / 2), + "MAX_M": [maxdim, maxm][avgm < maxm], + "MAX_N": [maxdim, maxn][avgn < maxn], + "MAX_K": [maxdim, maxk][avgk < maxk], + "FLAGS": flags, + "ILP64": [0, 1][0 != ilp64], + "ALPHA": alpha, + "BETA": beta, + "WRAP": wrap, + "MALLOC": malloc, + "SYNC": [0, 1][0 != sync], + "JIT": [0, 1][0 != jit], + "LIBXSMM_OFFLOAD_BUILD": ["", "\n#define LIBXSMM_OFFLOAD_BUILD"][ + 0 != offload + ], + "MNK_PREPROCESSOR_LIST": "", + } + + template = Template(open(filename, "r").read()) + if fnmatch.fnmatch(filename, "*.h*"): + if mnklist: + first = mnklist[0] + for mnk in mnklist: + mnkstr = "_".join(map(str, mnk)) + if mnk != first: + substitute["MNK_PREPROCESSOR_LIST"] += "\n" + if 2 != precision: + substitute["MNK_PREPROCESSOR_LIST"] += ( + "#define LIBXSMM_SMM_" + mnkstr + ) + if mnk != first or 0 == precision: + substitute["MNK_PREPROCESSOR_LIST"] += "\n" + if 1 != precision: + substitute["MNK_PREPROCESSOR_LIST"] += ( + "#define LIBXSMM_DMM_" + mnkstr + ) + + print(template.substitute(substitute)) + else: + substitute["BLASINT_KIND"] = ["C_INT", "C_LONG_LONG"][0 != ilp64] + print(template.safe_substitute(substitute)) + else: + sys.tracebacklimit = 0 + raise ValueError(sys.argv[0] + ": wrong number of arguments!") diff --git a/third_party/libxsmm/scripts/libxsmm_dispatch.py b/third_party/libxsmm/scripts/libxsmm_dispatch.py new file mode 100755 index 0000000000000000000000000000000000000000..7fab68e0de506c8b56f5ae5e685595a3fe684368 --- /dev/null +++ b/third_party/libxsmm/scripts/libxsmm_dispatch.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) Intel Corporation - All rights reserved. # +# This file is part of the LIBXSMM library. # +# # +# For information on the license, see the LICENSE file. # +# Further information: https://github.com/hfp/libxsmm/ # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################### +# Hans Pabst (Intel Corp.) +############################################################################### +import libxsmm_utilities +import sys +import os + + +if __name__ == "__main__": + argc = len(sys.argv) + if 1 < argc: + arg1_filename = [sys.argv[1], ""]["0" == sys.argv[1]] + arg1_isfile = os.path.isfile(arg1_filename) + base = 1 + if arg1_isfile: + print("#if !defined(_WIN32)") + print("{ static const char *const build_state =") + print('# include "../' + os.path.basename(arg1_filename) + '"') + print(" ;") + print(" internal_build_state = build_state;") + print("}") + print("#endif") + base = 2 + if (base + 2) < argc: + precision = int(sys.argv[base + 0]) + threshold = int(sys.argv[base + 1]) + mnklist = libxsmm_utilities.load_mnklist(sys.argv[base + 2:], 0) + print( + "/* omit registering code if JIT is enabled" + " and if an ISA extension is found" + ) + print( + " * which is beyond the static code" + " path used to compile the library" + ) + print(" */") + print("#if (0 != LIBXSMM_JIT) && !defined(__MIC__)") + print( + "if (LIBXSMM_X86_GENERIC > libxsmm_target_archid " + "/* JIT code gen. is not available */" + ) + print( + " /* conditions allows to avoid JIT " + "(if static code is good enough) */" + ) + print( + " || (LIBXSMM_STATIC_TARGET_ARCH == libxsmm_target_archid)" + ) + print( + " || (LIBXSMM_X86_AVX512_CORE <= libxsmm_target_archid &&" + ) + print( + " libxsmm_cpuid_vlen32(LIBXSMM_STATIC_TARGET_ARCH) ==" + ) + print( + " libxsmm_cpuid_vlen32(libxsmm_target_archid)))" + ) + print("#endif") + print("{") + print(" libxsmm_xmmfunction func;") + for mnk in mnklist: + mstr, nstr, kstr, mnkstr = ( + str(mnk[0]), + str(mnk[1]), + str(mnk[2]), + "_".join(map(str, mnk)), + ) + mnksig = mstr + ", " + nstr + ", " + kstr + # prefer registering double-precision kernels + # when approaching an exhausted registry + if 1 != precision: # only double-precision + print( + " func.dmm = (libxsmm_dmmfunction)libxsmm_dmm_" + + mnkstr + + ";" + ) + print( + " internal_register_static_code(" + + "LIBXSMM_GEMM_PRECISION_F64, " + + mnksig + + ", func, new_registry);" + ) + for mnk in mnklist: + mstr, nstr, kstr, mnkstr = ( + str(mnk[0]), + str(mnk[1]), + str(mnk[2]), + "_".join(map(str, mnk)), + ) + mnksig = mstr + ", " + nstr + ", " + kstr + # prefer registering double-precision kernels + # when approaching an exhausted registry + if 2 != precision: # only single-precision + print( + " func.smm = (libxsmm_smmfunction)libxsmm_smm_" + + mnkstr + + ";" + ) + print( + " internal_register_static_code(" + + "LIBXSMM_GEMM_PRECISION_F32, " + + mnksig + + ", func, new_registry);" + ) + print("}") + else: + sys.tracebacklimit = 0 + raise ValueError(sys.argv[0] + ": wrong number of arguments!") diff --git a/third_party/libxsmm/scripts/libxsmm_interface.py b/third_party/libxsmm/scripts/libxsmm_interface.py new file mode 100755 index 0000000000000000000000000000000000000000..9c013d8ce3e951bde7c1ba407b85185757c9cec3 --- /dev/null +++ b/third_party/libxsmm/scripts/libxsmm_interface.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) Intel Corporation - All rights reserved. # +# This file is part of the LIBXSMM library. # +# # +# For information on the license, see the LICENSE file. # +# Further information: https://github.com/hfp/libxsmm/ # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################### +# Hans Pabst (Intel Corp.) +############################################################################### +from string import Template +import libxsmm_utilities +import fnmatch +import sys + + +if __name__ == "__main__": + argc = len(sys.argv) + if 1 < argc: + # required argument(s) + filename = sys.argv[1] + + # default configuration if no arguments are given + precision = 0 # all + ifversion = 1 # interface + prefetch = -1 # auto + mnklist = list() + + # optional argument(s) + if 2 < argc: + ivalue = int(sys.argv[2]) + ifversion = (ivalue >> 2) + precision = (ivalue & 3) + if 3 < argc: + prefetch = int(sys.argv[3]) + if 4 < argc: + mnklist = sorted(libxsmm_utilities.load_mnklist(sys.argv[4:], 0)) + + template = Template(open(filename, "r").read()) + if fnmatch.fnmatch(filename, "*.h*"): + optional = [", ...", ""][0 <= prefetch] + substitute = {"MNK_INTERFACE_LIST": ""} + for mnk in mnklist: + mnkstr = "_".join(map(str, mnk)) + if 2 != precision: + pfsig = [ + optional + ");", + ",\n " + "const float* pa, " + "const float* pb, " + "const float* pc);" + ][0 < prefetch] + substitute["MNK_INTERFACE_LIST"] += ( + "\nLIBXSMM_API void libxsmm_smm_" + + mnkstr + + "(const float* a, const float* b, float* c" + + pfsig + ) + if 1 != precision: + pfsig = [ + optional + ");", + ",\n " + "const double* pa, " + "const double* pb, " + "const double* pc);" + ][0 < prefetch] + substitute["MNK_INTERFACE_LIST"] += ( + "\nLIBXSMM_API void libxsmm_dmm_" + + mnkstr + + "(const double* a, const double* b, double* c" + + pfsig + ) + if 0 == precision: + substitute["MNK_INTERFACE_LIST"] += "\n" + if mnklist and 0 != precision: + substitute["MNK_INTERFACE_LIST"] += "\n" + print(template.substitute(substitute)) + else: # Fortran interface + if 1 > ifversion and 0 != ifversion: + raise ValueError("Fortran interface level is inconsistent!") + # Fortran's OPTIONAL allows to always generate an interface + # with prefetch signature (more flexible usage) + if 0 == prefetch: + prefetch = -1 + version, branch, realversion = libxsmm_utilities.version_branch(16) + major, minor, update, patch = libxsmm_utilities.version_numbers( + version + ) + substitute = { + "VERSION": realversion, + "BRANCH": branch, + "MAJOR": major, + "MINOR": minor, + "UPDATE": update, + "PATCH": patch, + "MNK_INTERFACE_LIST": "", + "CONTIGUOUS": ["", ", CONTIGUOUS"][1 < ifversion] + } + if mnklist: + substitute["MNK_INTERFACE_LIST"] += "\n" + for mnk in mnklist: + mnkstr = "_".join(map(str, mnk)) + if 0 == precision: + substitute["MNK_INTERFACE_LIST"] += ( + "\n " + "!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_" + + mnkstr + + ", libxsmm_dmm_" + + mnkstr + ) + elif 2 != precision: + substitute["MNK_INTERFACE_LIST"] += ( + "\n " + "!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_" + + mnkstr + ) + elif 1 != precision: + substitute["MNK_INTERFACE_LIST"] += ( + "\n " + "!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_dmm_" + + mnkstr + ) + substitute["MNK_INTERFACE_LIST"] += "\n INTERFACE" + optional = [", OPTIONAL", ""][0 < prefetch] + bindc = ["", "BIND(C)"][0 < prefetch] + for mnk in mnklist: + mnkstr = "_".join(map(str, mnk)) + if 2 != precision: + pfsiga = [ + ") BIND(C)\n", + "," + + "&".rjust(26 - len(mnkstr)) + + "\n & pa, pb, pc) " + + bindc + + "\n" + ][0 != prefetch] + pfsigb = [ + "", + " REAL(C_FLOAT), " + "INTENT(IN)" + optional + " :: " + "pa(*), " + "pb(*), " + "pc(*)\n" + ][0 != prefetch] + substitute["MNK_INTERFACE_LIST"] += ( + "\n " + "PURE SUBROUTINE libxsmm_smm_" + + mnkstr + + "(a, b, c" + + pfsiga + + " IMPORT :: C_FLOAT\n" + " REAL(C_FLOAT), " + "INTENT(IN) :: a(*), b(*)\n" + " REAL(C_FLOAT), " + "INTENT(INOUT) :: c(*)\n" + + pfsigb + + " END SUBROUTINE" + ) + if 1 != precision: + pfsiga = [ + ") BIND(C)\n", + "," + + "&".rjust(26 - len(mnkstr)) + + "\n & pa, pb, pc) " + + bindc + + "\n" + ][0 != prefetch] + pfsigb = [ + "", + " REAL(C_DOUBLE), " + "INTENT(IN)" + optional + " :: " + "pa(*), " + "pb(*), " + "pc(*)\n" + ][0 != prefetch] + substitute["MNK_INTERFACE_LIST"] += ( + "\n " + "PURE SUBROUTINE libxsmm_dmm_" + + mnkstr + + "(a, b, c" + + pfsiga + + " IMPORT :: C_DOUBLE\n" + " REAL(C_DOUBLE), " + "INTENT(IN) :: a(*), b(*)\n" + " REAL(C_DOUBLE), " + "INTENT(INOUT) :: c(*)\n" + + pfsigb + + " END SUBROUTINE" + ) + substitute["MNK_INTERFACE_LIST"] += "\n END INTERFACE" + print(template.safe_substitute(substitute)) + else: + sys.tracebacklimit = 0 + raise ValueError(sys.argv[0] + ": wrong number of arguments!") diff --git a/third_party/libxsmm/scripts/libxsmm_source.sh b/third_party/libxsmm/scripts/libxsmm_source.sh new file mode 100755 index 0000000000000000000000000000000000000000..863206cfdc01964a6bbbb8fea2ea6b58f9809d2e --- /dev/null +++ b/third_party/libxsmm/scripts/libxsmm_source.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env sh + +SRCDIR=../src +GREP=$(command -v grep) + +if [ "" = "${GREP}" ]; then + >&2 echo "Error: missing prerequisites!" + exit 1 +fi +cat << EOM +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_SOURCE_H +#define LIBXSMM_SOURCE_H + +#if defined(LIBXSMM_MACROS_H) +# error Please do not include any LIBXSMM header other than libxsmm_source.h! +#endif +#if defined(LIBXSMM_BUILD) +# error LIBXSMM_BUILD cannot be defined for the header-only LIBXSMM! +#endif + +/** + * This header is intentionally called "libxsmm_source.h" since the followings block + * includes *internal* files, and thereby exposes LIBXSMM's implementation. + * The so-called "header-only" usage model gives up the clearly defined binary interface + * (including support for hot-fixes after deployment), and requires to rebuild client + * code for every (internal) change of LIBXSMM. Please make sure to only rely on the + * public interface as the internal implementation may change without notice. + */ +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +EOM + +HERE=$(cd "$(dirname "$0")" && pwd -P) + +if [ "" = "$1" ]; then + DSTDIR=${SRCDIR} +else + DSTDIR=$1 +fi + +# determine order of filenames in directory list +export LC_ALL=C + +# good-enough pattern to match a main function, and to exclude this translation unit +for FILE in $(cd "${HERE}/${SRCDIR}" && ${GREP} -L "main[[:space:]]*(.*)" ./*.c); do + BASENAME=$(basename "${FILE}") + echo "#include \"${DSTDIR}/${BASENAME}\"" +done + +cat << EOM +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#endif /*LIBXSMM_SOURCE_H*/ +EOM + diff --git a/third_party/libxsmm/scripts/libxsmm_specialized.py b/third_party/libxsmm/scripts/libxsmm_specialized.py new file mode 100755 index 0000000000000000000000000000000000000000..fd9b9dd8b52759be8b1311e024478125f671745f --- /dev/null +++ b/third_party/libxsmm/scripts/libxsmm_specialized.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) Intel Corporation - All rights reserved. # +# This file is part of the LIBXSMM library. # +# # +# For information on the license, see the LICENSE file. # +# Further information: https://github.com/hfp/libxsmm/ # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################### +# Hans Pabst (Intel Corp.) +############################################################################### +import sys + + +if __name__ == "__main__": + argc = len(sys.argv) + if 6 == argc: + precision = int(sys.argv[1]) + m, n, k = int(sys.argv[2]), int(sys.argv[3]), int(sys.argv[4]) + prefetch = int(sys.argv[5]) + + mnkstr = str(m) + "_" + str(n) + "_" + str(k) + optional = ["", ", ..."][0 > prefetch] + signature = ["a, b, c", "a, b, c, pa, pb, pc"][0 < prefetch] + if 2 != precision: + pfsig = [ + optional + ")", + "\n" + ", const float* pa" + ", const float* pb" + ", const float* pc)", + ][0 < prefetch] + print + print + print( + "LIBXSMM_API void libxsmm_smm_" + + mnkstr + + "(const float* a, const float* b, float* c" + + pfsig + ) + print("{") + print( + "#if defined(__AVX512F__) && " + "defined(LIBXSMM_GENTARGET_skx_sp) && \\" + ) + print(" !(defined(__AVX512PF__) && defined(__AVX512ER__))") + print(" libxsmm_smm_" + mnkstr + "_skx(" + signature + ");") + print( + "#elif defined(__AVX512F__) && " + "defined(LIBXSMM_GENTARGET_knl_sp)" + ) + print(" libxsmm_smm_" + mnkstr + "_knl(" + signature + ");") + print( + "#elif defined(__AVX2__) && " + "defined(LIBXSMM_GENTARGET_hsw_sp)" + ) + print(" libxsmm_smm_" + mnkstr + "_hsw(" + signature + ");") + print( + "#elif defined(__AVX__) && " + "defined(LIBXSMM_GENTARGET_snb_sp)" + ) + print(" libxsmm_smm_" + mnkstr + "_snb(" + signature + ");") + print( + "#elif defined(__SSE3__) && " + "defined(LIBXSMM_GENTARGET_wsm_sp)" + ) + print(" libxsmm_smm_" + mnkstr + "_wsm(" + signature + ");") + print("#else") + print( + " const char transa = (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & " + "LIBXSMM_FLAGS) ? 'N' : 'T');" + ) + print( + " const char transb = (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & " + "LIBXSMM_FLAGS) ? 'N' : 'T');" + ) + print(" const float alpha = LIBXSMM_ALPHA, beta = LIBXSMM_BETA;") + print( + " const libxsmm_blasint " + "m = " + str(m) + ", " + "n = " + str(n) + ", " + "k = " + str(k) + ";" + ) + if 0 < prefetch: + print( + " LIBXSMM_UNUSED(pa);" + " LIBXSMM_UNUSED(pb);" + " LIBXSMM_UNUSED(pc);" + ) + print( + " LIBXSMM_INLINE_XGEMM(float, float, &transa, &transb," + " &m, &n, &k, &alpha, a, &m, b, &k, &beta, c, &m);" + ) + print("#endif") + print("}") + print + print + print( + "LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_smm_" + + mnkstr + + ")(const float* a, const float* b, float* c" + + pfsig + + ";" + ) + print( + "LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_smm_" + + mnkstr + + ")(const float* a, const float* b, float* c" + + pfsig + ) + print("{") + print(" libxsmm_smm_" + mnkstr + "(" + signature + ");") + print("}") + if 1 != precision: + pfsig = [ + optional + ")", + "\n" + ", const double* pa" + ", const double* pb" + ", const double* pc)", + ][0 < prefetch] + print + print + print( + "LIBXSMM_API void libxsmm_dmm_" + + mnkstr + + "(const double* a, const double* b, double* c" + + pfsig + ) + print("{") + print( + "#if defined(__AVX512F__) && " + "defined(LIBXSMM_GENTARGET_skx_dp) && \\" + ) + print(" !(defined(__AVX512PF__) && defined(__AVX512ER__))") + print(" libxsmm_dmm_" + mnkstr + "_skx(" + signature + ");") + print( + "#elif defined(__AVX512F__) && " + "defined(LIBXSMM_GENTARGET_knl_dp)" + ) + print(" libxsmm_dmm_" + mnkstr + "_knl(" + signature + ");") + print( + "#elif defined(__AVX2__) && " + "defined(LIBXSMM_GENTARGET_hsw_dp)" + ) + print(" libxsmm_dmm_" + mnkstr + "_hsw(" + signature + ");") + print( + "#elif defined(__AVX__) && " + "defined(LIBXSMM_GENTARGET_snb_dp)" + ) + print(" libxsmm_dmm_" + mnkstr + "_snb(" + signature + ");") + print( + "#elif defined(__SSE3__) && " + "defined(LIBXSMM_GENTARGET_wsm_dp)" + ) + print(" libxsmm_dmm_" + mnkstr + "_wsm(" + signature + ");") + print("#else") + print( + " const char transa = (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & " + "LIBXSMM_FLAGS) ? 'N' : 'T');" + ) + print( + " const char transb = (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & " + "LIBXSMM_FLAGS) ? 'N' : 'T');" + ) + print(" const double alpha = LIBXSMM_ALPHA, beta = LIBXSMM_BETA;") + print( + " const libxsmm_blasint " + "m = " + str(m) + ", " + "n = " + str(n) + ", " + "k = " + str(k) + ";" + ) + if 0 < prefetch: + print( + " LIBXSMM_UNUSED(pa);" + " LIBXSMM_UNUSED(pb);" + " LIBXSMM_UNUSED(pc);" + ) + print( + " LIBXSMM_INLINE_XGEMM(double, double, &transa, &transb," + " &m, &n, &k, &alpha, a, &m, b, &k, &beta, c, &m);" + ) + print("#endif") + print("}") + print + print + print( + "LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dmm_" + + mnkstr + + ")(const double* a, const double* b, double* c" + + pfsig + + ";" + ) + print( + "LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dmm_" + + mnkstr + + ")(const double* a, const double* b, double* c" + + pfsig + ) + print("{") + print(" libxsmm_dmm_" + mnkstr + "(" + signature + ");") + print("}") + else: + sys.tracebacklimit = 0 + raise ValueError(sys.argv[0] + ": wrong number of arguments!") diff --git a/third_party/libxsmm/scripts/libxsmm_utilities.py b/third_party/libxsmm/scripts/libxsmm_utilities.py new file mode 100755 index 0000000000000000000000000000000000000000..b63372be45b971bb6720f0ac90320732e59a8a56 --- /dev/null +++ b/third_party/libxsmm/scripts/libxsmm_utilities.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) Intel Corporation - All rights reserved. # +# This file is part of the LIBXSMM library. # +# # +# For information on the license, see the LICENSE file. # +# Further information: https://github.com/hfp/libxsmm/ # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################### +# Hans Pabst (Intel Corp.) +############################################################################### +import itertools +import operator +import inspect +import sys +import os + +try: + from functools import reduce +except ImportError: + pass + + +def upper_list(lists, level): + nlist = len(lists) + upper = [level, level + nlist][1 > level] - 1 + above = lists[upper] + if above: + return above + elif -nlist <= level: + return upper_list(lists, level - 1) + else: + return [] + + +# https://docs.python.org/3/library/itertools.html#itertools.product +def itertools_product(*args): + # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy + # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111 + pools = [tuple(pool) for pool in args] + result = [[]] + for pool in pools: + result = [x + [y] for x in result for y in pool] + for prod in result: + yield tuple(prod) + + +def load_mnklist(argv, threshold, inputformat=0, resultset=None): + if resultset is None: + resultset = set() + if 0 == inputformat: # indexes format + resultset = set(map(lambda mnk: tuple(map(int, mnk.split("_"))), argv)) + elif -1 == inputformat: # new input format + groups = map( + lambda group: [int(i) for i in group.split()], + " ".join(argv[0:]).split(","), + ) + resultset = set( + itertools.chain( + *[list(itertools_product(*(i, i, i))) for i in groups] + ) + ) + elif -2 == inputformat: # legacy format + mlist = list( + map( + int, + map( + lambda s: str(s).replace(",", " ").strip(), + argv[2:2 + int(argv[0])], + ), + ) + ) + nlist = list( + map( + int, + map( + lambda s: str(s).replace(",", " ").strip(), + argv[2 + int(argv[0]):2 + int(argv[0]) + int(argv[1])], + ), + ) + ) + klist = list( + map( + int, + map( + lambda s: str(s).replace(",", " ").strip(), + argv[2 + int(argv[0]) + int(argv[1]):], + ), + ) + ) + mnk = [mlist, nlist, klist] + top = [ + [mlist, upper_list(mnk, 0)][0 == len(mlist)], + [nlist, upper_list(mnk, 1)][0 == len(nlist)], + [klist, upper_list(mnk, 2)][0 == len(klist)], + ] + for m in top[0]: + for n in top[1]: + if not nlist: + n = m + for k in top[2]: + if not klist: + k = n + if not mlist: + m = k + resultset.add((m, n, k)) + else: + sys.tracebacklimit = 0 + raise ValueError("load_mnklist: unexpected input format!") + if 0 != threshold: # threshold requested + return set( + filter( + lambda mnk: (0 < mnk[0]) + and (0 < mnk[1]) + and (0 < mnk[2]) + and (threshold >= (mnk[0] * mnk[1] * mnk[2])), + resultset, + ) + ) + else: + return set( + filter( + lambda mnk: (0 < mnk[0]) and (0 < mnk[1]) and (0 < mnk[2]), + resultset, + ) + ) + + +def max_mnk(mnklist, init=0, index=None): + if index is not None and 0 <= index and index < 3: + mapped = map(lambda mnk: mnk[index], mnklist) + else: + mapped = map(lambda mnk: mnk[0] * mnk[1] * mnk[2], mnklist) + return reduce(max, mapped, init) + + +def median(list_of_numbers, fallback=None, average=True): + size = len(list_of_numbers) + if 0 < size: + # TODO: use nth element + list_of_numbers.sort() + size2 = int(size / 2) + if average and 0 == (size - size2 * 2): + medval = int( + 0.5 * (list_of_numbers[size2 - 1] + list_of_numbers[size2]) + + 0.5 + ) + else: + medval = list_of_numbers[size2] + if fallback is not None: + result = min(medval, fallback) + else: + result = medval + elif fallback is not None: + result = fallback + else: + sys.tracebacklimit = 0 + raise ValueError("median: empty list!") + return result + + +def is_pot(num): + return 0 <= num or 0 == (num & (num - 1)) + + +def sanitize_alignment(alignment): + if 0 >= alignment: + alignment = [1, 64][0 != alignment] + elif not is_pot(alignment): + sys.tracebacklimit = 0 + raise ValueError( + "sanitize_alignment: alignment must be a Power of Two (POT)!" + ) + return alignment + + +def align_value(n, typesize, alignment): + if 0 < typesize and 0 < alignment: + return ( + ((n * typesize + alignment - 1) / alignment) * alignment + ) / typesize + else: + sys.tracebacklimit = 0 + raise ValueError("align_value: invalid input!") + + +def version_branch_from_file(version_filepath): + version_file = open(version_filepath, "r") + version, branch, sep = "1.0", "", "-" + try: + version_list, n = version_file.read().replace("\n", "").split(sep), 0 + for word in version_list: + if not reduce( + operator.and_, + (subword.isdigit() for subword in word.split(".")), + True, + ): + branch += [sep + word, word][0 == n] + n += 1 + else: + break + version = sep.join(version_list[n:]) + finally: + version_file.close() + return (version, branch) + + +def version_numbers(version, branch=None): + version_list = version.split("-") + if not version_list[0][0].isdigit(): + vbranch = version_list[0] + else: + vbranch = "master" + if branch is None or vbranch == branch: + minor = update = patch = 0 + major = 1 + n = len(version_list) + if 1 < n: + patch_list = version_list[n - 1] + if 1 == len(patch_list.split(".")): + version_list = version_list[n - 2].split(".") + if version_list != [vbranch]: + patch = int(patch_list) + else: + major = int(patch_list) + else: + version_list = patch_list.split(".") + else: + version_list = version.split(".") + n = len(version_list) + try: + if 0 < n: + major = int(version_list[0]) + if 1 < n: + minor = int(version_list[1]) + if 2 < n: + update = int(version_list[2]) + except ValueError: + # if 1 == n: major = 0 + pass + else: + major = minor = update = patch = -1 + return major, minor, update, patch + + +def version_branch(max_strlen=-1): + version_filename = "version.txt" + filepath_default = os.path.realpath( + os.path.join( + os.path.dirname(inspect.getfile(inspect.currentframe())), + "..", + version_filename, + ) + ) + filepath_local = os.path.realpath(version_filename) # local version file + realversion, branch = version_branch_from_file(filepath_default) + version = realversion + out_of_tree = filepath_default != filepath_local + if out_of_tree and os.path.isfile(filepath_local): + local, ignored = version_branch_from_file(filepath_local) + if version_numbers(realversion) < version_numbers(local): + version = local + if 0 < max_strlen: + start = int(max_strlen / 3) + cut = max( + branch.rfind("-", start, max_strlen), + branch.rfind("_", start, max_strlen), + branch.rfind(".", start, max_strlen), + ) + if start < cut: + branch = branch[0:cut] + else: + branch = branch[0:max_strlen] + return (version, branch, realversion) + + +if __name__ == "__main__": + argc = len(sys.argv) + if 1 < argc: + arg1 = int(sys.argv[1]) + else: + arg1 = 0 + if -1 == arg1: + if 5 < argc: + # threshold = int(sys.argv[2]) + mnk_size = int(sys.argv[3]) + dims = load_mnklist(sys.argv[4:4 + mnk_size], 0, -1) + dims = load_mnklist(sys.argv[4 + mnk_size:], 0, -2, dims) + mnklist = map(lambda mnk: "_".join(map(str, mnk)), sorted(dims)) + print(" ".join(mnklist)) + elif 3 == argc: + major, minor, update, patch = ( + version_numbers(sys.argv[2], "release") + ) + print(["0", "1"][0 == patch]) + elif 0 <= arg1: + if 0 == arg1 and 3 == argc: + major, minor, update, patch = version_numbers(sys.argv[2]) + print(major) # soname version + else: + version, branch, realversion = version_branch() + major, minor, update, patch = version_numbers(version) + if 1 == arg1: + print(major) + elif 2 == arg1: + print(minor) + elif 3 == arg1: + print(update) + elif 4 == arg1: + print(patch) + elif "" != branch: + print("{}-{}".format(branch, realversion)) + else: + print(realversion) + else: + sys.tracebacklimit = 0 + raise ValueError( + "{}: wrong ({}) number of arguments ('{}') given!".format( + sys.argv[0], argc - 1, " ".join(sys.argv[1:])) + ) diff --git a/third_party/libxsmm/scripts/libxsmm_version.sh b/third_party/libxsmm/scripts/libxsmm_version.sh new file mode 100755 index 0000000000000000000000000000000000000000..670b15d0a5f8f56cc4b3a2640d90c0f359f1c945 --- /dev/null +++ b/third_party/libxsmm/scripts/libxsmm_version.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env sh +############################################################################### +# Copyright (c) Intel Corporation - All rights reserved. # +# This file is part of the LIBXSMM library. # +# # +# For information on the license, see the LICENSE file. # +# Further information: https://github.com/hfp/libxsmm/ # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################### +# Hans Pabst (Intel Corp.) +############################################################################### +GIT=$(command -v git) + +SHIFT=0 +if [ "$1" ]; then + SHIFT=$1 +fi + +NAME=$(${GIT} rev-parse --abbrev-ref HEAD 2>/dev/null) +MAIN=$(${GIT} describe --tags --match "[0-9]*" --abbrev=0 2>/dev/null) + +if [ "${MAIN}" ]; then + VERSION="${NAME}-${MAIN}" + REVC=$(${GIT} rev-list --count --no-merges "${MAIN}"..HEAD 2>/dev/null) +else + VERSION=${NAME} + REVC=$(${GIT} rev-list --count --no-merges HEAD 2>/dev/null) +fi + +echo "${VERSION}-$((REVC+SHIFT))" diff --git a/third_party/libxsmm/src/libxsmm_cpuid_arm.c b/third_party/libxsmm/src/libxsmm_cpuid_arm.c new file mode 100644 index 0000000000000000000000000000000000000000..d3c495927fd4d759a2ffed4293de4d631dd7ba16 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_cpuid_arm.c @@ -0,0 +1,96 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include +#include +#include +#include + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#include +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if defined(_MSC_VER) +# define LIBXSMM_CPUID_ARM_ENC16(OP0, OP1, CRN, CRM, OP2) ( \ + (((OP0) & 1) << 14) | \ + (((OP1) & 7) << 11) | \ + (((CRN) & 15) << 7) | \ + (((CRM) & 15) << 3) | \ + (((OP2) & 7) << 0)) +# define ID_AA64ISAR1_EL1 LIBXSMM_CPUID_ARM_ENC16(0b11, 0b000, 0b0000, 0b0110, 0b001) +# define ID_AA64PFR0_EL1 LIBXSMM_CPUID_ARM_ENC16(0b11, 0b000, 0b0000, 0b0100, 0b000) +# define LIBXSMM_CPUID_ARM_MRS(RESULT, ID) RESULT = _ReadStatusReg(ID) +#else +# define LIBXSMM_CPUID_ARM_MRS(RESULT, ID) __asm__ __volatile__( \ + "mrs %0," LIBXSMM_STRINGIFY(ID) : "=r"(RESULT)) +#endif + + +#if defined(LIBXSMM_PLATFORM_AARCH64) +LIBXSMM_APIVAR_DEFINE(jmp_buf internal_cpuid_arm_jmp_buf); + +LIBXSMM_API_INTERN void internal_cpuid_arm_sigill(int /*signum*/); +LIBXSMM_API_INTERN void internal_cpuid_arm_sigill(int signum) +{ + void (*const handler)(int) = signal(signum, internal_cpuid_arm_sigill); + LIBXSMM_ASSERT(SIGILL == signum); + if (SIG_ERR != handler) longjmp(internal_cpuid_arm_jmp_buf, 1); +} +#endif + + +LIBXSMM_API int libxsmm_cpuid_arm(libxsmm_cpuid_info* info) +{ + static int result = LIBXSMM_TARGET_ARCH_UNKNOWN; +#if defined(LIBXSMM_PLATFORM_AARCH64) +#if defined(__APPLE__) && defined(__arm64__) + result = LIBXSMM_AARCH64_V81; +#else + if (LIBXSMM_TARGET_ARCH_UNKNOWN == result) { /* avoid redetecting features */ + void (*const handler)(int) = signal(SIGILL, internal_cpuid_arm_sigill); + result = LIBXSMM_AARCH64_V81; + if (SIG_ERR != handler) { + uint64_t capability; /* 64-bit value */ + if (0 == setjmp(internal_cpuid_arm_jmp_buf)) { + LIBXSMM_CPUID_ARM_MRS(capability, ID_AA64ISAR1_EL1); + if (0xF & capability) { /* DPB */ + result = LIBXSMM_AARCH64_V82; + if (0 == setjmp(internal_cpuid_arm_jmp_buf)) { + LIBXSMM_CPUID_ARM_MRS(capability, ID_AA64PFR0_EL1); + if (0xF & (capability >> 32)) { /* SVE */ + result = LIBXSMM_AARCH64_A64FX; + } + } + } + } + /* restore original state */ + signal(SIGILL, handler); + } + if (NULL != info) LIBXSMM_MEMZERO127(info); + } +#endif +#else +# if !defined(NDEBUG) + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: libxsmm_cpuid_arm called on non-ARM platform!\n"); + } +# endif + if (NULL != info) LIBXSMM_MEMZERO127(info); +#endif + return result; +} diff --git a/third_party/libxsmm/src/libxsmm_cpuid_x86.c b/third_party/libxsmm/src/libxsmm_cpuid_x86.c new file mode 100644 index 0000000000000000000000000000000000000000..6e90c0cb6ae0a786948b5307449d9ddd93f96ddb --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_cpuid_x86.c @@ -0,0 +1,336 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include +#include +#include +#if !defined(_WIN32) +# include +#endif + +#if defined(LIBXSMM_PLATFORM_X86) +/* XGETBV: receive results (EAX, EDX) for eXtended Control Register (XCR). */ +/* CPUID, receive results (EAX, EBX, ECX, EDX) for requested FUNCTION/SUBFN. */ +#if defined(_MSC_VER) /*defined(_WIN32) && !defined(__GNUC__)*/ +# define LIBXSMM_XGETBV(XCR, EAX, EDX) { \ + unsigned long long libxsmm_xgetbv_ = _xgetbv(XCR); \ + EAX = (int)libxsmm_xgetbv_; \ + EDX = (int)(libxsmm_xgetbv_ >> 32); \ + } +# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) { \ + int libxsmm_cpuid_x86_[/*4*/] = { 0, 0, 0, 0 }; \ + __cpuidex(libxsmm_cpuid_x86_, FUNCTION, SUBFN); \ + EAX = (unsigned int)libxsmm_cpuid_x86_[0]; \ + EBX = (unsigned int)libxsmm_cpuid_x86_[1]; \ + ECX = (unsigned int)libxsmm_cpuid_x86_[2]; \ + EDX = (unsigned int)libxsmm_cpuid_x86_[3]; \ + } +# elif defined(__GNUC__) || !defined(_CRAYC) +# if (64 > (LIBXSMM_BITS)) + LIBXSMM_EXTERN LIBXSMM_RETARGETABLE int __get_cpuid( /* prototype */ + unsigned int, unsigned int*, unsigned int*, unsigned int*, unsigned int*); +# define LIBXSMM_XGETBV(XCR, EAX, EDX) EAX = (EDX) = 0xFFFFFFFF +# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) \ + EAX = (EBX) = (EDX) = 0; ECX = (SUBFN); \ + __get_cpuid(FUNCTION, &(EAX), &(EBX), &(ECX), &(EDX)) +# else /* 64-bit */ +# define LIBXSMM_XGETBV(XCR, EAX, EDX) __asm__ __volatile__( \ + ".byte 0x0f, 0x01, 0xd0" /*xgetbv*/ : "=a"(EAX), "=d"(EDX) : "c"(XCR) \ + ) +# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) \ + __asm__ __volatile__ (".byte 0x0f, 0xa2" /*cpuid*/ \ + : "=a"(EAX), "=b"(EBX), "=c"(ECX), "=d"(EDX) \ + : "a"(FUNCTION), "b"(0), "c"(SUBFN), "d"(0) \ + ) +# endif +# else /* legacy Cray Compiler */ +# define LIBXSMM_XGETBV(XCR, EAX, EDX) EAX = (EDX) = 0 +# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) EAX = (EBX) = (ECX) = (EDX) = 0 +# endif +#endif + +#define LIBXSMM_CPUID_CHECK(VALUE, CHECK) ((CHECK) == ((CHECK) & (VALUE))) + + +LIBXSMM_API int libxsmm_cpuid_x86(libxsmm_cpuid_info* info) +{ + static int result = LIBXSMM_TARGET_ARCH_UNKNOWN; +#if defined(LIBXSMM_PLATFORM_X86) + unsigned int eax, ebx, ecx, edx; + LIBXSMM_CPUID_X86(0, 0/*ecx*/, eax, ebx, ecx, edx); + if (1 <= eax) { /* CPUID max. leaf */ + /* avoid redetecting features but redetect on request (info given) */ + if (LIBXSMM_TARGET_ARCH_UNKNOWN == result || NULL != info) { + int feature_cpu = LIBXSMM_X86_GENERIC, feature_os = LIBXSMM_X86_GENERIC, has_context = 0; + unsigned int maxleaf = eax; +# if defined(__linux__) + if (0 == libxsmm_se && LIBXSMM_TARGET_ARCH_UNKNOWN == result) { + FILE *const selinux = fopen("/sys/fs/selinux/enforce", "rb"); + if (NULL != selinux) { + if (1 == fread(&libxsmm_se, 1/*sizeof(char)*/, 1/*count*/, selinux)) { + libxsmm_se = ('0' != libxsmm_se ? 1 : 0); + } + else { /* conservative assumption in case of read-error */ + libxsmm_se = 1; + } + fclose(selinux); + } + } +# elif defined(MAP_JIT) + libxsmm_se = 1; +# endif + LIBXSMM_CPUID_X86(1, 0/*ecx*/, eax, ebx, ecx, edx); + if (LIBXSMM_CPUID_CHECK(ecx, 0x00000001)) { /* SSE3(0x00000001) */ + if (LIBXSMM_CPUID_CHECK(ecx, 0x00100000)) { /* SSE42(0x00100000) */ + if (LIBXSMM_CPUID_CHECK(ecx, 0x10000000)) { /* AVX(0x10000000) */ + if (LIBXSMM_CPUID_CHECK(ecx, 0x00001000)) { /* FMA(0x00001000) */ + unsigned int ecx2; + LIBXSMM_CPUID_X86(7, 0/*ecx*/, eax, ebx, ecx2, edx); + /* AVX512F(0x00010000), AVX512CD(0x10000000) */ + if (LIBXSMM_CPUID_CHECK(ebx, 0x10010000)) { /* Common */ + /* AVX512DQ(0x00020000), AVX512BW(0x40000000), AVX512VL(0x80000000) */ + if (LIBXSMM_CPUID_CHECK(ebx, 0xC0020000)) { /* AVX512-Core */ + if (LIBXSMM_CPUID_CHECK(ecx2, 0x00000800)) { /* VNNI */ + unsigned int edx2; /* we need to save edx for AMX check */ +# if 0 /* no check required yet */ + unsigned int ecx3; + LIBXSMM_CPUID_X86(7, 1/*ecx*/, eax, ebx, ecx3, edx); +# else + LIBXSMM_CPUID_X86(7, 1/*ecx*/, eax, ebx, ecx2, edx2); +# endif + if (LIBXSMM_CPUID_CHECK(eax, 0x00000020)) { /* BF16 */ + feature_cpu = LIBXSMM_X86_AVX512_CPX; + if (LIBXSMM_CPUID_CHECK(edx, 0x03400000)) { /* AMX-TILE, AMX-INT8, AMX-BF16 */ + feature_cpu = LIBXSMM_X86_AVX512_SPR; + } + } + else feature_cpu = LIBXSMM_X86_AVX512_CLX; /* CLX */ + } + else feature_cpu = LIBXSMM_X86_AVX512_CORE; /* SKX */ + } + /* AVX512PF(0x04000000), AVX512ER(0x08000000) */ + else if (LIBXSMM_CPUID_CHECK(ebx, 0x0C000000)) { /* AVX512-MIC */ + if (LIBXSMM_CPUID_CHECK(edx, 0x0000000C)) { /* KNM */ + feature_cpu = LIBXSMM_X86_AVX512_KNM; + } + else feature_cpu = LIBXSMM_X86_AVX512_MIC; /* KNL */ + } + else feature_cpu = LIBXSMM_X86_AVX512; /* AVX512-Common */ + } + else feature_cpu = LIBXSMM_X86_AVX2; + } + else feature_cpu = LIBXSMM_X86_AVX; + } + else feature_cpu = LIBXSMM_X86_SSE42; + } + else feature_cpu = LIBXSMM_X86_SSE3; + } +# if !defined(LIBXSMM_INTRINSICS_DEBUG) + LIBXSMM_ASSERT_MSG(LIBXSMM_STATIC_TARGET_ARCH <= LIBXSMM_MAX(LIBXSMM_X86_GENERIC, feature_cpu), "missed detecting ISA extensions"); + /* coverity[dead_error_line] */ + if (LIBXSMM_STATIC_TARGET_ARCH > feature_cpu) feature_cpu = LIBXSMM_STATIC_TARGET_ARCH; +# endif + /* XSAVE/XGETBV(0x04000000), OSXSAVE(0x08000000) */ + if (LIBXSMM_CPUID_CHECK(ecx, 0x0C000000)) { /* OS SSE support */ + feature_os = LIBXSMM_MIN(LIBXSMM_X86_SSE42, feature_cpu); + if (LIBXSMM_X86_AVX <= feature_cpu) { + LIBXSMM_XGETBV(0, eax, edx); + if (LIBXSMM_CPUID_CHECK(eax, 0x00000006)) { /* OS XSAVE 256-bit */ + feature_os = LIBXSMM_MIN(LIBXSMM_X86_AVX2, feature_cpu); + if (LIBXSMM_CPUID_CHECK(eax, 0x000000E0)) { /* OS XSAVE 512-bit */ + feature_os = LIBXSMM_MIN(LIBXSMM_X86_AVX512_CPX, feature_cpu); + if (LIBXSMM_X86_AVX512_SPR <= feature_cpu && 7 <= maxleaf + && LIBXSMM_CPUID_CHECK(eax, 0x00060000)) /* OS XSAVE 512-bit */ + { + feature_os = feature_cpu; /* unlimited AMX */ + } + } + } + } + } + else if (LIBXSMM_X86_GENERIC <= feature_cpu) { + /* assume FXSAVE, which should be fine + * 16 years after the first x86_64 OS + */ + feature_os = LIBXSMM_X86_SSE42; + } + else feature_os = LIBXSMM_TARGET_ARCH_GENERIC; + has_context = (LIBXSMM_STATIC_TARGET_ARCH >= feature_cpu || feature_os >= feature_cpu) ? 1 : 0; + if (LIBXSMM_TARGET_ARCH_UNKNOWN == result && 0 != libxsmm_verbosity) { /* library code is expected to be mute */ +# if !defined(LIBXSMM_TARGET_ARCH) + const int target_vlen32 = libxsmm_cpuid_vlen32(feature_cpu); + const char *const compiler_support = (libxsmm_cpuid_vlen32(LIBXSMM_MAX_STATIC_TARGET_ARCH) < target_vlen32 + ? "" : (((2 <= libxsmm_verbosity || 0 > libxsmm_verbosity) && LIBXSMM_MAX_STATIC_TARGET_ARCH < feature_cpu) + ? "highly " : NULL)); + if (NULL != compiler_support) { + const char *const name = libxsmm_cpuid_name( /* exclude MIC when running on Core processors */ + (((LIBXSMM_X86_AVX512_MIC == LIBXSMM_MAX_STATIC_TARGET_ARCH) || + (LIBXSMM_X86_AVX512_KNM == LIBXSMM_MAX_STATIC_TARGET_ARCH)) && (LIBXSMM_X86_AVX512_CORE <= feature_cpu)) + ? LIBXSMM_X86_AVX2 : LIBXSMM_MAX_STATIC_TARGET_ARCH); + fprintf(stderr, "LIBXSMM WARNING: %soptimized non-JIT code paths are limited to \"%s\"!\n", compiler_support, name); + } +# endif +# if !defined(NDEBUG) && defined(__OPTIMIZE__) + fprintf(stderr, "LIBXSMM WARNING: library is optimized without -DNDEBUG and contains debug code!\n"); +# endif +# if !defined(__APPLE__) || !defined(__MACH__) /* permitted features */ + if (0 == has_context) { + fprintf(stderr, "LIBXSMM WARNING: detected CPU features are not permitted by the OS!\n"); + if (0 == libxsmm_se) { + fprintf(stderr, "LIBXSMM WARNING: downgraded code generation to supported features!\n"); + } + } +# endif + } + /* macOS is faulting AVX-512 (on-demand larger state) */ + result = feature_cpu; +# if !defined(__APPLE__) || !defined(__MACH__) +# if 0 /* opportunistic */ + if (0 == libxsmm_se) +# endif + { /* only permitted features */ + result = LIBXSMM_MIN(feature_cpu, feature_os); + } +# endif + if (NULL != info) { + LIBXSMM_CPUID_X86(0x80000007, 0/*ecx*/, eax, ebx, ecx, edx); + info->constant_tsc = LIBXSMM_CPUID_CHECK(edx, 0x00000100); + info->has_context = has_context; + } + } + } + else { + if (NULL != info) LIBXSMM_MEMZERO127(info); + result = LIBXSMM_X86_GENERIC; + } +#else +# if !defined(NDEBUG) + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: libxsmm_cpuid_x86 called on non-x86 platform!\n"); + } +# endif + if (NULL != info) LIBXSMM_MEMZERO127(info); +#endif + return result; +} + + +LIBXSMM_API int libxsmm_cpuid(void) +{ +#if defined(LIBXSMM_PLATFORM_X86) + return libxsmm_cpuid_x86(NULL/*info*/); +#else + return libxsmm_cpuid_arm(NULL/*info*/); +#endif +} + + +/** + * This implementation also accounts for non-x86 platforms, + * which not only allows to resolve any given ID but to + * fallback gracefully ("unknown"). + */ +LIBXSMM_API const char* libxsmm_cpuid_name(int id) +{ + const char* target_arch = NULL; + switch (id) { + case LIBXSMM_X86_AVX512_SPR: { + target_arch = "spr"; + } break; + case LIBXSMM_X86_AVX512_CPX: { + target_arch = "cpx"; + } break; + case LIBXSMM_X86_AVX512_CLX: { + target_arch = "clx"; + } break; + case LIBXSMM_X86_AVX512_CORE: { + target_arch = "skx"; + } break; + case LIBXSMM_X86_AVX512_KNM: { + target_arch = "knm"; + } break; + case LIBXSMM_X86_AVX512_MIC: { + target_arch = "knl"; + } break; + case LIBXSMM_X86_AVX512: { + /* TODO: rework BE to use target ID instead of set of strings (target_arch = "avx3") */ + target_arch = "hsw"; + } break; + case LIBXSMM_X86_AVX2: { + target_arch = "hsw"; + } break; + case LIBXSMM_X86_AVX: { + target_arch = "snb"; + } break; + case LIBXSMM_X86_SSE42: { + target_arch = "wsm"; + } break; + case LIBXSMM_X86_SSE3: { + target_arch = "sse3"; + } break; + case LIBXSMM_AARCH64_V81: { + target_arch = "aarch64"; + } break; + case LIBXSMM_AARCH64_A64FX: { + target_arch = "a64fx"; + } break; + case LIBXSMM_TARGET_ARCH_GENERIC: { + target_arch = "generic"; + } break; + default: if (LIBXSMM_X86_GENERIC <= id) { + target_arch = "x86_64"; + } + else { + target_arch = "unknown"; + } + } + LIBXSMM_ASSERT(NULL != target_arch); + return target_arch; +} + + +/** + * This implementation also accounts for non-x86 platforms, + * which not only allows to resolve any given ID but to + * fallback gracefully (scalar). + */ +LIBXSMM_API int libxsmm_cpuid_vlen32(int id) +{ + int result; +#if defined(LIBXSMM_PLATFORM_X86) + if (LIBXSMM_X86_AVX512 <= id) { + result = 16; + } + else if (LIBXSMM_X86_AVX <= id) { + result = 8; + } + else if (LIBXSMM_X86_SSE42 <= id) { + result = 4; + } + else +#elif defined(LIBXSMM_PLATFORM_AARCH64) + if (LIBXSMM_AARCH64_V81 == id) { + result = 4; + } + else if (LIBXSMM_AARCH64_A64FX == id) { + result = 16; + } + else +#else + LIBXSMM_UNUSED(id); +#endif + { /* scalar */ + result = 1; + } + return result; +} diff --git a/third_party/libxsmm/src/libxsmm_diff.h b/third_party/libxsmm/src/libxsmm_diff.h new file mode 100644 index 0000000000000000000000000000000000000000..fed7b82e196f47ef9fd32846d5e1d46236513455 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_diff.h @@ -0,0 +1,144 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DIFF_H +#define LIBXSMM_DIFF_H + +#include + +#if !defined(LIBXSMM_DIFF_AVX512_ENABLED) && 0 +# define LIBXSMM_DIFF_AVX512_ENABLED +#endif + +#define LIBXSMM_DIFF_4_DECL(A) const uint32_t */*const*/ A = NULL +#define LIBXSMM_DIFF_4_ASSIGN(A, B) (A) = (B) +#define LIBXSMM_DIFF_4_LOAD(A, SRC) A = (const uint32_t*)(SRC) +#define LIBXSMM_DIFF_4(A, B, ...) ((unsigned char)(0 != (*(A) ^ (*(const uint32_t*)(B))))) + +#define LIBXSMM_DIFF_8_DECL(A) const uint64_t */*const*/ A = NULL +#define LIBXSMM_DIFF_8_ASSIGN(A, B) (A) = (B) +#define LIBXSMM_DIFF_8_LOAD(A, SRC) A = (const uint64_t*)(SRC) +#define LIBXSMM_DIFF_8(A, B, ...) ((unsigned char)(0 != (*(A) ^ (*(const uint64_t*)(B))))) + +#define LIBXSMM_DIFF_SSE_DECL(A) __m128i A = LIBXSMM_INTRINSICS_MM_UNDEFINED_SI128() +#define LIBXSMM_DIFF_SSE_ASSIGN(A, B) (A) = (B) +#define LIBXSMM_DIFF_SSE_LOAD(A, SRC) A = LIBXSMM_INTRINSICS_LOADU_SI128((const __m128i*)(SRC)) +#define LIBXSMM_DIFF_SSE(A, B, ...) ((unsigned char)(0xFFFF != _mm_movemask_epi8(_mm_cmpeq_epi8( \ + A, LIBXSMM_INTRINSICS_LOADU_SI128((const __m128i*)(B)))))) + +#if (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH) /*|| defined(LIBXSMM_INTRINSICS_TARGET)*/ +# define LIBXSMM_DIFF_16_DECL LIBXSMM_DIFF_SSE_DECL +# define LIBXSMM_DIFF_16_ASSIGN LIBXSMM_DIFF_SSE_ASSIGN +# define LIBXSMM_DIFF_16_LOAD LIBXSMM_DIFF_SSE_LOAD +# define LIBXSMM_DIFF_16 LIBXSMM_DIFF_SSE +#else +# define LIBXSMM_DIFF_16_DECL(A) const uint64_t */*const*/ A = NULL +# define LIBXSMM_DIFF_16_ASSIGN(A, B) (A) = (B) +# define LIBXSMM_DIFF_16_LOAD(A, SRC) A = (const uint64_t*)(SRC) +# define LIBXSMM_DIFF_16(A, B, ...) ((unsigned char)(0 != (((A)[0] ^ (*(const uint64_t*)(B))) | \ + ((A)[1] ^ ((const uint64_t*)(B))[1])))) +#endif + +#define LIBXSMM_DIFF_AVX2_DECL(A) __m256i A = LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256() +#define LIBXSMM_DIFF_AVX2_ASSIGN(A, B) (A) = (B) +#define LIBXSMM_DIFF_AVX2_LOAD(A, SRC) A = _mm256_loadu_si256((const __m256i*)(SRC)) +#define LIBXSMM_DIFF_AVX2(A, B, ...) ((unsigned char)(-1 != _mm256_movemask_epi8(_mm256_cmpeq_epi8( \ + A, _mm256_loadu_si256((const __m256i*)(B)))))) + +#if (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_DIFF_32_DECL LIBXSMM_DIFF_AVX2_DECL +# define LIBXSMM_DIFF_32_ASSIGN LIBXSMM_DIFF_AVX2_ASSIGN +# define LIBXSMM_DIFF_32_LOAD LIBXSMM_DIFF_AVX2_LOAD +# define LIBXSMM_DIFF_32 LIBXSMM_DIFF_AVX2 +#else +# define LIBXSMM_DIFF_32_DECL(A) LIBXSMM_DIFF_16_DECL(A); LIBXSMM_DIFF_16_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _)) +# define LIBXSMM_DIFF_32_ASSIGN(A, B) LIBXSMM_DIFF_16_ASSIGN(A, B); LIBXSMM_DIFF_16_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_32_, B, _)) +# define LIBXSMM_DIFF_32_LOAD(A, SRC) LIBXSMM_DIFF_16_LOAD(A, SRC); LIBXSMM_DIFF_16_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), (const uint64_t*)(SRC) + 2) +# define LIBXSMM_DIFF_32(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_16(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_16(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), (const uint64_t*)(B) + 2, __VA_ARGS__))) +#endif + +#define LIBXSMM_DIFF_48_DECL(A) LIBXSMM_DIFF_16_DECL(A); LIBXSMM_DIFF_32_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _)) +#define LIBXSMM_DIFF_48_ASSIGN(A, B) LIBXSMM_DIFF_16_ASSIGN(A, B); LIBXSMM_DIFF_32_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_48_, B, _)) +#define LIBXSMM_DIFF_48_LOAD(A, SRC) LIBXSMM_DIFF_16_LOAD(A, SRC); LIBXSMM_DIFF_32_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), (const uint64_t*)(SRC) + 2) +#define LIBXSMM_DIFF_48(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_16(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_32(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), (const uint64_t*)(B) + 2, __VA_ARGS__))) + +#define LIBXSMM_DIFF_64SW_DECL(A) LIBXSMM_DIFF_32_DECL(A); LIBXSMM_DIFF_32_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _)) +#define LIBXSMM_DIFF_64SW_ASSIGN(A, B) LIBXSMM_DIFF_32_ASSIGN(A, B); LIBXSMM_DIFF_32_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_64_, B, _)) +#define LIBXSMM_DIFF_64SW_LOAD(A, SRC) LIBXSMM_DIFF_32_LOAD(A, SRC); LIBXSMM_DIFF_32_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), (const uint64_t*)(SRC) + 4) +#define LIBXSMM_DIFF_64SW(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_32(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_32(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), (const uint64_t*)(B) + 4, __VA_ARGS__))) + +#if defined(LIBXSMM_DIFF_AVX512_ENABLED) +# define LIBXSMM_DIFF_AVX512_DECL(A) __m512i A = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32() +# define LIBXSMM_DIFF_AVX512_ASSIGN(A, B) (A) = (B) +# define LIBXSMM_DIFF_AVX512_LOAD(A, SRC) A = _mm512_loadu_si512((const __m512i*)(SRC)) +# define LIBXSMM_DIFF_AVX512(A, B, ...) ((unsigned char)(0xFFFF != (unsigned int)/*_cvtmask16_u32*/(_mm512_cmpeq_epi32_mask( \ + A, _mm512_loadu_si512((const __m512i*)(B)))))) +#else +# define LIBXSMM_DIFF_AVX512_DECL LIBXSMM_DIFF_64SW_DECL +# define LIBXSMM_DIFF_AVX512_ASSIGN LIBXSMM_DIFF_64SW_ASSIGN +# define LIBXSMM_DIFF_AVX512_LOAD LIBXSMM_DIFF_64SW_LOAD +# define LIBXSMM_DIFF_AVX512 LIBXSMM_DIFF_64SW +#endif + +#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) +# define LIBXSMM_DIFF_64_DECL LIBXSMM_DIFF_AVX512_DECL +# define LIBXSMM_DIFF_64_ASSIGN LIBXSMM_DIFF_AVX512_ASSIGN +# define LIBXSMM_DIFF_64_LOAD LIBXSMM_DIFF_AVX512_LOAD +# define LIBXSMM_DIFF_64 LIBXSMM_DIFF_AVX512 +#else +# define LIBXSMM_DIFF_64_DECL LIBXSMM_DIFF_64SW_DECL +# define LIBXSMM_DIFF_64_ASSIGN LIBXSMM_DIFF_64SW_ASSIGN +# define LIBXSMM_DIFF_64_LOAD LIBXSMM_DIFF_64SW_LOAD +# define LIBXSMM_DIFF_64 LIBXSMM_DIFF_64SW +#endif + +#define LIBXSMM_DIFF_DECL(N, A) LIBXSMM_CONCATENATE3(LIBXSMM_DIFF_, N, _DECL)(A) +#define LIBXSMM_DIFF_LOAD(N, A, SRC) LIBXSMM_CONCATENATE3(LIBXSMM_DIFF_, N, _LOAD)(A, SRC) +#define LIBXSMM_DIFF(N) LIBXSMM_CONCATENATE(LIBXSMM_DIFF_, N) + +#define LIBXSMM_DIFF_N(TYPE, RESULT, DIFF, A, BN, ELEMSIZE, STRIDE, HINT, N) { \ + const char* libxsmm_diff_b_ = (const char*)(BN) + (size_t)(HINT) * (STRIDE); \ + for (RESULT = (HINT); (RESULT) < (N); ++(RESULT)) { \ + if (0 == DIFF(A, libxsmm_diff_b_, ELEMSIZE)) break; \ + libxsmm_diff_b_ += (STRIDE); \ + } \ + if ((N) == (RESULT)) { /* wrong hint */ \ + TYPE libxsmm_diff_r_ = 0; \ + libxsmm_diff_b_ = (const char*)(BN); /* reset */ \ + for (; libxsmm_diff_r_ < (HINT); ++libxsmm_diff_r_) { \ + if (0 == DIFF(A, libxsmm_diff_b_, ELEMSIZE)) { \ + RESULT = libxsmm_diff_r_; \ + break; \ + } \ + libxsmm_diff_b_ += (STRIDE); \ + } \ + } \ +} + + +/** Function type representing the diff-functionality. */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE unsigned int (*libxsmm_diff_function)( + const void* /*a*/, const void* /*b*/, ... /*size*/); + +/** Compare two data blocks of 4 Byte each. */ +LIBXSMM_API unsigned char libxsmm_diff_4(const void* a, const void* b, ...); +/** Compare two data blocks of 8 Byte each. */ +LIBXSMM_API unsigned char libxsmm_diff_8(const void* a, const void* b, ...); +/** Compare two data blocks of 16 Byte each. */ +LIBXSMM_API unsigned char libxsmm_diff_16(const void* a, const void* b, ...); +/** Compare two data blocks of 32 Byte each. */ +LIBXSMM_API unsigned char libxsmm_diff_32(const void* a, const void* b, ...); +/** Compare two data blocks of 48 Byte each. */ +LIBXSMM_API unsigned char libxsmm_diff_48(const void* a, const void* b, ...); +/** Compare two data blocks of 64 Byte each. */ +LIBXSMM_API unsigned char libxsmm_diff_64(const void* a, const void* b, ...); + +#endif /*LIBXSMM_DIFF_H*/ + diff --git a/third_party/libxsmm/src/libxsmm_dnn.c b/third_party/libxsmm/src/libxsmm_dnn.c new file mode 100644 index 0000000000000000000000000000000000000000..4627a34c8e228e05583645571d0b7c73106206a9 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn.c @@ -0,0 +1,759 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include +#include "libxsmm_main.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(_OPENMP) +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + + +LIBXSMM_API_INTERN void libxsmm_dnn_init(int target_arch) +{ + LIBXSMM_UNUSED(target_arch); +} + + +LIBXSMM_API_INTERN void libxsmm_dnn_finalize(void) +{ +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_get_feature_map_blocks( int C, int K, int* C_block, int* K_block, int* fm_lp_block, libxsmm_dnn_datatype datatype_in, libxsmm_dnn_datatype datatype_out ) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + int ifmblock = 0; + int ofmblock = 0; + int lp_block = 0; + int tmp_max_c_block = 64; + int tmp_max_k_block = 64; + int tmp_block = 0; + + /* init libxsmm */ + LIBXSMM_INIT + + /* C */ + if ( ((libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR) && (datatype_in == LIBXSMM_DNN_DATATYPE_BF16)) || + (libxsmm_target_archid < LIBXSMM_X86_AVX512 ) ) { + tmp_max_c_block = 32; + } else if ( libxsmm_target_archid == LIBXSMM_AARCH64_V81 ) { + tmp_max_c_block = 16; + } + if ( C < tmp_max_c_block ) { + ifmblock = C; + } else { + for ( tmp_block = 1; tmp_block <= tmp_max_c_block; tmp_block *= 2 ) { + if ( C % tmp_block == 0 ) ifmblock = tmp_block; + } + } + + /* K */ + if ( ((libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR) && (datatype_in == LIBXSMM_DNN_DATATYPE_BF16)) || + (libxsmm_target_archid < LIBXSMM_X86_AVX512 ) ) { + tmp_max_k_block = 32; + } else if ( libxsmm_target_archid == LIBXSMM_AARCH64_V81 ) { + tmp_max_k_block = 16; + } + if ( K < tmp_max_k_block ) { + ofmblock = K; + } else { + for ( tmp_block = 1; tmp_block <= tmp_max_k_block; tmp_block *= 2 ) { + if ( K % tmp_block == 0 ) ofmblock = tmp_block; + } + } + + /* when do we need VNNI format? */ + if ( (datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + lp_block = 1; + } else if ( (datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + lp_block = 2; + } else if ( (datatype_in == LIBXSMM_DNN_DATATYPE_I16) && ((datatype_out == LIBXSMM_DNN_DATATYPE_I32) || (datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) { + lp_block = 2; + } else if (datatype_in == LIBXSMM_DNN_DATATYPE_I8) { + lp_block = 4; + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + + *C_block = ifmblock; + *K_block = ofmblock; + *fm_lp_block = lp_block; + + return status; +} + + +LIBXSMM_API const char* libxsmm_dnn_get_error(libxsmm_dnn_err_t code) +{ + switch (code) { + case LIBXSMM_DNN_SUCCESS: + return "LIBXSMM DNN Success!"; + case LIBXSMM_DNN_WARN_FALLBACK: + return "LIBXSMM DNN Warning: Falling back to naive code as target is currently not supported by LIBXSMM!"; + case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING: + return "LIBXSMM DNN Warning: RNN cell suboptimal minibatch blocking!"; + case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING: + return "LIBXSMM DNN Warning: RNN cell suboptimal input feature blocking!"; + case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING: + return "LIBXSMM DNN Warning: RNN cell suboptimal output feature blocking!"; + case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING: + return "LIBXSMM DNN Warning: FC layer suboptimal minibatch blocking!"; + case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING: + return "LIBXSMM DNN Warning: FC layer suboptimal input feature blocking!"; + case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING: + return "LIBXSMM DNN Warning: FC layer suboptimal output feature blocking!"; + case LIBXSMM_DNN_ERR_GENERAL: + return "LIBXSMM DNN Error: General error occurred!"; + case LIBXSMM_DNN_ERR_CREATE_HANDLE: + return "LIBXSMM DNN Error: Handle creation failed!"; + case LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE: + return "LIBXSMM DNN Error: Requested datatype is not available!"; + case LIBXSMM_DNN_ERR_INVALID_BLOCKING: + return "LIBXSMM DNN Error: Requested Input/Output buffer size cannot be blocked!"; + case LIBXSMM_DNN_ERR_INVALID_HANDLE: + return "LIBXSMM DNN Error: An invalid handle was provided!"; + case LIBXSMM_DNN_ERR_DATA_NOT_BOUND: + return "LIBXSMM DNN Error: Not all required sources and destinations have been bound to convolution!"; + case LIBXSMM_DNN_ERR_CREATE_TENSOR: + return "LIBXSMM DNN Error: Tensor creation failed!"; + case LIBXSMM_DNN_ERR_INVALID_TENSOR: + return "LIBXSMM DNN Error: Invalid tensor was specified!"; + case LIBXSMM_DNN_ERR_MISMATCH_TENSOR: + return "LIBXSMM DNN Error: Tensor doesn't match handle it should be bind to!"; + case LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR: + return "LIBXSMM DNN Error: Invalid handle or tensor!"; + case LIBXSMM_DNN_ERR_INVALID_KIND: + return "LIBXSMM DNN Error: Invalid convolution kind!"; + case LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW: + return "LIBXSMM DNN Error: NCHW format is currently not natively supported by LIBXSMM!"; + case LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT: + return "LIBXSMM DNN Error: Unsupported destination format when copying data!"; + case LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT: + return "LIBXSMM DNN Error: Unsupported source format when copying data!"; + case LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE: + return "LIBXSMM DNN Error: Unsupported format when requesting a convolution!"; + case LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS: + return "LIBXSMM DNN Error: KCRS format is currently not natively supported by LIBXSMM!"; + case LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL: + return "LIBXSMM DNN Error: Invalid format was specified!"; + case LIBXSMM_DNN_ERR_CREATE_LAYOUT: + return "LIBXSMM DNN Error: Layout creation failed!"; + case LIBXSMM_DNN_ERR_INVALID_LAYOUT: + return "LIBXSMM DNN Error: Invalid layout was specified!"; + case LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH: + return "LIBXSMM DNN Error: Unsupported architecture!"; + case LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED: + return "LIBXSMM DNN Error: scratch binding failed as scratch was not allocated!"; + case LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE: + return "LIBXSMM DNN Error: an unknown tensor type was provided!"; + case LIBXSMM_DNN_ERR_INVALID_ALGO: + return "LIBXSMM DNN Error: Invalid algorithm was specified!"; + case LIBXSMM_DNN_ERR_INVALID_PADDING: + return "LIBXSMM DNN Error: Invalid padding was specified!"; + case LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL: + return "LIBXSMM DNN Error: time steps should be >= 2 for RNN/LSTM!"; + case LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS: + return "LIBXSMM DNN Error: failed to create internal layout arrays!"; + case LIBXSMM_DNN_ERR_NOT_IMPLEMENTED: + return "LIBXSMM DNN Error: the requested functionality is right now not implemented!"; + case LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER: + return "LIBXSMM DNN Error: the requested order of fusion in batch norm is right now not implemented!"; + case LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION: + return "LIBXSMM DNN Error: the requested fusion in batch norm is right now not implemented!"; + case LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN: + return "LIBXSMM DNN Error: Unsupported format when requesting a fused batch norm!"; + case LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING: + return "LIBXSMM DNN Error: Unsupported pooling operations was requested!"; + case LIBXSMM_DNN_ERR_INVALID_FORMAT_FC: + return "LIBXSMM DNN Error: Unsupported format when requesting a fullyconnected layer!"; + case LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN: + return "LIBXSMM DNN Error: max sequence length is shorter than sequence length we attempt to set!"; + case LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER: + return "LIBXSMM DNN Error: the requested order of fusion in group norm is right now not implemented!"; + case LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION: + return "LIBXSMM DNN Error: the requested fusion in group norm is right now not implemented!"; + case LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION: + return "LIBXSMM DNN Error: the requested fusion in fullyconnected is right now not implemented!"; + default: + return "LIBXSMM DNN Error: Unknown error or warning occurred!"; + } +} + + +LIBXSMM_API size_t libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype) +{ + switch (datatype) { + case LIBXSMM_DNN_DATATYPE_F32: return 4; + case LIBXSMM_DNN_DATATYPE_I32: return 4; + case LIBXSMM_DNN_DATATYPE_BF16: return 2; + case LIBXSMM_DNN_DATATYPE_I16: return 2; + case LIBXSMM_DNN_DATATYPE_I8: return 1; + /* no error expected as enumeration really arrives at an enum; compiler-checked */ + default: return 1; + } +} + + +LIBXSMM_API size_t libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype) +{ + size_t l_cl_width_bytes; + + /* init libxsmm */ + LIBXSMM_INIT + + if ( libxsmm_target_archid == LIBXSMM_X86_GENERIC || + libxsmm_target_archid == LIBXSMM_X86_SSE3 || + libxsmm_target_archid == LIBXSMM_X86_SSE42 ) { + l_cl_width_bytes = 16; + } else if ( libxsmm_target_archid == LIBXSMM_X86_AVX2 || + libxsmm_target_archid == LIBXSMM_X86_AVX ) { + l_cl_width_bytes = 32; + } else { + l_cl_width_bytes = 64; + } + + return l_cl_width_bytes/libxsmm_dnn_typesize(datatype); +} + +LIBXSMM_API_INLINE float libxsmm_internal_get_max( float* in_buffer, int length ) { + float absmax_value = LIBXSMM_ABS(in_buffer[0]); + int i = 0; +#ifdef _OPENMP + LIBXSMM_OMP_VAR(i); +# pragma omp parallel private(i) + { + float my_absmax_value = absmax_value; +# pragma omp for + for (i = 0; i < length; ++i ) { + if (LIBXSMM_ABS(in_buffer[i]) > my_absmax_value) { + my_absmax_value = LIBXSMM_ABS(in_buffer[i]); + } + } +# pragma omp critical + { + if (my_absmax_value > absmax_value) { + absmax_value = my_absmax_value; + } + } + } +#else + for (i = 1; i < length; ++i ) { + if (LIBXSMM_ABS(in_buffer[i]) > absmax_value) { + absmax_value = LIBXSMM_ABS(in_buffer[i]); + } + } +#endif + + return absmax_value; +} + + +LIBXSMM_API_INLINE unsigned char libxsmm_internal_get_max_exp( float* in_buffer, int length ) { + libxsmm_intfloat val_exp; + unsigned char max_exp = 0; + + /* bit-wise conversion to int */ + val_exp.f = libxsmm_internal_get_max( in_buffer, length ); + /* shift by mantissa to the right and convert to char */ + max_exp = (unsigned char)((val_exp.ui & LIBXSMM_DNN_MASK_ABS_F32) >> LIBXSMM_DNN_MANT_SZ_F32); + + return max_exp; +} + + +LIBXSMM_API_INLINE short libxsmm_internal_quantize_scalar_no_scf( float input, unsigned char max_exp, unsigned char add_shift, int round_mode ) { + libxsmm_intfloat value; + unsigned int qvalue = 0; + unsigned int mant = 0; + unsigned int sign = 0; + unsigned char rhs = 0; + unsigned char exp_off = 0; + + /* init libxsmm */ + LIBXSMM_INIT + + /* in case of zero we don't need to do anything */ + if (LIBXSMM_FEQ(input, 0)) { + qvalue = 0; + } else { + /* let's get a float copy to work on */ + /* vinp = LIBXSMM_INTRINSICS_MM512_LOAD_PS( in_buffer[i] ); */ + value.f = input; + /* let's compute the offset of the current exp at pos i from max offset, we need to mask the sign bit though */ + /*__m512i vexp = _mm512_cvtps_epi32(_mm512_getexp_ps (vinp)); + __m512i vexp_off = _mm512_sub_epi32(maxexpf, vexp);*/ + exp_off = (unsigned char)(max_exp - ((value.ui & LIBXSMM_DNN_MASK_ABS_F32) >> LIBXSMM_DNN_MANT_SZ_F32)); + /* cut out mantissa and set leading bit */ + /*__m512i mmask = _mm512_set1_epi32(LIBXSMM_DNN_MASK_MANT_F32); + __m512i vmant = _mm512_or_epi32(_mm512_set1_epi32(0x1 << LIBXSMM_DNN_MANT_SZ_F32), _mm512_and_epi32( _mm512_castps_si512( vinp ), mmask));*/ + mant = ((0x1 << LIBXSMM_DNN_MANT_SZ_F32) | (value.ui & LIBXSMM_DNN_MASK_MANT_F32)); + /* extract sign */ + /* __mmask16 smask = _mm512_cmplt_ps_mask (inp, _mm512_set1_ps(0)); */ + sign = ((value.ui & LIBXSNN_DNN_MASK_SIGN_F32) >> (LIBXSMM_DNN_SZ_F32-1)); + /* calculate rhs, be aware of the now explicit leading bit, @TODO add DFP8/4 */ + rhs = (unsigned char)((LIBXSMM_DNN_MANT_SZ_F32+1) - LIBXSMM_DNN_MANT_DFP16 + exp_off + add_shift); + /* some safety, to generate 0 when we fall off quant region, @TODO issue a LIBXSMM WARNING: that we shifted out the entire mantissa */ + if (rhs > (LIBXSMM_DNN_MANT_SZ_F32+1)) { + rhs = (LIBXSMM_DNN_MANT_SZ_F32+1); + } + /* finally shift the value into the region we need, this is now a 15-add_rhs bit number for the max value in in_buffer */ + qvalue = (mant >> rhs); + /* handle sign, 2 complement */ + if ( (sign > 0) && (qvalue > 0) ) { + qvalue = (~qvalue + 1); + } + + if (round_mode == LIBXSMM_DNN_QUANT_BIAS_ROUND) { + /* biased rounding towards next bigger number */ + /* first let's determine in the original number if we need a bias rounding, @TODO need fix for F64 */ + int bias_needed = (mant & (0x3 << (rhs-2))); + /* apply bias */ + if (bias_needed > 0) { + qvalue++; + } + } else if (round_mode == LIBXSMM_DNN_QUANT_NEAREST_ROUND) { + int nearest_needed = (mant & (0x1 << (rhs-1))); + /* apply rounding */ + if ((nearest_needed > 0) && (rhs > 1)) { + qvalue++; + } + } else if (round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND) { + /* stochastic rounding, as implemented in the IBM paper from 2015, @TODO, fix F64 and DFP8 */ + const float eps = LIXSMMM_DNN_RES_DFP16; + /* coverity[dont_call] */ + const float r = (float)rand(); + libxsmm_intfloat fvalue; + float p, q; + /* masking all bits which will be shifted out */ + fvalue.ui = value.ui & ((LIBXSMM_DNN_MASK_FULL_F32) << rhs); + /* drawing a random number */ + p = r/((float)RAND_MAX); + q = (input - fvalue.f)/eps; + /* apply rounding if needed */ + if ((p + q) > 0.5f) { + ++qvalue; + } + } else { + /* do nothing about rounding, just chop */ + } + } + + return (short)qvalue; +} + + +/* @TODO make this routine aware of any int type */ +LIBXSMM_API void libxsmm_dnn_quantize( float* in_buffer, short* out_buffer, int length, unsigned char add_shift, unsigned char* scf, int round_mode ) { + int i = 0; + + /* init libxsmm */ + LIBXSMM_INIT + + /* in case we are using FP-Mul based quantization we use a different path for now + @TODO let's unify the paths by using the similar vectorization for both */ + if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) { + const float max_value = libxsmm_internal_get_max( in_buffer, length ); + int maxexp = 0; + /* take return value of LIBXSMM_FREXPF to mute static analysis issue */ + float scfq = LIBXSMM_FREXPF(max_value, &maxexp); + maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift); + scfq = libxsmm_sexp2_i8i(-maxexp); + +#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) + if ( length % 16 == 0 ) { + __m512 vscfq = _mm512_set1_ps(scfq); +#ifdef _OPENMP +# pragma omp parallel for private(i) +#endif + for (i = 0; i < length; i+=16 ) { + _mm256_stream_si256( (__m256i *)&(out_buffer[i]), LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( &(in_buffer[i]), vscfq ) ); + } + } else { +#endif +#ifdef _OPENMP +# pragma omp parallel for private(i) +#endif + for (i = 0; i < length; ++i ) { + out_buffer[i] = (short)LIBXSMM_ROUNDF(in_buffer[i] * scfq); + } +#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) + } +#endif + /* @TODO, we need to potentially fix this unsigned char problem */ +#if !defined(NDEBUG) /* library code is expected to be mute */ + if (maxexp > 0) { + fprintf(stderr, "error quant fil\n"); + } +#endif + *scf = (unsigned char)(-maxexp); + } else { + /* get max exponent */ + unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, length ); + + /* if we go for stochastic rounding, let's initialize random seed */ + if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) { + srand(libxsmm_timer_tick() % ((unsigned int)-1)); + } + +#ifdef _OPENMP +# pragma omp parallel for private(i) +#endif + for (i = 0; i < length; ++i ) { + out_buffer[i] = libxsmm_internal_quantize_scalar_no_scf( in_buffer[i], max_exp, add_shift, round_mode ); + } + + *scf = (unsigned char)(14 - add_shift - (max_exp - 127)); + } +} + + +LIBXSMM_API void libxsmm_dnn_quantize_act( float* in_buffer, short* out_buffer, unsigned int N, unsigned int C, unsigned int H, unsigned int W, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ) { + LIBXSMM_VLA_DECL(5, const float, in, in_buffer, C/cblk_f32, H, W, cblk_f32); + LIBXSMM_VLA_DECL(6, short, out, out_buffer, C/(cblk_i16*lp_blk), H, W, cblk_i16, lp_blk); + const unsigned int cblk = C/(cblk_i16*lp_blk); + int i1 = 0, i2 = 0, i3 = 0, i4 = 0, i5, i6; + + /* init libxsmm */ + LIBXSMM_INIT + + /* some quick and dirty checks */ + assert((C % cblk_f32) == 0); + assert((C % cblk_i16) == 0); + + /* in case we are using FP-Mul based quantization we use a different path for now + @TODO let's unify the paths by using the similar vectorization for both */ + if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) { + const float max_value = libxsmm_internal_get_max( in_buffer, N*C*H*W ); + int maxexp = 0; + /* take return value of LIBXSMM_FREXPF to mute static analysis issue */ + float scfq = LIBXSMM_FREXPF(max_value, &maxexp); + maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift); + scfq = libxsmm_sexp2_i8i(-maxexp); + +#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) + if ( (cblk_f32 == 16) && (cblk_i16*lp_blk == 16) ) { + __m512 vscfq = _mm512_set1_ps(scfq); +#ifdef _OPENMP + LIBXSMM_OMP_VAR(i1); +# pragma omp parallel for private(i1) +#endif + for (i1 = 0; i1 < (int)(N*C*H*W); i1 += 16 ) { + _mm256_stream_si256( (__m256i *)&(out_buffer[i1]), LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( &(in_buffer[i1]), vscfq ) ); + } + } else { +#endif +#ifdef _OPENMP + LIBXSMM_OMP_VAR(i1); LIBXSMM_OMP_VAR(i2); LIBXSMM_OMP_VAR(i3); LIBXSMM_OMP_VAR(i4); LIBXSMM_OMP_VAR(i5); LIBXSMM_OMP_VAR(i6); +# pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4) +#endif + for (i1 = 0; i1 < (int)N; ++i1 ) { + for (i2 = 0; i2 < (int)cblk; ++i2 ) { + for (i3 = 0; i3 < (int)H; ++i3 ) { + for (i4 = 0; i4 < (int)W; ++i4 ) { + for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) { + for (i6 = 0; i6 < (int)lp_blk; ++i6 ) { + const int fi1 = i1; + const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)/cblk_f32; + const int fi3 = i3; + const int fi4 = i4; + const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)%cblk_f32; + LIBXSMM_VLA_ACCESS(6, out, i1, i2, i3, i4, i5, i6, cblk, H, W, cblk_i16, lp_blk) = (short)LIBXSMM_ROUNDF( + LIBXSMM_VLA_ACCESS(5, in, fi1, fi2, fi3, fi4, fi5, C / cblk_f32, H, W, cblk_f32) * scfq); + } + } + } + } + } + } +#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) + } +#endif + /* @TODO, we need to potentially fix this unsigned char problem */ +#if !defined(NDEBUG) /* library code is expected to be mute */ + if (maxexp > 0) { + fprintf(stderr, "error quant act\n"); + } +#endif + *scf = (unsigned char)(-maxexp); + } else { + /* get max exponent */ + unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, N*C*H*W ); + + /* if we go for stochastic rounding, let's initialize random seed */ + if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) { + srand(libxsmm_timer_tick() % ((unsigned int)-1)); + } + +#ifdef _OPENMP +# pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4) +#endif + for (i1 = 0; i1 < (int)N; ++i1 ) { + for (i2 = 0; i2 < (int)cblk; ++i2 ) { + for (i3 = 0; i3 < (int)H; ++i3 ) { + for (i4 = 0; i4 < (int)W; ++i4 ) { + for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) { + for (i6 = 0; i6 < (int)lp_blk; ++i6 ) { + const int fi1 = i1; + const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)/cblk_f32; + const int fi3 = i3; + const int fi4 = i4; + const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)%cblk_f32; + LIBXSMM_VLA_ACCESS(6, out, i1, i2, i3, i4, i5, i6, cblk, H, W, cblk_i16, lp_blk) = libxsmm_internal_quantize_scalar_no_scf( + LIBXSMM_VLA_ACCESS(5, in, fi1, fi2, fi3, fi4, fi5, C / cblk_f32, H, W, cblk_f32), max_exp, add_shift, round_mode); + } + } + } + } + } + } + + *scf = (unsigned char)(14 - add_shift - (max_exp - 127)); + } +} + + +LIBXSMM_API void libxsmm_dnn_quantize_fil( float* in_buffer, short* out_buffer, unsigned int K, unsigned int C, unsigned int R, unsigned int S, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int kblk_f32, unsigned int kblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ) { + LIBXSMM_VLA_DECL(6, const float, in, in_buffer, C/cblk_f32, R, S, cblk_f32, kblk_f32); + LIBXSMM_VLA_DECL(7, short, out, out_buffer, C/(cblk_i16*lp_blk), R, S, cblk_i16, kblk_i16, lp_blk); + unsigned int cblk = C/(cblk_i16*lp_blk); + unsigned int kblk = K/kblk_i16; + int i1 = 0, i2 = 0, i3 = 0, i4 = 0, i5, i6, i7; + + /* some quick and dirty checks */ + assert((C % cblk_f32) == 0); + assert((C % (cblk_i16*lp_blk)) == 0); + assert((K % kblk_f32) == 0); + assert((K % kblk_i16) == 0); + assert((lp_blk % 2) == 0); + + /* init libxsmm */ + LIBXSMM_INIT + + /* in case we are using FP-Mul based quantization we use a different path for now + @TODO let's unify the paths by using the similar vectorization for both */ + if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) { + const float max_value = libxsmm_internal_get_max( in_buffer, K*C*R*S ); + int maxexp = 0; + /* take return value of LIBXSMM_FREXPF to mute static analysis issue */ + float scfq = LIBXSMM_FREXPF(max_value, &maxexp); + maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift); + scfq = libxsmm_sexp2_i8i(-maxexp); + +#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) + if ( (kblk_f32 == 16) && (cblk_f32 == 16) && (kblk_i16 == 16) && (cblk_i16*lp_blk == 16) ) { + const __m512 vscfq = _mm512_set1_ps(scfq); + const __m512i permute_compact_idx = _mm512_set_epi32(15,14,13,12,7,6,5,4,11,10,9,8,3,2,1,0); +#ifdef _OPENMP +# pragma omp parallel for private(i1, i2, i3, i4, i5) LIBXSMM_OPENMP_COLLAPSE(4) +#endif + for (i1 = 0; i1 < (int)kblk; ++i1 ) { + for (i2 = 0; i2 < (int)cblk; ++i2 ) { + for (i3 = 0; i3 < (int)R; ++i3 ) { + for (i4 = 0; i4 < (int)S; ++i4 ) { + for (i5 = 0; i5 < 16; i5+=2 ) { + __m256i even_ch = LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( + &LIBXSMM_VLA_ACCESS(6, in, i1, i2, i3, i4, i5 + 0, 0, C / cblk_f32, R, S, cblk_f32, kblk_f32), vscfq); + __m256i odd_ch = LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( + &LIBXSMM_VLA_ACCESS(6, in, i1, i2, i3, i4, i5 + 1, 0, C / cblk_f32, R, S, cblk_f32, kblk_f32), vscfq); + __m256i compressed_lo = _mm256_unpacklo_epi16(even_ch, odd_ch); + __m256i compressed_hi = _mm256_unpackhi_epi16(even_ch, odd_ch); + __m512i compact = _mm512_inserti64x4( _mm512_setzero_si512(), compressed_lo, 0); + compact = _mm512_inserti64x4(compact, compressed_hi, 1); + compact = _mm512_permutexvar_epi32(permute_compact_idx, compact); + LIBXSMM_INTRINSICS_MM512_STREAM_SI512( + (void*)&LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5 / 2, 0, 0, cblk, R, S, cblk_i16, kblk_i16, lp_blk), + compact); + } + } + } + } + } + } else { +#endif +#ifdef _OPENMP + LIBXSMM_OMP_VAR(i1); LIBXSMM_OMP_VAR(i2); LIBXSMM_OMP_VAR(i3); LIBXSMM_OMP_VAR(i4); LIBXSMM_OMP_VAR(i5); LIBXSMM_OMP_VAR(i6); LIBXSMM_OMP_VAR(i7); +# pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4) +#endif + for (i1 = 0; i1 < (int)kblk; ++i1 ) { + for (i2 = 0; i2 < (int)cblk; ++i2 ) { + for (i3 = 0; i3 < (int)R; ++i3 ) { + for (i4 = 0; i4 < (int)S; ++i4 ) { + for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) { + for (i6 = 0; i6 < (int)kblk_i16; ++i6 ) { + for (i7 = 0; i7 < (int)lp_blk; ++i7 ) { + const int fi1 = ((i1*kblk_i16)+i6)/kblk_f32; + const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)/cblk_f32; + const int fi3 = i3; + const int fi4 = i4; + const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)%cblk_f32; + const int fi6 = ((i1*kblk_i16)+i6)%kblk_f32; + LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5, i6, i7, cblk, R, S, cblk_i16, kblk_i16, lp_blk) = (short)LIBXSMM_ROUNDF( + LIBXSMM_VLA_ACCESS(6, in, fi1, fi2, fi3, fi4, fi5, fi6, C / cblk_f32, R, S, cblk_f32, kblk_f32) * scfq); + } + } + } + } + } + } + } +#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) + } +#endif + /* @TODO, we need to potentially fix this unsigned char problem */ +#if !defined(NDEBUG) /* library code is expected to be mute */ + if (maxexp > 0) { + fprintf(stderr, "error quant fil\n"); + } +#endif + *scf = (unsigned char)(-maxexp); + } else { + /* get max exponent */ + unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, K*C*R*S ); + + /* if we go for stochastic rounding, let's initialize random seed */ + if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) { + srand(libxsmm_timer_tick() % ((unsigned int)-1)); + } + +#ifdef _OPENMP +# pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4) +#endif + for (i1 = 0; i1 < (int)kblk; ++i1 ) { + for (i2 = 0; i2 < (int)cblk; ++i2 ) { + for (i3 = 0; i3 < (int)R; ++i3 ) { + for (i4 = 0; i4 < (int)S; ++i4 ) { + for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) { + for (i6 = 0; i6 < (int)kblk_i16; ++i6 ) { + for (i7 = 0; i7 < (int)lp_blk; ++i7 ) { + const int fi1 = ((i1*kblk_i16)+i6)/kblk_f32; + const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)/cblk_f32; + const int fi3 = i3; + const int fi4 = i4; + const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)%cblk_f32; + const int fi6 = ((i1*kblk_i16)+i6)%kblk_f32; + LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5, i6, i7, cblk, R, S, cblk_i16, kblk_i16, lp_blk) = libxsmm_internal_quantize_scalar_no_scf( + LIBXSMM_VLA_ACCESS(6, in, fi1, fi2, fi3, fi4, fi5, fi6, C / cblk_f32, R, S, cblk_f32, kblk_f32), max_exp, add_shift, round_mode); + } + } + } + } + } + } + } + + *scf = (unsigned char)(14 - add_shift - (max_exp - 127)); + } +} + + +LIBXSMM_API void libxsmm_dnn_dequantize( short* in_buffer, float* out_buffer, int length, unsigned char scf ) { + const float val_exp = libxsmm_sexp2_i8i(-scf); + int i = 0; + +#ifdef _OPENMP +# pragma omp parallel for private(i) +#endif + for ( i = 0; i < length; ++i ) { + out_buffer[i] = ((float)in_buffer[i])*val_exp; + } +} + + +LIBXSMM_API void libxsmm_truncate_convert_f32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int length) { + unsigned int i = 0; + + /* truncate buffer to bf16 */ + for ( i = 0; i < length; ++i ) { + libxsmm_bfloat16_hp t; + + t.f = in[i]; + out[i] = t.i[1]; + } +} + + +LIBXSMM_API void libxsmm_rnaz_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len) { + unsigned int i = 0; + + /* truncate buffer to bf16 */ + for ( i = 0; i < len; ++i ) { + unsigned int int_round = 0; + unsigned int do_round = 1; + + int_round = *((unsigned int*)&(in[i])); + + /* we don't round NaN and inf */ + if ( (int_round & 0x7f800000) == 0x7f800000 ) { + do_round = 0; + } + + /* perform round nearest tie away from zero */ + if ( do_round != 0 ) { + int_round = int_round + 0x00008000; + } + + /* create the bf16 value by shifting out the lower 16bits */ + int_round = int_round >> 16; + + out[i] = (libxsmm_bfloat16)int_round; + } +} + + +LIBXSMM_API void libxsmm_rne_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len) { + unsigned int i = 0; + + /* truncate buffer to bf16 */ + for ( i = 0; i < len; ++i ) { + unsigned int int_round = 0; + unsigned int do_round = 1; + + int_round = *((unsigned int*)&(in[i])); + + /* we don't round NaN and inf */ + if ( (int_round & 0x7f800000) == 0x7f800000 ) { + do_round = 0; + } + + /* perform round nearest tie even */ + if ( do_round != 0 ) { + unsigned int fixup = (int_round >> 16) & 1; + int_round = int_round + 0x00007fff + fixup; + } + + /* create the bf16 value by shifting out the lower 16bits */ + int_round = int_round >> 16; + + out[i] = (unsigned short)int_round; + } +} + + +LIBXSMM_API void libxsmm_convert_bf16_f32(const libxsmm_bfloat16* in, float* out, unsigned int length) { + unsigned int i = 0; + + /* up-convert is super simple */ + for ( i = 0; i < length; ++i ) { + libxsmm_bfloat16_hp t; + + t.i[1] = in[i]; + t.i[0] = 0; + out[i] = t.f; + } +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_convolution.c b/third_party/libxsmm/src/libxsmm_dnn_convolution.c new file mode 100644 index 0000000000000000000000000000000000000000..2ba07679553c270c32021eee9bad545c4c844bb1 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_convolution.c @@ -0,0 +1,2747 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst, Alexander Heinecke, Evangelos Georganas, Rajkishore Barik (Intel Corp.) +******************************************************************************/ +#include +#include "libxsmm_main.h" +#include "libxsmm_dnn_convolution_forward.h" +#include "libxsmm_dnn_convolution_backward.h" +#include "libxsmm_dnn_convolution_weight_update.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(_OPENMP) +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#define MIXED 0 +#define KHWC 1 +#define HWKC 2 +#define CHWK 3 +#define HWCK 4 + +/**********************************************************/ +/* Helper functions for convolutions' general param setup */ +/**********************************************************/ +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_ifmblock( libxsmm_dnn_layer* handle ) { + int result = 1; + int ofm, lp; + + libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K, &result, &ofm, &lp, handle->desc.datatype_in, handle->desc.datatype_out ); + + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_ofmblock( libxsmm_dnn_layer* handle ) { + int result = 1; + int ifm, lp; + + libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K, &ifm, &result, &lp, handle->desc.datatype_in, handle->desc.datatype_out ); + + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fm_lp_block( libxsmm_dnn_layer* handle ) { + int result = 1; + int ifm, ofm; + + libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K, &ifm, &ofm, &result, handle->desc.datatype_in, handle->desc.datatype_out ); + + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fallback_loops_fwd( libxsmm_dnn_layer* handle ) { + int result = 0; + /* FIXME: For now fallback only if MB is not divisible by number of threads */ + if (handle->desc.N % handle->desc.threads != 0) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_blocksifm( libxsmm_dnn_layer* handle ) { + int result = handle->desc.C / handle->ifmblock; + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_blocksofm( libxsmm_dnn_layer* handle ) { + int result = handle->desc.K / handle->ofmblock; + return result; +} + +/**********************************************************/ +/* Helper functions for FWD convolutions' parameter setup */ +/**********************************************************/ +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_ofw_rb( libxsmm_dnn_layer* handle ) { + int result = 0; + result = handle->ofw; + if (handle->ofw == 56) { + result = 28; + } + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) { + if (handle->ofw % 2 == 0) { + result = handle->ofw/2; + } + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_pack_input_fwd( libxsmm_dnn_layer* handle ) { + int result = 0; + /* Pack only for small images and when having large K to amortize, and we can only pack for 1x1 convolutions */ + if ((handle->ofw <= 14) && (handle->desc.K > 512) && (handle->desc.R == 1) && (handle->desc.S == 1) && (handle->desc.u == 2) && (handle->desc.v == 2)) { + result = 1; + } + + /* For SPR we allow packing more aggressively to generate more efficient BRGEMMs */ + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) { + if ((handle->ofw <= 14) && (handle->desc.R == 1) && (handle->desc.S == 1) && (handle->desc.u == 2) && (handle->desc.v == 2)) { + result = 1; + } + } + + /* Make sure we don't pack when minibatch is not divisible by number of threads since H is used potentially for parallelism */ + if (handle->desc.N != handle->desc.threads) { + result = 0; + } + /* we don't pack for int8 */ + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) { + result = 0; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_ofh_rb( libxsmm_dnn_layer* handle ) { + int result = 1; + /* Multiple rows for "small" images and 1x1 convolutions */ + if ((handle->ofh <= 14) && (handle->desc.R == 1) && (handle->desc.S == 1)) { + result = handle->ofh; + } + + /* In this case we will be using fallback generic loops, thus ofh_rb should be 1 */ + if ((handle->desc.N % handle->desc.threads != 0) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) { + result = 1; + } + + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) { + if (handle->ofw == 7 && handle->ofh == 7 && handle->desc.R == 3 && handle->desc.S == 3) { + result = 7; + } + if (handle->ofw == 14 && handle->ofh == 14 /*&& handle->desc.R == 3 && handle->desc.S == 3*/) { + result = 2; + } + } + + /* Make sure we don't use multiple rows when we don't pack input and convolutions are strided*/ + if ((handle->pack_input == 0) && ((handle->desc.u !=1 ) || (handle->desc.v != 1))) { + result = 1; + } + + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_pixels_gemm( libxsmm_dnn_layer* handle ) { + int result = handle->fwd_ofw_rb * handle->fwd_ofh_rb; + /* In the case below we calculate redundantly pixels in order to efficiently use AMX */ + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) { + if (handle->desc.R != 1 || handle->desc.R != 1) { + if (handle->ofw < 24) { + result = (handle->fwd_ofw_rb+2*handle->desc.pad_w) * (handle->fwd_ofh_rb-2) + 2 * (handle->fwd_ofw_rb+handle->desc.pad_w); + } + } + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_block_H( libxsmm_dnn_layer* handle ) { + int result = 14; + + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) { + /* Spatial dimension block tuning for SPR */ + if ((handle->ofh == 7 && handle->desc.u == 2) || (handle->ofh == 14 && handle->desc.R != 3 ) || handle->ofh == 27 || (handle->ofh == 28 && handle->desc.R == 1) || handle->ofh == 48 || handle->ofh == 54 || handle->ofh == 56 || handle->ofh == 112 ) { + result = 4; + } + } else { + /* Block H only for large images */ + if (handle->ofh >= 28) { + result = 4; + } + if (handle->ofh == 28 && handle->desc.R == 3 ) { + result = 14; + } + } + /* Make sure it is divisible bu the ofh_rb factor in the kernel */ + while ( result % handle->fwd_ofh_rb != 0 ) { + result--; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_blocksifm_blocking( libxsmm_dnn_layer* handle ) { + int result = 1; + /* For 1x1 Convolutions bring in kernel all IFMs unless filters are huge*/ + if ((handle->desc.R == 1) && (handle->desc.S == 1) ) { + result = handle->blocksifm; + if ((handle->desc.C >= 2048) && (handle->desc.K >= 512)) { + result = 1; + } + if ( (handle->target_archid < LIBXSMM_X86_AVX512) && (handle->desc.C >= 512) ) { + result = 2; + } + if ( (handle->target_archid < LIBXSMM_X86_AVX512) && (handle->desc.C >= 1024) ) { + result = 4; + } + } else { + result = 1; + /* If small image can bring in more IFMS even if NOT 1x1 convolution */ + if (handle->ofw <= 7) { + result = 2; + } + } + if (handle->blocksifm % result != 0) { + result = 1; + } + + /* In case of SPR bring always in all accumulation */ + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8))) { + result = handle->blocksifm; + } + + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) { + result = handle->blocksifm; + } + + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_loop_order_fwd( libxsmm_dnn_layer* handle ) { + int result = 0; + /* Switch to loop order 1 only if 1x1 convolution with "large" input image and "small" K */ + if ((handle->desc.H >= 28) && (handle->desc.R == 1) && (handle->desc.S == 1) && (handle->desc.C >=512) && (handle->desc.K <=512)) { + result = 1; + } + if (handle->ofw == 56 && handle->desc.R == 1 && handle->desc.C == 256 && handle->desc.K == 64 ) { + result = 1; + } + if (handle->ofw == 28 && handle->desc.R == 1) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_fwd_IFM( libxsmm_dnn_layer* handle ) { + int result = 8; + if (handle->ofw == 7 && handle->desc.C == 2048 && handle->desc.K == 512) { + result = 4; + } + /* Make sure it is divisible by ifms in the kernel */ + while (result % handle->blocksifm_blocking != 0) { + result++; + } + result = LIBXSMM_MIN(handle->blocksifm, result); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_fwd_OFM( libxsmm_dnn_layer* handle ) { + int result = 8; + if (handle->ofw == 14 && handle->desc.K == 1024) { + result = 16; + } + if (handle->ofw == 7) { + result = 16; + } + result = LIBXSMM_MIN(handle->blocksofm, result); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_use_ofm_parallelization( libxsmm_dnn_layer* handle ) { + int result = 0; +#if 0 + /* Use "hybrid" minibatch/ofm parallelization if we have huge filters */ + if ((handle->desc.R >= 3) && (handle->desc.S >= 3) && (handle->desc.C >= 512) && (handle->desc.K >= 512)) { + result = 1; + } +#endif + if ((handle->ofw <= 7) && (handle->desc.C == 1024) && (handle->desc.K == 512)) { + result = 1; + } + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8))) { + if (handle->ofw == 7) { + result = 1; + } + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd( libxsmm_dnn_layer* handle ) { + int result = 0; + /* Avoid rim FMA if the convolution is 3x3 (non-strided) and the image is "small" */ + if ((handle->desc.R == 3) && (handle->desc.S == 3) && + (handle->desc.u == 1) && (handle->desc.v == 1) && + (handle->desc.pad_h_in == 1) && (handle->desc.pad_w_in == 1) && + (handle->desc.H == handle->desc.W) ) { + if (handle->ofw <= 28) { + result = 1; + } + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) { + result = 0; + } + } + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8))) { + result = 0; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_shuffle_filter_accesses( libxsmm_dnn_layer* handle ) { + int result = 0; + /* Shuffle filter accesses only if "pure minibatch" parallelization and large filters are involved */ + if ((handle->use_ofm_parallelization == 0) && (handle->desc.C > 512) && (handle->desc.K > 512)) { + result = 1; + } + if (handle->ofw == 7 && handle->desc.R == 3 && handle->desc.C == 512) { + result = 1; + } + if (handle->ofw == 7 && handle->desc.R == 1 && handle->desc.C == 512 && handle->desc.K == 2048) { + result = 1; + } + if (handle->ofw == 7 && handle->desc.R == 1 && handle->desc.C == 2048 && handle->desc.K == 512) { + result = 1; + } + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) { + result = 0; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_acc_load( libxsmm_dnn_layer* handle ) { + int result = 0; + if ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) { + if ((handle->desc.R == 1) && (handle->desc.S == 1)) { + if (handle->blocksifm_blocking == handle->blocksifm) { + result = 1; + } + } else { + if ((handle->blocksifm_blocking == handle->blocksifm) && (handle->avoid_fmas_in_rim == 0)) { + result = 1; + } + } + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_init_fwd_gemm_flags( libxsmm_dnn_layer* handle ) { + int result = 0; + +#if defined(LIBXSMM_DNN_CONVOLUTION_SETUP_USE_NTS) + /* If large image and NOT already loaded in accumulators, tnen use streaming stores */ + if ((handle->ofw >= 56) && (handle->desc.K >= 256) && (handle->avoid_acc_load == 1) && (handle->desc.R == 1) && (handle->desc.S == 1)) { + result = LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT; + } + if (handle->ofw == 56 && handle->desc.C == 64 && handle->desc.K == 64 && handle->desc.R == 1) { + result = LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT; + } + if (handle->ofw == 56 && handle->desc.C == 256 && handle->desc.K == 64 && handle->desc.R == 1) { + result = LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT; + } + /* Disable since the GEMM output is going to f32 scratch */ + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16 || handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) { + result = 0; + } +#else + LIBXSMM_UNUSED(handle); +#endif + + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8))) { + result = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG; + } + + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fwd_padding_copy( libxsmm_dnn_layer* handle ) { + int result = 0; + if ( (handle->desc.pad_h != handle->desc.pad_h_in) && (handle->desc.pad_w != handle->desc.pad_w_in) ) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_fwd_scratch( libxsmm_dnn_layer* handle ) { + handle->fwd_packing_padding_scratch_size = 0; + /* packing of input */ + if ( handle->pack_input != 0 ) { + handle->fwd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C * + handle->desc.H/handle->desc.u * + handle->desc.W/handle->desc.v * + libxsmm_dnn_typesize(handle->datatype_in); + } + /* logical padding with copying in the fly */ + if ( handle->fwd_padding_copy != 0 ) { + handle->fwd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C * + (handle->desc.H + 2*handle->desc.pad_h) * + (handle->desc.W + 2*handle->desc.pad_w) * + libxsmm_dnn_typesize(handle->datatype_in); + } + /* output buffer in high precision when we use BF16 */ + if ( ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16 ) || + ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8 ) ) { + handle->fwd_lp_output_full_scratch_size = (size_t) LIBXSMM_MAX(handle->desc.threads * handle->fwd_gemm_pixels * handle->ofmblock * libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32), handle->desc.N * handle->desc.K * handle->ofwp * handle->ofhp * libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32)); + handle->fwd_lp_output_block_scratch_size = (size_t)handle->desc.threads * handle->fwd_ofw_rb * + handle->fwd_ofh_rb * handle->ofmblock * + libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32); + } else { + handle->fwd_lp_output_full_scratch_size = 0; + handle->fwd_lp_output_block_scratch_size = 0; + } + /* align sizes to full cacheline */ + handle->fwd_packing_padding_scratch_size += ( handle->fwd_packing_padding_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->fwd_packing_padding_scratch_size % LIBXSMM_CACHELINE); + handle->fwd_lp_output_full_scratch_size += ( handle->fwd_lp_output_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->fwd_lp_output_full_scratch_size % LIBXSMM_CACHELINE); + handle->fwd_lp_output_block_scratch_size += ( handle->fwd_lp_output_block_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->fwd_lp_output_block_scratch_size % LIBXSMM_CACHELINE); + + /* set offsets */ + handle->fwd_packing_padding_scratch_offset = 0; + handle->fwd_lp_output_full_scratch_offset = handle->fwd_packing_padding_scratch_size; + handle->fwd_lp_output_block_scratch_offset = handle->fwd_lp_output_full_scratch_offset + + handle->fwd_lp_output_full_scratch_size; + + /* set overall scratch size for forward */ + handle->fwd_scratch_size = handle->fwd_packing_padding_scratch_size + + handle->fwd_lp_output_full_scratch_size + + handle->fwd_lp_output_block_scratch_size; +} + +/**********************************************************/ +/* Helper functions for BWD convolutions' parameter setup */ +/**********************************************************/ +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_fallback_loops_bwd( libxsmm_dnn_layer* handle ) { + int result = 0; + /* FIXME: Fallback if MB is not divisible by number of threads */ + if (handle->desc.N % handle->desc.threads != 0) { + result = 1; + } + if (handle->desc.R == 1 && handle->desc.S == 1 && (handle->desc.pad_h != 0 || handle->desc.pad_w != 0)) { + result = 1; + } + if ((handle->desc.R > 1 && handle->desc.pad_h == 0) || (handle->desc.S > 1 && handle->desc.pad_w == 0)) { + result = 1; + } + if ((handle->desc.R > 1 && (handle->desc.pad_h_out == 0 || handle->desc.pad_h_in == 0)) || + (handle->desc.S > 1 && (handle->desc.pad_w_out == 0 || handle->desc.pad_w_in == 0)) ) { + result = 1; + } + if ((handle->desc.R > 1 && handle->desc.u > 1) || (handle->desc.S > 1 && handle->desc.v > 1)) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_bwd_ofw_rb( libxsmm_dnn_layer* handle ) { + int result = libxsmm_dnn_convolution_setup_fwd_ofw_rb(handle); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_bwd_ofh_rb( libxsmm_dnn_layer* handle ) { + int result = libxsmm_dnn_convolution_setup_fwd_ofh_rb(handle); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_bwd_pixels_gemm( libxsmm_dnn_layer* handle ) { + int result = handle->bwd_ofw_rb * handle->bwd_ofh_rb; + /* In the case below we calculate redundantly pixels in order to efficiently use AMX */ + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) { + if (handle->desc.R != 1 || handle->desc.R != 1) { + if (handle->ofw < 24) { + result = (handle->bwd_ofw_rb+2*handle->desc.pad_w) * (handle->bwd_ofh_rb-2) + 2 * (handle->bwd_ofw_rb+handle->desc.pad_w); + } + } + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_bwd_block_H( libxsmm_dnn_layer* handle ) { + int result = 0; + result = libxsmm_dnn_convolution_setup_fwd_block_H(handle); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_loop_order_bwd( libxsmm_dnn_layer* handle ) { + int result = 0; + result = libxsmm_dnn_convolution_setup_loop_order_fwd(handle); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_bwd_IFM( libxsmm_dnn_layer* handle ) { + int result = 0; + result = LIBXSMM_MIN(handle->blocksifm, 16); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_bwd_OFM( libxsmm_dnn_layer* handle ) { + int result = 8; + while (result % handle->blocksofm_blocking != 0) { + result++; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_pack_input_bwd( libxsmm_dnn_layer* handle ) { + int result = 0; + if ((handle->desc.u != 1) && (handle->bwd_ofh_rb != 1)) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_use_ifm_parallelization( libxsmm_dnn_layer* handle ) { + int result = 0; + if (handle->ofw <= 7) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_rim_fmas_bwd( libxsmm_dnn_layer* handle ) { + int result = libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd(handle); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_blocksofm_blocking( libxsmm_dnn_layer* handle ) { + int result = 0; + if (handle->desc.R == 1 && handle->desc.S == 1) { + result = handle->blocksofm; + } else { + result = 1; + if (handle->desc.R == 3 && handle->desc.S == 3 && handle->ofh == 7 && handle->ofw == 7) { + result = 2; + } + } + + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) { + result = handle->blocksofm; + } + + if (handle->blocksofm % result != 0) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_init_bwd_gemm_flags( libxsmm_dnn_layer* handle ) { + int result = 0; + LIBXSMM_UNUSED( handle ); + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) && ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8)) ) { + result = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_spread_input_bwd( libxsmm_dnn_layer* handle ) { + int result = 0; + LIBXSMM_UNUSED(handle); + if (((handle->desc.u != 1) || (handle->desc.v != 1)) && (handle->bwd_ofh_rb == 1)) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_acc_load_bwd( libxsmm_dnn_layer* handle ) { + int result = 0; + if ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) { + if ((handle->desc.R == 1) && (handle->desc.S == 1)) { + if (handle->blocksofm_blocking == handle->blocksofm) { + result = 1; + } + } else { + if ((handle->blocksofm_blocking == handle->blocksofm) && (handle->avoid_fmas_in_rim == 0)) { + result = 1; + } + } + } + return result; +} + +LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_bwd_scratch( libxsmm_dnn_layer* handle ) { + /* transpose of weights */ + handle->bwd_filter_trans_scratch_size = (size_t)handle->desc.C * handle->desc.K * + handle->desc.R * handle->desc.S * + libxsmm_dnn_typesize(handle->datatype_in); + + handle->bwd_packing_padding_scratch_size = 0; + /* packing of input */ + if ( handle->pack_input_bwd != 0 ) { + handle->bwd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C * + handle->ofhp * handle->ofwp * + libxsmm_dnn_typesize(handle->datatype_in); + } + /* logical padding with copying in the fly */ + if ( handle->use_fallback_bwd_loops != 0 ) { + handle->bwd_packing_padding_scratch_size = (size_t)handle->desc.threads * handle->ifmblock * + (handle->desc.H + 2*handle->desc.pad_h) * + (handle->desc.W + 2*handle->desc.pad_w) * + libxsmm_dnn_typesize(handle->datatype_in); + } + /* input bufffer in high precision when we use BF16 */ + if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16 ) { + handle->bwd_lp_input_full_scratch_size = (size_t) LIBXSMM_MAX(handle->desc.threads * handle->bwd_gemm_pixels * handle->ifmblock * libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32), handle->desc.N * handle->desc.C * handle->ifwp * handle->ifhp * libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32)); + /* logical padding with copying in the fly */ + if ( handle->use_fallback_bwd_loops != 0 ) { + handle->bwd_packing_padding_scratch_size = (size_t)handle->desc.threads * handle->ifmblock * + (handle->desc.H + 2*handle->desc.pad_h) * + (handle->desc.W + 2*handle->desc.pad_w) * + libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32); + } + } else { + handle->bwd_lp_input_full_scratch_size = 0; + } + /* align sizes to full cacheline */ + handle->bwd_filter_trans_scratch_size += ( handle->bwd_filter_trans_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->bwd_filter_trans_scratch_size % LIBXSMM_CACHELINE); + handle->bwd_packing_padding_scratch_size += ( handle->bwd_packing_padding_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->bwd_packing_padding_scratch_size % LIBXSMM_CACHELINE); + handle->bwd_lp_input_full_scratch_size += ( handle->bwd_lp_input_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->bwd_lp_input_full_scratch_size % LIBXSMM_CACHELINE); + + /* set offsets */ + handle->bwd_filter_trans_scratch_offset = 0; + handle->bwd_packing_padding_scratch_offset = handle->bwd_filter_trans_scratch_size; + handle->bwd_lp_input_full_scratch_offset = handle->bwd_packing_padding_scratch_offset + + handle->bwd_packing_padding_scratch_size; + + /* set overall scratch size for forward */ + handle->bwd_scratch_size = handle->bwd_filter_trans_scratch_size + + handle->bwd_packing_padding_scratch_size + + handle->bwd_lp_input_full_scratch_size; +} + +/**********************************************************/ +/* Helper functions for UPD convolutions' parameter setup */ +/**********************************************************/ +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_loop_order_upd( libxsmm_dnn_layer* handle ) { + int result = 1; + if (handle->ofh == 28 && handle->desc.R == 1 && handle->desc.u == 1 && handle->desc.C == 128 && handle->desc.K == 512) { + result = 0; + } + if (handle->ofh == 28 && handle->desc.R == 3 && handle->desc.u == 1 && handle->desc.C == 128 && handle->desc.K == 128) { + result = 0; + } + if (handle->ofw == 28 && handle->desc.R == 1 && handle->desc.C == 256 && handle->desc.K == 512) { + result = 0; + } + if (handle->ofw == 14 && !(handle->desc.R == 1 && handle->desc.C == 1024 && handle->desc.K == 256)) { + result = 0; + } + if (handle->ofw == 7) { + result = 0; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_pack_input_upd( libxsmm_dnn_layer* handle ) { + int result = 0; + /* Pack input only for very small images, 1x1 convs, with large K to amortize the relevant overhead */ + if ((handle->ofh <= 7) && (handle->desc.R == 1) && (handle->desc.S == 1) && (handle->desc.u != 1) && (handle->desc.v != 1) && (handle->desc.K >= 2048)) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_avoid_rim_fmas_upd( libxsmm_dnn_layer* handle ) { + int result = 0; + /* Avoid rim FMAs only for small images */ + if ( (handle->ofh <= 7) && (handle->desc.R == 3) && (handle->desc.S == 3) && (handle->desc.pad_w == 1) && (handle->desc.pad_h == 1)) { + result = 1; + } + if (handle->desc.N != handle->desc.threads) { + result = 0; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_upd_ofw_rb( libxsmm_dnn_layer* handle ) { + int result = 1; + result = handle->ofw; + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_upd_ofh_rb( libxsmm_dnn_layer* handle ) { + int result = 1; + /* Restrict the reduction chain which is ofw_rb*ofh_rb*/ + if (handle->ofh <= 28 ) { + result = handle->ofh; + } + /* In the following scenario with strided convolutions and non batch reduce kernel make sure we have ofh_rb = 1 */ + if ((handle->desc.u != 1) && (handle->desc.v != 1) && (handle->upd_use_batchreduce == 0) && (handle->upd_pack_input == 0)) { + result = 1; + } + /* If using linearized taskview and have strided convs, make sure ofh_rb is 1.. */ + if (handle->upd_linearized_tasklist == 1 && handle->upd_avoid_rim_fmas == 0 && handle->upd_pack_input == 0 && handle->desc.u != 1) { + result = 1; + } + if (handle->upd_linearized_tasklist == 1 && handle->upd_use_batchreduce == 0 && (handle->desc.R != 1 || handle->desc.S != 1)) { + result = 1; + } + if (handle->upd_linearized_tasklist == 0 && handle->upd_use_batchreduce == 0 && (handle->desc.R != 1 || handle->desc.S != 1)) { + result = 1; + } + if (handle->ofw == 56 && handle->desc.R == 1) { + result = 2; + } + if (handle->upd_linearized_tasklist == 1 && handle->upd_use_batchreduce == 1 && handle->upd_avoid_rim_fmas == 1) { + result = handle->ofh; + } + + if ((handle->desc.N != handle->desc.threads) && (handle->desc.R > 1 || handle->desc.S > 1 ) && (handle->desc.u > 1 || handle->desc.v > 1 )) { + result = 1; + } + + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_upd_IFM( libxsmm_dnn_layer* handle ) { + int result = 1; + if (handle->ofh == 56 && handle->desc.R == 1 && handle->desc.S == 1 && handle->desc.u == 1 && handle->desc.v == 1) { + result = 4; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_block_upd_OFM( libxsmm_dnn_layer* handle ) { + int result = 1; + LIBXSMM_UNUSED(handle); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_img_batchreduce_block( libxsmm_dnn_layer* handle ) { + int result = 1; + LIBXSMM_UNUSED(handle); + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_use_batchreduce_upd( libxsmm_dnn_layer* handle ) { + int result = 1; + /* If W is large, no need for batchreduce kernel */ + if (handle->ofw >= 56) { + result = 0; + } + /* If we have packed the input, then disable batch-reduce GEMM */ + if (handle->upd_pack_input == 1) { + result = 0; + } + if (handle->upd_linearized_tasklist == 1 && handle->upd_avoid_rim_fmas == 0) { + result = 0; + } + if (handle->upd_linearized_tasklist == 1 && handle->upd_avoid_rim_fmas == 1) { + result = 1; + } + + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_weight_copies_upd( libxsmm_dnn_layer* handle ) { + int result = handle->desc.threads; + if (handle->ofw <= 14) { + result = 9; + } + if (handle->ofw == 14 && handle->desc.N == 92 && handle->desc.threads == 92) { + result = 23; + } + if (handle->ofw == 7 && handle->desc.N == 92 && handle->desc.threads == 92 && handle->desc.R == 3 && handle->desc.S == 3 && handle->desc.u == 1 && handle->desc.v == 1) { + result = 23; + } + while (handle->desc.threads % result != 0) { + result--; + } + /* FIXME: Hardcoded logic for N=27, N=26 */ + if (handle->desc.N == 27 && handle->desc.threads == 27 && handle->desc.R == 1 && handle->ofw == 14 && handle->desc.u == 1) { + result = 7; + } + if (((handle->ofh == 14) || (handle->ofw == 7 && handle->desc.u == 2)) && handle->desc.N == 26 && handle->desc.threads == 26) { + result = 13; + } + if ((handle->desc.N != handle->desc.threads) && !(handle->upd_linearized_tasklist == 0 && handle->upd_use_batchreduce == 0)) { + result = handle->desc.N; + } + /* Make sure a single copy when we use linearized-task view */ + if (handle->upd_linearized_tasklist == 1) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_linearized_tasklist_upd( libxsmm_dnn_layer* handle ) { + int result = 0; + /* Use linearized task-list (i.e. no reduction) only if small images and large filters */ + if (handle->ofh <= 10 && handle->ofw <= 10) { + result = 1; + } + if (handle->ofw == 7 && handle->desc.N == 92 && handle->desc.threads == 92 && handle->desc.R == 3 && handle->desc.S == 3 && handle->desc.u == 1 && handle->desc.v == 1) { + result = 0; + } + if (handle->ofh == 14 && handle->ofw == 14 && handle->desc.N == 23 && handle->desc.threads == 23) { + result = 1; + } +#if 0 + if ((handle->blocksofm * handle->blocksifm * handle->desc.R * handle->desc.S > (handle->desc.threads * 4)) && (handle->ofh <= 56)) { + result = 1; + } +#endif + if (handle->desc.u == 2 && handle->desc.v == 2 && handle->desc.K == 512) { + result = 0; + } + return result; +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_init_upd_gemm_flags( libxsmm_dnn_layer* handle ) { + int result = 0; + LIBXSMM_UNUSED(handle); + return result; +} + +LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_bf16_upd( libxsmm_dnn_layer* handle ) { + int remainder_pixels, max_init_offset, max_compute_offset_input, input_compute_pad, accum_length_pixels, compute_pixels; + const int multiple_target = 2; + int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2 * handle->desc.pad_h : handle->ifhp; + int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2 * handle->desc.pad_w : handle->ifwp; + int OFHP = (handle->upd_padding_copy == 1) ? handle->ofhp + 2 * handle->desc.pad_h : handle->ofhp; + int OFWP = (handle->upd_padding_copy == 1) ? handle->ofwp + 2 * handle->desc.pad_w : handle->ofwp; + + handle->upd_linearized_pixels = 1; + if (handle->desc.S != 1 && handle->desc.v != 1) { + handle->upd_linearized_pixels = 0; + handle->upd_trans_w_only = 0; + } + /* For large images facilitate the "large" transposes by blocking the pixel/reduction domains */ + if (handle->ofw >= 56 && handle->ofh >=56 && handle->desc.R == 1 && handle->desc.S == 1 && handle->desc.u == 1 && handle->desc.v == 1) { + handle->upd_linearized_pixels = 0; + handle->upd_trans_w_only = 1; + } + + handle->on_the_fly_input_packing = 0; + handle->upd_pack_input_upfront = 0; + handle->use_hybrid_imgofm_parallelization = 0; + handle->upd_linearized_tasklist = 0; + + if (handle->upd_linearized_pixels == 1) { + /* Logistics to pad accumulation chainlength */ + compute_pixels = handle->ofw * handle->ofh + 2 * handle->desc.pad_w * (handle->ofh-1); + remainder_pixels = (compute_pixels % multiple_target == 0) ? 0 : (compute_pixels/multiple_target+1)*multiple_target - compute_pixels; + accum_length_pixels = compute_pixels + remainder_pixels; + + /* In this case compact input upfront */ + if (handle->desc.R == 1 && handle->desc.S == 1 && (handle->desc.u != 1 || handle->desc.v != 1)) { + handle->upd_pack_input_upfront = 1; + } + + /* Logistics for input transpose and additional pixel padding */ + max_init_offset = 2 * handle->desc.pad_h * IFWP + 2 * handle->desc.pad_w; + max_compute_offset_input = max_init_offset + accum_length_pixels; + input_compute_pad = (max_compute_offset_input > IFWP*IFHP) ? max_compute_offset_input - IFWP*IFHP : 0; + handle->input_pixels = IFWP * IFHP + input_compute_pad; + if (handle->upd_pack_input_upfront) { + handle->input_pixels = accum_length_pixels; + } + handle->output_pixels = accum_length_pixels; + handle->pixel_blocking = accum_length_pixels; + handle->n_used_pixels = accum_length_pixels; + handle->compute_pixels = compute_pixels; + + handle->use_intermediate_f32_wt_tensor = (handle->pixel_blocking == handle->n_used_pixels) ? 0 : 1; + + if (handle->ofw <= 14) { + handle->use_hybrid_imgofm_parallelization = 1; + handle->weight_copies = libxsmm_dnn_convolution_setup_weight_copies_upd(handle); + if (handle->ofw == 14 && handle->desc.K >= 1024) { + handle->use_hybrid_imgofm_parallelization = 0; + handle->weight_copies = handle->desc.threads; + } + } else { + handle->weight_copies = handle->desc.threads; + } + } + + if (handle->upd_linearized_pixels == 0) { + handle->weight_copies = handle->desc.threads; + if (handle->desc.v !=1) { + handle->on_the_fly_input_packing = 1; + } + remainder_pixels = (handle->ofw % multiple_target == 0) ? 0 : (handle->ofw/multiple_target+1)*multiple_target - handle->ofw; + handle->ofwp_extended = OFWP + remainder_pixels; + handle->ifwp_extended = IFWP + remainder_pixels; + handle->output_pixels = OFHP * handle->ofwp_extended; + /* coverity[identical_branches] */ + handle->batchreduce_h_pixels = (handle->upd_trans_w_only) ? 1 : 1; /* TODO: identical_branches */ + handle->use_intermediate_f32_wt_tensor = (handle->batchreduce_h_pixels == handle->ofh) ? 0 : 1; + } + + if (handle->desc.N != handle->desc.threads) { + handle->use_intermediate_f32_wt_tensor = 1; + handle->use_hybrid_imgofm_parallelization = 0; + handle->weight_copies = LIBXSMM_MIN(handle->desc.N, handle->desc.threads); + } + +} + +LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_bf16_upd_amx( libxsmm_dnn_layer* handle ) { + /* JIT related variables... */ + libxsmm_blasint LDA = handle->ofmblock; + libxsmm_blasint LDB = handle->input_pixels; + libxsmm_blasint LDC = handle->ofmblock; + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG; + int l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + size_t stride_a, stride_b; + int unroll_hint; + float beta; + + int remainder_pixels, max_init_offset, max_compute_offset_input, input_compute_pad, accum_length_pixels, compute_pixels; + const int multiple_target = 32; + int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2 * handle->desc.pad_h : handle->ifhp; + int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2 * handle->desc.pad_w : handle->ifwp; + int OFWP = (handle->upd_padding_copy == 1) ? handle->ofwp + 2 * handle->desc.pad_w : handle->ofwp; + + handle->upd_linearized_pixels = 1; + if (handle->desc.S != 1 && handle->desc.v != 1) { + handle->upd_linearized_pixels = 0; + } + handle->fuse_upd_transposes = 1; + handle->pack_to_cnhw = 0; + handle->on_the_fly_input_packing = 0; + handle->upd_pack_input_upfront = 0; + handle->use_hybrid_imgofm_parallelization = 0; + handle->upd_linearized_tasklist = 0; + if (((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && (handle->ofw == 7) && (handle->desc.R == 1) && (handle->desc.S == 1) ) { + handle->pack_to_cnhw= 1; + } + + if (handle->upd_linearized_pixels == 1) { + if (handle->pack_to_cnhw == 0) { + handle->fuse_upd_transposes = 1; + /* Logistics to pad accumulation chainlength */ + compute_pixels = handle->ofw * handle->ofh + 2 * handle->desc.pad_w * (handle->ofh-1); + remainder_pixels = (compute_pixels % multiple_target == 0) ? 0 : (compute_pixels/multiple_target+1)*multiple_target - compute_pixels; + accum_length_pixels = compute_pixels + remainder_pixels; + + /* In this case compact input upfront */ + if (handle->desc.R == 1 && handle->desc.S == 1 && (handle->desc.u != 1 || handle->desc.v != 1)) { + handle->upd_pack_input_upfront = 1; + } + + /* Logistics for input transpose and additional pixel padding */ + max_init_offset = 2 * handle->desc.pad_h * IFWP + 2 * handle->desc.pad_w; + max_compute_offset_input = max_init_offset + accum_length_pixels; + input_compute_pad = (max_compute_offset_input > IFWP*IFHP) ? max_compute_offset_input - IFWP*IFHP : 0; + handle->input_pixels = IFWP*IFHP+ input_compute_pad; + if (handle->upd_pack_input_upfront) { + handle->input_pixels = accum_length_pixels; + } + handle->output_pixels = accum_length_pixels; + handle->pixel_blocking = accum_length_pixels; + handle->n_used_pixels = accum_length_pixels; + handle->compute_pixels = compute_pixels; + + handle->use_intermediate_f32_wt_tensor = (handle->pixel_blocking == handle->n_used_pixels) ? 0 : 1; +#if 0 + handle->scratch2_size = (size_t) (handle->desc.N * handle->output_pixels * handle->desc.K * sizeof(float)/2); + if (handle->use_intermediate_f32_wt_tensor) { + handle->scratch2_size += (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads * sizeof(float); + } + handle->scratch3_size = (size_t) (handle->desc.N * handle->input_pixels * handle->desc.C * sizeof(float)/2); +#endif + + if (handle->ofw <= 14) { + handle->use_hybrid_imgofm_parallelization = 1; + handle->fuse_upd_transposes = 0; + } else { + handle->weight_copies = handle->desc.threads; + } + + if ((handle->ofmblock % 32 != 0) || (handle->ifmblock % 32 != 0)) { + handle->fuse_upd_transposes = 0; + } + } else { + /* Logistics to pad accumulation chainlength */ + handle->use_hybrid_imgofm_parallelization = 1; + handle->weight_copies = 7; + while (handle->desc.threads % handle->weight_copies != 0) { + handle->weight_copies--; + } + compute_pixels = handle->ofw * handle->ofh * (handle->desc.N/handle->weight_copies); + remainder_pixels = (compute_pixels % multiple_target == 0) ? 0 : (compute_pixels/multiple_target+1)*multiple_target - compute_pixels; + handle->remainder_pixels = remainder_pixels; + accum_length_pixels = compute_pixels + remainder_pixels; + handle->output_pixels = accum_length_pixels * handle->weight_copies; + handle->input_pixels = accum_length_pixels * handle->weight_copies; + handle->pixel_blocking = accum_length_pixels; + handle->n_used_pixels = accum_length_pixels; + + handle->use_intermediate_f32_wt_tensor = 0; +#if 0 + handle->scratch2_size = (size_t) (handle->weight_copies * handle->output_pixels * handle->desc.K * sizeof(float)/2); + handle->scratch3_size = (size_t) (handle->weight_copies * handle->input_pixels * handle->desc.C * sizeof(float)/2); +#endif + } + } + + if (handle->upd_linearized_pixels == 0) { + handle->weight_copies = handle->desc.threads; + if (handle->desc.v !=1) { + handle->on_the_fly_input_packing = 1; + } + remainder_pixels = (handle->ofw % multiple_target == 0) ? 0 : (handle->ofw/multiple_target+1)*multiple_target - handle->ofw; + handle->remainder_pixels = remainder_pixels; + handle->ofwp_extended = OFWP + remainder_pixels; + handle->ifwp_extended = IFWP + remainder_pixels; + handle->batchreduce_h_pixels = handle->ofh; + handle->use_intermediate_f32_wt_tensor = (handle->batchreduce_h_pixels == handle->ofh) ? 0 : 1; +#if 0 + handle->scratch2_size = (size_t) (handle->desc.N * handle->ofhp*handle->ofwp_extended * handle->desc.K * sizeof(float)/2); + if (handle->use_intermediate_f32_wt_tensor) { + handle->scratch2_size += (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads * sizeof(float); + } + handle->scratch3_size = (size_t) (handle->desc.N * handle->ifhp * handle->ifwp_extended * handle->desc.C * sizeof(float)/2); +#endif + } + + /* Now that all decisions have been made, JIT the proper kernel... */ + beta = (handle->use_intermediate_f32_wt_tensor) ? (float)1.0 : (float)0.0; + if (handle->upd_linearized_pixels == 0) { + LDA = handle->ofmblock; + LDB = IFHP*handle->ifwp_extended; + LDC = handle->ofmblock; + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + unroll_hint = handle->batchreduce_h_pixels; + stride_a = handle->ofwp_extended * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in); + stride_b = handle->desc.u * handle->ifwp_extended * libxsmm_dnn_typesize(handle->datatype_in); + handle->upd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->ofw+handle->remainder_pixels, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL); + handle->upd_compute_kernel_brgemm_no_linearized_pixels = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->ofmblock, handle->ifmblock, handle->ofw+handle->remainder_pixels, + (libxsmm_blasint)stride_a, (libxsmm_blasint)stride_b, unroll_hint, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + } else { + LDA = handle->ofmblock; + LDB = handle->input_pixels; + LDC = handle->ofmblock; + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + if (handle->use_hybrid_imgofm_parallelization == 0) { + handle->upd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL); + handle->upd_compute_kernel_gemm_linearized_pixels_no_hybrid_par = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + } else { + if (handle->pack_to_cnhw == 1) { + handle->upd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL); + handle->upd_compute_kernel_gemm_linearized_pixels_hybrid_par_cnhw = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + } else { + /* TODO: Hoist here hybrid parallelization logic and then we should be able to also provide unroll hint in the BRGEMM call */ + stride_a = handle->blocksofm * handle->output_pixels * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in); + stride_b = handle->blocksifm * handle->ifmblock * handle->input_pixels * libxsmm_dnn_typesize(handle->datatype_in); + handle->upd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL); + handle->upd_compute_kernel_brgemm_linearized_pixels_hybrid_par_no_cnhw = libxsmm_bsmmdispatch_reducebatch_strd(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, + (libxsmm_blasint)stride_a, (libxsmm_blasint)stride_b, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + } + } + } + + if (handle->desc.N != handle->desc.threads) { + handle->use_intermediate_f32_wt_tensor = 1; + handle->use_hybrid_imgofm_parallelization = 0; + handle->weight_copies = LIBXSMM_MIN(handle->desc.N, handle->desc.threads); + } + +} + +LIBXSMM_API_INLINE int libxsmm_dnn_convolution_setup_upd_padding_copy( libxsmm_dnn_layer* handle ) { + int result = 0; + if ( (handle->desc.pad_h != handle->desc.pad_h_in) && (handle->desc.pad_w != handle->desc.pad_w_in) ) { + result = 1; + } + return result; +} + +LIBXSMM_API_INLINE void libxsmm_dnn_convolution_setup_upd_scratch( libxsmm_dnn_layer* handle ) { + handle->upd_packing_padding_scratch_size = 0; + /* packing of input */ + if ( handle->upd_pack_input != 0 ) { + handle->upd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C * + handle->desc.H/handle->desc.u * + handle->desc.W/handle->desc.v * + libxsmm_dnn_typesize(handle->datatype_in); + } + /* logical padding with copying in the fly */ + if ( handle->upd_padding_copy != 0 ) { + handle->upd_packing_padding_scratch_size = (size_t)handle->desc.N * handle->desc.C * + (handle->desc.H + 2*handle->desc.pad_h) * + (handle->desc.W + 2*handle->desc.pad_w) * + libxsmm_dnn_typesize(handle->datatype_in); + } + /* output/input buffer to transpose when we use bf16 */ + if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16 ) { + if (handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + int OFHP = (handle->upd_padding_copy == 1) ? handle->ofhp + 2 * handle->desc.pad_h : handle->ofhp; + int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2 * handle->desc.pad_h : handle->ifhp; + + if (handle->upd_linearized_pixels == 1) { + handle->upd_lp_output_full_scratch_size = (size_t) (handle->desc.N * handle->output_pixels * handle->desc.K * sizeof(handle->datatype_in)); + handle->upd_lp_input_full_scratch_size = (size_t) (handle->desc.N * handle->input_pixels * handle->desc.C * sizeof(handle->datatype_in)); + } + + if (handle->upd_linearized_pixels == 0) { + handle->upd_lp_output_full_scratch_size = (size_t) (handle->desc.N * OFHP * handle->ofwp_extended * handle->desc.K * sizeof(handle->datatype_in)); + handle->upd_lp_input_full_scratch_size = (size_t) (handle->desc.N * IFHP * handle->ifwp_extended * handle->desc.C * sizeof(handle->datatype_in)); + } + } else { + const int multiple_target = 2; + int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2 * handle->desc.pad_h : handle->ifhp; + int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2 * handle->desc.pad_w : handle->ifwp; + int OFHP = (handle->upd_padding_copy == 1) ? handle->ofhp + 2 * handle->desc.pad_h : handle->ofhp; + int OFWP = (handle->upd_padding_copy == 1) ? handle->ofwp + 2 * handle->desc.pad_w : handle->ofwp; + + if (handle->upd_linearized_pixels == 1) { + int compute_pixels = handle->ofw * handle->ofh + 2 * handle->desc.pad_w * (handle->ofh-1); + int remainder_pixels = (compute_pixels % multiple_target == 0) ? 0 : (compute_pixels/multiple_target+1)*multiple_target - compute_pixels; + int accum_length_pixels = compute_pixels + remainder_pixels; + + int max_init_offset = 2 * handle->desc.pad_h * IFWP + 2 * handle->desc.pad_w; + int max_compute_offset_input = max_init_offset + accum_length_pixels; + int input_compute_pad = (max_compute_offset_input > IFWP*IFHP) ? max_compute_offset_input - IFWP*IFHP : 0; + int input_pixels = IFWP * IFHP + input_compute_pad; + + if (handle->upd_pack_input_upfront == 1) { + input_pixels = accum_length_pixels; + } + + handle->upd_lp_output_full_scratch_size = (size_t) (handle->desc.N * accum_length_pixels * handle->desc.K * sizeof(handle->datatype_in)); + handle->upd_lp_input_full_scratch_size = (size_t) (handle->desc.N * input_pixels * handle->desc.C * sizeof(handle->datatype_in)); + } + + if (handle->upd_linearized_pixels == 0) { + int remainder_pixels = (handle->ofw % multiple_target == 0) ? 0 : (handle->ofw/multiple_target+1)*multiple_target - handle->ofw; + int ofwp_extended = OFWP + remainder_pixels; + int ifwp_extended = IFWP + remainder_pixels; + + handle->upd_lp_output_full_scratch_size = (size_t) (handle->desc.N * OFHP * ofwp_extended * handle->desc.K * sizeof(handle->datatype_in)); + handle->upd_lp_input_full_scratch_size = (size_t) (handle->desc.N * IFHP * ifwp_extended * handle->desc.C * sizeof(handle->datatype_in)); + } + } + handle->upd_lp_filter_full_scratch_size = (size_t)handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads * + libxsmm_dnn_typesize(LIBXSMM_DNN_DATATYPE_F32); + } else { + handle->upd_lp_output_full_scratch_size = 0; + handle->upd_lp_input_full_scratch_size = 0; + handle->upd_lp_filter_full_scratch_size = 0; + } + /* filter scratch */ + handle->upd_filter_scratch_size = (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * LIBXSMM_MAX(handle->desc.threads, handle->desc.N) * sizeof(float); + + /* align sizes to full cacheline */ + handle->upd_packing_padding_scratch_size += ( handle->upd_packing_padding_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->upd_packing_padding_scratch_size % LIBXSMM_CACHELINE); + handle->upd_lp_output_full_scratch_size += ( handle->upd_lp_output_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->upd_lp_output_full_scratch_size % LIBXSMM_CACHELINE); + handle->upd_lp_input_full_scratch_size += ( handle->upd_lp_input_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->upd_lp_input_full_scratch_size % LIBXSMM_CACHELINE); + handle->upd_filter_scratch_size += ( handle->upd_filter_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->upd_filter_scratch_size % LIBXSMM_CACHELINE); + handle->upd_lp_filter_full_scratch_size += ( handle->upd_lp_filter_full_scratch_size % LIBXSMM_CACHELINE == 0 ) ? 0 : + LIBXSMM_CACHELINE - (handle->upd_lp_filter_full_scratch_size % LIBXSMM_CACHELINE); + + /* calculate offsets */ + handle->upd_packing_padding_scratch_offset = 0; + handle->upd_lp_output_full_scratch_offset = handle->upd_packing_padding_scratch_size; + handle->upd_lp_input_full_scratch_offset = handle->upd_lp_output_full_scratch_offset + handle->upd_lp_output_full_scratch_size; + handle->upd_filter_scratch_offset = handle->upd_lp_input_full_scratch_offset + handle->upd_lp_input_full_scratch_size; + handle->upd_lp_filter_full_scratch_offset = handle->upd_filter_scratch_offset + handle->upd_filter_scratch_size; + + /* set overall scratch size for update */ + handle->upd_scratch_size = handle->upd_packing_padding_scratch_size + + handle->upd_lp_output_full_scratch_size + + handle->upd_lp_input_full_scratch_size + + handle->upd_filter_scratch_size + + handle->upd_lp_filter_full_scratch_size; +} + +LIBXSMM_API_INLINE libxsmm_dnn_err_t libxsmm_dnn_convolution_setup( libxsmm_dnn_layer* handle ) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + libxsmm_blasint _ldi = 64, _ldo = 64; + libxsmm_blasint ldx; + libxsmm_blasint ldA; + libxsmm_blasint ldC; + int beta_int; + float beta; + int l_flags; + int l_tc_flags; + + /* init libxsmm */ + LIBXSMM_INIT + + /* Generic parameter setup */ + handle->target_archid = libxsmm_target_archid; + if ( ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && ((handle->desc.C % 16 != 0) || (handle->desc.K % 16 != 0)) ) { + handle->target_archid = LIBXSMM_X86_AVX512_CPX; + } + handle->ifmblock = libxsmm_dnn_convolution_setup_ifmblock(handle); + handle->ofmblock = libxsmm_dnn_convolution_setup_ofmblock(handle); + handle->fm_lp_block = libxsmm_dnn_convolution_setup_fm_lp_block(handle); + handle->blocksifm = libxsmm_dnn_convolution_setup_blocksifm(handle); + handle->blocksofm = libxsmm_dnn_convolution_setup_blocksofm(handle); + + /* If in SPR, generate tilerelease kernel */ + if (handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + int l_tr_flags = LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + handle->tilerelease_kernel = libxsmm_bsmmdispatch(handle->ifmblock, handle->ifmblock, handle->ifmblock, NULL, NULL, NULL, NULL, NULL, &l_tr_flags, NULL); + } + + /* FWD parameter setup */ + handle->fwd_ofw_rb = libxsmm_dnn_convolution_setup_fwd_ofw_rb(handle); + handle->pack_input = libxsmm_dnn_convolution_setup_pack_input_fwd(handle); + handle->fwd_ofh_rb = libxsmm_dnn_convolution_setup_fwd_ofh_rb(handle); + handle->fwd_gemm_pixels = libxsmm_dnn_convolution_setup_fwd_pixels_gemm(handle); + handle->block_fwd_oj = libxsmm_dnn_convolution_setup_fwd_block_H(handle); + handle->loop_order = libxsmm_dnn_convolution_setup_loop_order_fwd(handle); + handle->blocksifm_blocking = libxsmm_dnn_convolution_setup_blocksifm_blocking(handle); + handle->block_fwd_ofm = libxsmm_dnn_convolution_setup_block_fwd_OFM(handle); + handle->block_fwd_ifm = libxsmm_dnn_convolution_setup_block_fwd_IFM(handle); + handle->avoid_fmas_in_rim = libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd(handle); + handle->use_ofm_parallelization = libxsmm_dnn_convolution_setup_use_ofm_parallelization(handle); + handle->shuffle_filter_accesses = libxsmm_dnn_convolution_setup_shuffle_filter_accesses(handle); + handle->avoid_acc_load = libxsmm_dnn_convolution_setup_avoid_acc_load(handle); + handle->fwd_flags = libxsmm_dnn_convolution_setup_init_fwd_gemm_flags(handle); + handle->use_fallback_fwd_loops = libxsmm_dnn_convolution_setup_fallback_loops_fwd(handle); + handle->fwd_padding_copy = libxsmm_dnn_convolution_setup_fwd_padding_copy(handle); + +#if 0 + if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 ) { + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + handle->block_fwd_ofm = 1; + handle->block_fwd_oj = handle->fwd_ofh_rb; + ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock; + ldA = handle->ofmblock; + ldC = handle->ofmblock; + beta = (handle->avoid_acc_load) ? (float)0.0 : (float)1.0; + l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags; + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + handle->fwd_compute_kernel_offs_f32 = NULL; + handle->fwd_compute_kernel_strd_f32 = NULL; + handle->fwd_compute_kernel_addr_a_f32 = NULL; + handle->fwd_compute_kernel_addr_b_f32 = NULL; + if (handle->desc.R == 1 && handle->desc.S == 1) { + const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp; + const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp; + int stride_a = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in); + int stride_b = IFW * IFH * handle->ifmblock * libxsmm_dnn_typesize(handle->datatype_in); + handle->fwd_compute_kernel_strd_f32 = libxsmm_smmdispatch_reducebatch_strd_unroll(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, stride_a, stride_b, handle->blocksifm_blocking, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + } else { + const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp ); + const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp ); + int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking; + int i = 0, ifm, ki, kj; + handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock + + kj * handle->desc.S * handle->ifmblock * handle->ofmblock + + ki * handle->ifmblock * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in); + handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock + + kj * IFW * handle->ifmblock + + ki * handle->ifmblock) * libxsmm_dnn_typesize(handle->datatype_in); + i++; + } + } + } + handle->fwd_compute_kernel_offs_f32 = libxsmm_smmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + } + handle->fwd_compute_kernel_addr_a_f32 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + handle->fwd_compute_kernel_addr_b_f32 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + } +#endif + + if ( ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) ) { + handle->block_fwd_ofm = 1; + handle->block_fwd_oj = handle->fwd_ofh_rb; + ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock; + ldA = handle->ofmblock; + ldC = handle->ofmblock; + beta = (handle->avoid_acc_load) ? (float)0.0 : (float)1.0; + l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG; + l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + handle->fwd_compute_kernel_addr = NULL; + handle->fwd_compute_kernel_offs_a = NULL; + handle->fwd_compute_kernel_offs_b = NULL; + handle->fwd_compute_kernel_strd = NULL; + if (handle->desc.R == 1 && handle->desc.S == 1) { + const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp; + const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp; + size_t stride_a = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in); + size_t stride_b = IFW * IFH * handle->ifmblock * libxsmm_dnn_typesize(handle->datatype_in); + handle->fwd_compute_kernel_strd = libxsmm_bmmdispatch_reducebatch_strd_unroll(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, + (libxsmm_blasint)stride_a, (libxsmm_blasint)stride_b, handle->blocksifm_blocking, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + } else { + const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp ); + const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp ); + int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking; + int i = 0, ifm, ki, kj; + handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock + + kj * handle->desc.S * handle->ifmblock * handle->ofmblock + + ki * handle->ifmblock * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in); + handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock + + kj * IFW * handle->ifmblock + + ki * handle->ifmblock) * libxsmm_dnn_typesize(handle->datatype_in); + i++; + } + } + } + handle->fwd_compute_kernel_offs_a = libxsmm_bmmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + handle->fwd_compute_kernel_offs_b = libxsmm_bsmmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + } + handle->fwd_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_tc_flags, NULL); + } + + handle->code_fwd[0].ptr = 0; + handle->code_fwd[1].ptr = 0; + handle->code_fwd[2].ptr = 0; + + /* JIT cvt eltwise functions for fwd convolutions */ + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) { + _ldi = handle->ofmblock * handle->ofwp; + _ldo = handle->ofmblock * handle->ofwp; + handle->fwd_cvtfp32bf16_kernel = libxsmm_dispatch_meltw_unary(handle->ofmblock * handle->fwd_ofw_rb, handle->fwd_ofh_rb, &_ldi, &_ldo, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY); + } + + /* Create strided BRGEMMs for i8i32 convolutions */ + if ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_I32)) { + ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock; + ldA = handle->ofmblock; + ldC = handle->ofmblock; + beta_int = (handle->avoid_acc_load) ? 0 : 1; + l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | handle->fwd_flags; + if (handle->desc.R == 1 && handle->desc.S == 1) { + const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp; + const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp; + libxsmm_blasint stride_A = handle->ifmblock * handle->ofmblock * sizeof(char); + libxsmm_blasint stride_B = handle->ifmblock * IFW * IFH * sizeof(char) ; + handle->gemm_fwd.xgemm.subimrs = libxsmm_subimmdispatch_reducebatch_strd(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, stride_A, stride_B, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL); + } else { + const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp; + const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp; + if (handle->avoid_fmas_in_rim == 0) { + int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking; + int i = 0, ifm, ki, kj; + handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock + + kj * handle->desc.S * handle->ifmblock * handle->ofmblock + + ki * handle->ifmblock * handle->ofmblock) * sizeof(char); + handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock + + kj * IFW * handle->ifmblock + + ki * handle->ifmblock) * sizeof(char); + i++; + } + } + } + handle->gemm_fwd.xgemm.subimro = libxsmm_subimmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL); + } else { + libxsmm_blasint stride_A = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(char); + libxsmm_blasint stride_B = handle->ifmblock * IFW * IFH * sizeof(char) ; + handle->gemm_fwd.xgemm.subimrs = libxsmm_subimmdispatch_reducebatch_strd(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, stride_A, stride_B, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL); + handle->gemm_fwd2.xgemm.subimrs = libxsmm_subimmdispatch_reducebatch_strd(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, stride_A, stride_B, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL); + } + } + } else if ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_I8)) { + ldx = (libxsmm_blasint)handle->desc.v*handle->ifmblock; + ldA = handle->ofmblock; + ldC = handle->ofmblock; + beta_int = 0; + l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | handle->fwd_flags; + if (handle->desc.R == 1 && handle->desc.S == 1) { + const int IFW = handle->ifwp; + const int IFH = handle->ifhp; + libxsmm_blasint stride_A = handle->ifmblock * handle->ofmblock * sizeof(char); + libxsmm_blasint stride_B = handle->ifmblock * IFW * IFH * sizeof(char) ; + handle->gemm_fwd.xgemm.sububmrs = libxsmm_sububmmdispatch_reducebatch_strd(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, stride_A, stride_B, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL); + } else { + const int IFW = handle->ifwp; + const int IFH = handle->ifhp; + int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking; + int i = 0, ifm, ki, kj; + handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock + + kj * handle->desc.S * handle->ifmblock * handle->ofmblock + + ki * handle->ifmblock * handle->ofmblock) * sizeof(char); + handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock + + kj * IFW * handle->ifmblock + + ki * handle->ifmblock) * sizeof(char); + i++; + } + } + } + handle->gemm_fwd.xgemm.sububmro = libxsmm_sububmmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta_int, &l_flags, NULL); + } + } + +#if 0 + /* Spit out FWD parameters that are selected... */ + printf("FWD params...\n"); + printf("Fwd_ofw_rb = %d\n", handle->fwd_ofw_rb); + printf("Fwd_ofh_rb = %d\n", handle->fwd_ofh_rb); + printf("Pack input = %d\n", handle->pack_input); + printf("Block oj = %d\n", handle->block_fwd_oj); + printf("Loop order = %d\n", handle->loop_order); + printf("Blocksifm_blocking = %d\n", handle->blocksifm_blocking); + printf("Block fwd ofm = %d\n", handle->block_fwd_ofm); + printf("Block fwd ifm = %d\n", handle->block_fwd_ifm); + printf("Avoid rim fmas = %d\n", handle->avoid_fmas_in_rim); + printf("Ofm parallelization = %d\n", handle->use_ofm_parallelization); + printf("Shuffle filter accesses = %d\n", handle->shuffle_filter_accesses); + printf("Avoid acc load = %d\n", handle->avoid_acc_load); + printf("Fwd GEMM flags = %d\n", handle->fwd_flags); +#endif + + /* BWD parameter setup */ + handle->bwd_ofw_rb = libxsmm_dnn_convolution_setup_bwd_ofw_rb(handle); + handle->bwd_ofh_rb = libxsmm_dnn_convolution_setup_bwd_ofh_rb(handle); + handle->bwd_gemm_pixels = libxsmm_dnn_convolution_setup_bwd_pixels_gemm(handle); + handle->pack_input_bwd = libxsmm_dnn_convolution_setup_pack_input_bwd(handle); + handle->spread_input_bwd = libxsmm_dnn_convolution_setup_spread_input_bwd(handle); + handle->blocksofm_blocking = libxsmm_dnn_convolution_setup_blocksofm_blocking(handle); + handle->avoid_acc_load_bwd = libxsmm_dnn_convolution_setup_avoid_acc_load_bwd(handle); + handle->use_ifm_parallelization = libxsmm_dnn_convolution_setup_use_ifm_parallelization(handle); + handle->block_bwd_ofm = libxsmm_dnn_convolution_setup_block_bwd_OFM(handle); + handle->block_bwd_ifm = libxsmm_dnn_convolution_setup_block_bwd_IFM(handle); + handle->block_bwd_oj = libxsmm_dnn_convolution_setup_bwd_block_H(handle); + handle->use_fallback_bwd_loops = libxsmm_dnn_convolution_setup_fallback_loops_bwd(handle); + handle->bwd_flags = libxsmm_dnn_convolution_setup_init_bwd_gemm_flags(handle); + + if ( ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) ) { + handle->block_bwd_ifm = 1; + handle->block_bwd_oj = handle->bwd_ofh_rb ; + ldx = ((libxsmm_blasint)handle->ofmblock); + ldA = handle->ifmblock; + ldC = (handle->spread_input_bwd == 1) ? handle->ifmblock * handle->desc.v : handle->ifmblock; + beta = (handle->avoid_acc_load_bwd) ? (float)0.0 : (float)1.0; + l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG; + l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + handle->bwd_compute_kernel_addr = NULL; + handle->bwd_compute_kernel_offs = NULL; + handle->bwd_compute_kernel_strd = NULL; + if (handle->desc.R == 1 && handle->desc.S == 1) { + size_t stride_a = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in); + size_t stride_b = handle->ofwp * handle->ofhp * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in); + handle->bwd_compute_kernel_strd = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->ifmblock, handle->bwd_gemm_pixels, handle->ofmblock, + (libxsmm_blasint)stride_a, (libxsmm_blasint)stride_b, handle->blocksofm_blocking, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + } else { + int n_blocks = handle->desc.R * handle->desc.S * handle->blocksofm_blocking; + int i = 0, ofm, ki, kj; + handle->A_offsets_bwd = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + handle->B_offsets_bwd = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long)); + for (ofm = 0; ofm < handle->blocksofm_blocking; ofm++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + handle->A_offsets_bwd[i] = (ofm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock + + kj * handle->desc.S * handle->ifmblock * handle->ofmblock + + ki * handle->ifmblock * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in); + handle->B_offsets_bwd[i] = (ofm * handle->ofhp * handle->ofwp * handle->ofmblock + + kj * handle->ofwp * handle->ofmblock + + ki * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in); + i++; + } + } + } + handle->bwd_compute_kernel_offs = libxsmm_bsmmdispatch_reducebatch_offs(handle->ifmblock, handle->bwd_gemm_pixels, handle->ofmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + } + handle->bwd_config_kernel = libxsmm_bsmmdispatch(handle->ifmblock, handle->bwd_gemm_pixels, handle->ofmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_tc_flags, NULL); + } + +#if 0 + /* Spit out BWD parameters that are selected... */ + printf("BWD params...\n"); + printf("Bwd_ofw_rb = %d\n", handle->bwd_ofw_rb); + printf("Bwd_ofh_rb = %d\n", handle->bwd_ofh_rb); + printf("Pack input = %d\n", handle->pack_input_bwd); + printf("Spread input = %d\n", handle->spread_input_bwd); + printf("Blocksofm_blocking = %d\n", handle->blocksofm_blocking); + printf("Avoid acc load = %d\n", handle->avoid_acc_load_bwd); + printf("Ifm parallelization = %d\n", handle->use_ifm_parallelization); + printf("Block bwd ofm = %d\n", handle->block_bwd_ofm); + printf("Block bwd ifm = %d\n", handle->block_bwd_ifm); + printf("Block oj = %d\n", handle->block_bwd_oj); +#endif + + handle->code_bwd[0].ptr = 0; + handle->code_bwd[1].ptr = 0; + handle->code_bwd[2].ptr = 0; + + /* Transpose kernel used for filter transpose in bwd pass */ + handle->tr_kernel = libxsmm_dispatch_meltw_unary(64, 16, &(_ldi), &(_ldo), LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + + /* UPD parameter setup */ + handle->upd_linearized_tasklist = libxsmm_dnn_convolution_setup_linearized_tasklist_upd(handle); + handle->upd_avoid_rim_fmas = libxsmm_dnn_convolution_setup_avoid_rim_fmas_upd(handle); + handle->upd_pack_input = libxsmm_dnn_convolution_setup_pack_input_upd(handle); + handle->upd_use_batchreduce = libxsmm_dnn_convolution_setup_use_batchreduce_upd(handle); + handle->upd_ofw_rb = libxsmm_dnn_convolution_setup_upd_ofw_rb(handle); + handle->upd_ofh_rb = libxsmm_dnn_convolution_setup_upd_ofh_rb(handle); + handle->upd_loop_order = libxsmm_dnn_convolution_setup_loop_order_upd(handle); + handle->weight_copies = libxsmm_dnn_convolution_setup_weight_copies_upd(handle); + handle->block_upd_ofm = libxsmm_dnn_convolution_setup_block_upd_OFM(handle); + handle->block_upd_ifm = libxsmm_dnn_convolution_setup_block_upd_IFM(handle); + handle->upd_loop_order = libxsmm_dnn_convolution_setup_loop_order_upd(handle); + handle->upd_padding_copy = libxsmm_dnn_convolution_setup_upd_padding_copy(handle); + + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) { + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) { + libxsmm_dnn_convolution_setup_bf16_upd_amx(handle); + } else { + libxsmm_dnn_convolution_setup_bf16_upd(handle); + } + } + +#if 0 + /* Spit out UPD parameters that are selected... */ + printf("UPD params...\n"); + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) { + printf("BF16 path...\n"); + printf("UPD use_hybrid_imgofm_parallelization = %d\n", handle->use_hybrid_imgofm_parallelization); + printf("UPD linearized_pixels = %d\n", handle->upd_linearized_pixels); + printf("UPD upd_trans_w_only = %d\n", handle->upd_trans_w_only); + printf("UPD on_the_fly_input_packing = %d\n", handle->on_the_fly_input_packing); + printf("UPD use_intermediate_f32_wt_tensor = %d\n", handle->use_intermediate_f32_wt_tensor); + printf("UPD pack to CNHW format = %d\n", handle->pack_to_cnhw); + printf("UPD batchreduce H pixels = %d\n", handle->batchreduce_h_pixels); + } + printf("UPD linearized tasks = %d\n", handle->upd_linearized_tasklist); + printf("UPD avoid rim fmas = %d\n", handle->upd_avoid_rim_fmas); + printf("UPD Pack input = %d\n", handle->upd_pack_input); + printf("UPD use batch-reduce GEMM = %d\n", handle->upd_use_batchreduce); + printf("Upd_ofw_rb = %d\n", handle->upd_ofw_rb); + printf("Upd_ofh_rb = %d\n", handle->upd_ofh_rb); + printf("UPD loop order = %d\n", handle->upd_loop_order); + printf("UPD weight_copies = %d\n", handle->weight_copies); + printf("Block upd ofm = %d\n", handle->block_upd_ofm); + printf("Block upd ifm = %d\n", handle->block_upd_ifm); +#endif + + handle->code_upd[0].ptr = 0; + handle->code_upd[1].ptr = 0; + + /* prepare barrier */ + handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1); + + /* setup up scratch */ + libxsmm_dnn_convolution_setup_fwd_scratch( handle ); + libxsmm_dnn_convolution_setup_bwd_scratch( handle ); + libxsmm_dnn_convolution_setup_upd_scratch( handle ); + handle->scratch = 0; + handle->scratch_size = LIBXSMM_MAX( handle->fwd_scratch_size, LIBXSMM_MAX( handle->bwd_scratch_size, handle->upd_scratch_size ) ); + + return status; +} + +#undef MIXED +#undef KHWC +#undef HWKC +#undef CHWK +#undef HWCK + + +LIBXSMM_API libxsmm_dnn_layer* libxsmm_dnn_create_conv_layer( + libxsmm_dnn_conv_desc conv_desc, + libxsmm_dnn_err_t* status) +{ + libxsmm_dnn_layer* handle = 0; + *status = LIBXSMM_DNN_SUCCESS; + + /* currently we don't support NCHW */ + if ( (conv_desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCHW) > 0 ) { + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW; + return 0; + } + /* currently we don't support KCRS */ + if ( (conv_desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_KCRS) > 0 ) { + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS; + return 0; + } + /* we only support physical paddind in these days */ + /* @TODO: add logical padding support for other datatypes than FP32 */ + if ( ( ( conv_desc.pad_h != conv_desc.pad_h_in ) || + ( conv_desc.pad_w != conv_desc.pad_w_in ) || + ( conv_desc.pad_h != conv_desc.pad_h_out ) || + ( conv_desc.pad_w != conv_desc.pad_w_out ) ) && + ( conv_desc.datatype_in != LIBXSMM_DNN_DATATYPE_F32 ) && (conv_desc.datatype_in != LIBXSMM_DNN_DATATYPE_BF16) ) { + *status = LIBXSMM_DNN_ERR_INVALID_PADDING; + return 0; + } + + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + handle = (libxsmm_dnn_layer*)calloc(1, sizeof(libxsmm_dnn_layer)); + + if (0 != handle) { + /* initialize known handle components */ + handle->desc = conv_desc; + handle->datatype_in = conv_desc.datatype_in; + handle->datatype_out = conv_desc.datatype_out; + /* select the intermediate format, only applicable for integer types */ + if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_F32) ) { + /* error */ + } else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_BF16) ) { + /* error */ + } else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_I16) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_F32) ) { + /* error */ + } else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_I32) ) { + /* error */ + } else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_I8) ) { + /* error */ + } else if ( (conv_desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8) && (conv_desc.datatype_out != LIBXSMM_DNN_DATATYPE_F32) ) { + /* error */ + } else { + /* fine, no error */ + } + handle->buffer_format = conv_desc.buffer_format; + handle->filter_format = conv_desc.filter_format; + handle->fuse_ops = conv_desc.fuse_ops; + handle->options = conv_desc.options; + + /* derive additional values */ + handle->ifhp = conv_desc.H + 2*conv_desc.pad_h_in; + handle->ifwp = conv_desc.W + 2*conv_desc.pad_w_in; + handle->ofh = (conv_desc.H + 2*conv_desc.pad_h - conv_desc.R) / conv_desc.u + 1; + handle->ofw = (conv_desc.W + 2*conv_desc.pad_w - conv_desc.S) / conv_desc.v + 1; + handle->ofhp = handle->ofh + 2*conv_desc.pad_h_out; + handle->ofwp = handle->ofw + 2*conv_desc.pad_w_out; + handle->ifmblock = 1; + handle->ofmblock = 1; + handle->blocksifm = conv_desc.C; + handle->blocksofm = conv_desc.K; + handle->fwd_ofw_rb = 1; + handle->fwd_ofh_rb = 1; + handle->bwd_ofw_rb = 1; + handle->bwd_ofh_rb = 1; + handle->upd_ofw_rb = 1; + handle->upd_ofh_rb = 1; + handle->fm_lp_block = 1; + handle->blocksifm_blocking = 1; + handle->blocksofm_blocking = 1; + /* Set algorithm to use */ + if (conv_desc.algo == LIBXSMM_DNN_CONV_ALGO_AUTO) { + handle->algo = LIBXSMM_DNN_CONV_ALGO_DIRECT; + } else { + handle->algo = conv_desc.algo; + } + if ( handle->algo != LIBXSMM_DNN_CONV_ALGO_DIRECT ) { + *status = LIBXSMM_DNN_ERR_INVALID_ALGO; + free(handle); + handle = 0; + return 0; + } + + *status = libxsmm_dnn_convolution_setup(handle); + } + else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + } + /* account for eventually deallocated handle */ + if ( LIBXSMM_DNN_SUCCESS != *status ) { + handle = 0; + } + return handle; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_conv_layer(const libxsmm_dnn_layer* handle) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + /* Deallocate barrier */ + if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); } + + /* deallocate handle structure itself */ + free(/*remove constness*/(libxsmm_dnn_layer*)handle); + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_create_tensor_datalayout(const libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor_datalayout* layout; + + *status = LIBXSMM_DNN_SUCCESS; + layout = 0; + + if (handle != 0) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout)); + + if (layout != 0) { + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->format = handle->buffer_format; + layout->tensor_type = LIBXSMM_DNN_ACTIVATION; + + if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->ifwp; + layout->dim_size[2] = handle->ifhp; + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->ofwp; + layout->dim_size[2] = handle->ofhp; + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + /* @TODO this need to change */ + } else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_I32) ) { + if ( ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_INPUT) ) ) { + layout->datatype = handle->datatype_in; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->datatype = handle->datatype_out; + } + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->ifwp; + layout->dim_size[2] = handle->ifhp; + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->ofwp; + layout->dim_size[2] = handle->ofhp; + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_BF16; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->ifwp; + layout->dim_size[2] = handle->ifhp; + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->ofwp; + layout->dim_size[2] = handle->ofhp; + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else if ( ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) ) { + if ( ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_INPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) ) ) { + layout->datatype = handle->datatype_in; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) ) { + layout->datatype = handle->datatype_out; + } + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->ifwp; + layout->dim_size[2] = handle->ifhp; + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->ofwp; + layout->dim_size[2] = handle->ofhp; + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->ofwp; + layout->dim_size[2] = handle->ofhp; + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->ifwp; + layout->dim_size[2] = handle->ifhp; + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + if ( ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->dim_size[0] = handle->ifmblock * handle->blocksifm; + layout->dim_size[1] = handle->ifwp; + layout->dim_size[2] = handle->ifhp; + layout->dim_size[3] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock * handle->blocksofm; + layout->dim_size[1] = handle->ofwp; + layout->dim_size[2] = handle->ofhp; + layout->dim_size[3] = handle->desc.N; + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) || (type == LIBXSMM_DNN_FILTER) ) { + layout->format = handle->filter_format; + layout->tensor_type = LIBXSMM_DNN_FILTER; + + if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 6; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->ifmblock; + layout->dim_size[2] = handle->desc.S; + layout->dim_size[3] = handle->desc.R; + layout->dim_size[4] = handle->blocksifm; + layout->dim_size[5] = handle->blocksofm; + } + } else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_BF16; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 7; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = handle->fm_lp_block; + layout->dim_size[1] = handle->ofmblock; + layout->dim_size[2] = handle->ifmblock/handle->fm_lp_block; + layout->dim_size[3] = handle->desc.S; + layout->dim_size[4] = handle->desc.R; + layout->dim_size[5] = handle->blocksifm; + layout->dim_size[6] = handle->blocksofm; + } + } else if ( ((handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8 ) ) { + if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_FILTER) ) { + layout->datatype = handle->datatype_in; + } else if (type == LIBXSMM_DNN_GRADIENT_FILTER) { + layout->datatype = handle->datatype_out; + } + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + if ((type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_FILTER)) { + layout->num_dims = 7; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = handle->fm_lp_block; + layout->dim_size[1] = handle->ofmblock; + layout->dim_size[2] = handle->ifmblock/handle->fm_lp_block; + layout->dim_size[3] = handle->desc.S; + layout->dim_size[4] = handle->desc.R; + layout->dim_size[5] = handle->blocksifm; + layout->dim_size[6] = handle->blocksofm; + } + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) { + if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_size[0] = handle->ofmblock * handle->blocksofm; + layout->dim_size[1] = handle->ifmblock * handle->blocksifm; + layout->dim_size[2] = handle->desc.S; + layout->dim_size[3] = handle->desc.R; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER_TRANS ) { + layout->format = handle->filter_format; + layout->tensor_type = LIBXSMM_DNN_REGULAR_FILTER_TRANS; + + if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 6; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->ofmblock; + layout->dim_size[2] = handle->desc.S; + layout->dim_size[3] = handle->desc.R; + layout->dim_size[4] = handle->blocksofm; + layout->dim_size[5] = handle->blocksifm; + } + } else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_BF16; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 7; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = handle->fm_lp_block; + layout->dim_size[1] = handle->ifmblock; + layout->dim_size[2] = handle->ofmblock/handle->fm_lp_block; + layout->dim_size[3] = handle->desc.S; + layout->dim_size[4] = handle->desc.R; + layout->dim_size[5] = handle->blocksofm; + layout->dim_size[6] = handle->blocksifm; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } +#if 0 + } else if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) { + if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_size[0] = handle->ofmblock * handle->blocksofm; + layout->dim_size[1] = handle->ifmblock * handle->blocksifm; + layout->dim_size[2] = handle->desc.S; + layout->dim_size[3] = handle->desc.K; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } +#endif + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) || (type == LIBXSMM_DNN_CHANNEL_BIAS) ) { + layout->format = handle->buffer_format; + layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR; + + if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + layout->datatype = handle->datatype_out; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 2; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->blocksofm; + } +#if 0 + } else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) ) { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 3; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = handle->fm_lp_block; + layout->dim_size[1] = handle->ofmblock; + layout->dim_size[2] = handle->blocksofm; + } +#endif + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + layout->datatype = handle->datatype_out; + if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 ) { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 1; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = handle->ofmblock*handle->blocksofm; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_BATCH_STATS) ) { + layout->format = handle->buffer_format; + layout->tensor_type = LIBXSMM_DNN_BATCH_STATS; + + if ((handle->buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) || (handle->datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 2; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->desc.N; + layout->dim_size[2] = handle->blocksofm; + layout->dim_size[3] = 2; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if (type == LIBXSMM_DNN_MAX_STATS_FWD) { + layout->format = handle->buffer_format; + layout->tensor_type = LIBXSMM_DNN_MAX_STATS_FWD; + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 2; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->desc.N; + } + } else if (type == LIBXSMM_DNN_MAX_STATS_BWD) { + layout->format = handle->buffer_format; + layout->tensor_type = LIBXSMM_DNN_MAX_STATS_BWD; + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 2; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->desc.N; + } + } else if (type == LIBXSMM_DNN_MAX_STATS_UPD) { + layout->format = handle->buffer_format; + layout->tensor_type = LIBXSMM_DNN_MAX_STATS_UPD; + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 2; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->desc.N; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return layout; +} + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_trans_reg_bf16_filter(const libxsmm_dnn_layer* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (handle != 0) { + if ( (handle->reg_filter != 0) && (handle->reg_filter_tr != 0) ) { + /* TODO handle more datatypes */ + int ifm1, ifm2, kj, ki, ofm1, ofm2; + int ofmblock_lp = handle->ofmblock/handle->fm_lp_block; + int ifmblock_lp = handle->ifmblock/handle->fm_lp_block; + int lpb = handle->fm_lp_block; + LIBXSMM_VLA_DECL(7, libxsmm_bfloat16, wt, (libxsmm_bfloat16*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb); + LIBXSMM_VLA_DECL(7, libxsmm_bfloat16, tr_wt, (libxsmm_bfloat16*)handle->reg_filter_tr->data, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + + /* TODO we might want to do this in parallel.... */ + for ( ifm1 = 0; ifm1 < handle->blocksifm; ++ifm1 ) { + for ( ofm1 = 0; ofm1 < handle->blocksofm; ++ofm1 ) { + for (kj=0; kj < handle->desc.R; ++kj) { + for (ki=0; ki < handle->desc.S; ++ki) { + for ( ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2 ) { + for ( ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2 ) { + LIBXSMM_VLA_ACCESS(7, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2/lpb, ifm2, ofm2%lpb, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb) = + LIBXSMM_VLA_ACCESS(7, wt, ofm1, ifm1, kj, ki, ifm2/lpb, ofm2, ifm2%lpb, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb); + } + } + } + } + } + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_trans_reg_filter(const libxsmm_dnn_layer* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (handle != 0) { + if ( (handle->reg_filter != 0) && (handle->reg_filter_tr != 0) ) { + /* TODO handle more datatypes */ + int ifm1, ifm2, kj, ki, ofm1, ofm2; + LIBXSMM_VLA_DECL(6, float, wt, (float*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + LIBXSMM_VLA_DECL(6, float, tr_wt, (float*)handle->reg_filter_tr->data, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + + /* TODO we might want to do this in parallel.... */ + for ( ifm1 = 0; ifm1 < handle->blocksifm; ++ifm1 ) { + for ( ofm1 = 0; ofm1 < handle->blocksofm; ++ofm1 ) { + for (kj=0; kj < handle->desc.R; ++kj) { + for (ki=0; ki < handle->desc.S; ++ki) { + for ( ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2 ) { + for ( ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2 ) { + LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, ofm2, ifm2, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } + } + } + } + } + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_bind_tensor(libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) && + (type != LIBXSMM_DNN_REGULAR_FILTER_TRANS) && (type != LIBXSMM_DNN_BATCH_STATS) && (type != LIBXSMM_DNN_MAX_STATS_FWD) && (type != LIBXSMM_DNN_MAX_STATS_BWD) && (type != LIBXSMM_DNN_MAX_STATS_UPD) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0 && tensor != 0) { + libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_create_tensor_datalayout(handle, type, &status); + + if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) { + handle->reg_filter = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) { + handle->grad_filter = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) { + handle->reg_bias = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) { + handle->grad_bias = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER_TRANS ) { + handle->reg_filter_tr = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_BATCH_STATS ) { + handle->batch_stats = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_MAX_STATS_FWD ) { + handle->maxstats_fwd = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_MAX_STATS_BWD ) { + handle->maxstats_bwd = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_MAX_STATS_UPD ) { + handle->maxstats_upd = (libxsmm_dnn_tensor*)tensor; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR; + } + + libxsmm_dnn_destroy_tensor_datalayout( handle_layout ); + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_get_tensor(libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) +{ + libxsmm_dnn_tensor* return_tensor = 0; + + *status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) && + (type != LIBXSMM_DNN_REGULAR_FILTER_TRANS) && (type != LIBXSMM_DNN_BATCH_STATS) && (type != LIBXSMM_DNN_MAX_STATS_FWD) && (type != LIBXSMM_DNN_MAX_STATS_BWD) && (type != LIBXSMM_DNN_MAX_STATS_UPD) ) { + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return return_tensor; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + return_tensor = handle->reg_input; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + return_tensor = handle->grad_input; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + return_tensor = handle->reg_output; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + return_tensor = handle->grad_output; + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) { + return_tensor = handle->reg_filter; + } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) { + return_tensor = handle->grad_filter; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) { + return_tensor = handle->reg_bias; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) { + return_tensor = handle->grad_bias; + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER_TRANS ) { + return_tensor = handle->reg_filter_tr; + } else if ( type == LIBXSMM_DNN_BATCH_STATS ) { + return_tensor = handle->batch_stats; + } else if ( type == LIBXSMM_DNN_MAX_STATS_FWD ) { + return_tensor = handle->maxstats_fwd; + } else if ( type == LIBXSMM_DNN_MAX_STATS_BWD ) { + return_tensor = handle->maxstats_bwd; + } else if ( type == LIBXSMM_DNN_MAX_STATS_UPD ) { + return_tensor = handle->maxstats_upd; + } else { + /* cannot happen */ + } + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return return_tensor; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_release_tensor(libxsmm_dnn_layer* handle, const libxsmm_dnn_tensor_type type) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) && + (type != LIBXSMM_DNN_REGULAR_FILTER_TRANS) && (type != LIBXSMM_DNN_BATCH_STATS) && (type != LIBXSMM_DNN_MAX_STATS_FWD) && (type != LIBXSMM_DNN_MAX_STATS_BWD) && (type != LIBXSMM_DNN_MAX_STATS_UPD) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) { + handle->reg_filter = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) { + handle->grad_filter = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) { + handle->reg_bias = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) { + handle->grad_bias = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER_TRANS ) { + handle->reg_filter_tr = 0; + } else if ( type == LIBXSMM_DNN_BATCH_STATS ) { + handle->batch_stats = 0; + } else if ( type == LIBXSMM_DNN_MAX_STATS_FWD ) { + handle->maxstats_fwd = 0; + } else if ( type == LIBXSMM_DNN_MAX_STATS_BWD ) { + handle->maxstats_bwd = 0; + } else if ( type == LIBXSMM_DNN_MAX_STATS_UPD ) { + handle->maxstats_upd = 0; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API size_t libxsmm_dnn_get_scratch_size(const libxsmm_dnn_layer* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status) +{ + size_t l_scratch_size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: break; + case LIBXSMM_DNN_COMPUTE_KIND_UPD: break; + case LIBXSMM_DNN_COMPUTE_KIND_ALL: break; + default: { + *status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + l_scratch_size += handle->scratch_size + 64; + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return l_scratch_size; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_bind_scratch(libxsmm_dnn_layer* handle, const libxsmm_dnn_compute_kind kind, const void* scratch) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + uintptr_t address = (uintptr_t)scratch; + size_t offset = 0; + + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + + if (0 != handle) { + if (address % 64 == 0) { + handle->scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch = (void*)(address+offset); + } + address += handle->scratch_size + 64; + + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: break; + case LIBXSMM_DNN_COMPUTE_KIND_UPD: break; + case LIBXSMM_DNN_COMPUTE_KIND_ALL: break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_release_scratch(libxsmm_dnn_layer* handle, const libxsmm_dnn_compute_kind kind) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + handle->scratch = 0; + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: break; + case LIBXSMM_DNN_COMPUTE_KIND_UPD: break; + case LIBXSMM_DNN_COMPUTE_KIND_ALL: break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API_INLINE libxsmm_dnn_err_t internal_execute_st(libxsmm_dnn_layer* handle, + libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + switch (handle->algo) { + case LIBXSMM_DNN_CONV_ALGO_DIRECT: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + switch (handle->buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + switch (handle->filter_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_convolve_st_fwd_custom_custom(handle, start_thread, tid); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + case LIBXSMM_DNN_TENSOR_FORMAT_NHWC: { + switch (handle->filter_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_RSCK: { + status = libxsmm_dnn_convolve_st_fwd_nhwc_rsck(handle, start_thread, tid); + } break; + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_convolve_st_fwd_nhwc_custom(handle, start_thread, tid); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: { + switch (handle->buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + switch (handle->filter_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_convolve_st_bwd_custom_custom(handle, start_thread, tid); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + case LIBXSMM_DNN_TENSOR_FORMAT_NHWC: { + switch (handle->filter_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_RSCK: { + status = libxsmm_dnn_convolve_st_bwd_nhwc_rsck(handle, start_thread, tid); + } break; + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_convolve_st_bwd_nhwc_custom(handle, start_thread, tid); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_UPD: { + switch (handle->buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + switch (handle->filter_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_convolve_st_upd_custom_custom(handle, start_thread, tid); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + case LIBXSMM_DNN_TENSOR_FORMAT_NHWC: { + switch (handle->filter_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_RSCK: { + status = libxsmm_dnn_convolve_st_upd_nhwc_rsck(handle, start_thread, tid); + } break; + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_convolve_st_upd_nhwc_custom(handle, start_thread, tid); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: { + switch (handle->buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + switch (handle->filter_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_convolve_st_upd_custom_custom(handle, start_thread, tid); + status = libxsmm_dnn_convolve_st_bwd_custom_custom(handle, start_thread, tid); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + case LIBXSMM_DNN_TENSOR_FORMAT_NHWC: { + switch (handle->filter_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_RSCK: { + status = libxsmm_dnn_convolve_st_upd_nhwc_rsck(handle, start_thread, tid); + status = libxsmm_dnn_convolve_st_bwd_nhwc_rsck(handle, start_thread, tid); + } break; + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_convolve_st_upd_nhwc_custom(handle, start_thread, tid); + status = libxsmm_dnn_convolve_st_bwd_nhwc_custom(handle, start_thread, tid); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_ALGO; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_execute_st(libxsmm_dnn_layer* handle, + libxsmm_dnn_compute_kind kind, /*unsigned*/int start_thread, /*unsigned*/int tid) +{ + return internal_execute_st(handle, kind, start_thread, tid); +} + + +LIBXSMM_API void libxsmm_dnn_execute(libxsmm_dnn_layer* handle, libxsmm_dnn_compute_kind kind) +{ +#if defined(_OPENMP) +# pragma omp parallel num_threads(handle->desc.threads) + { + const int tid = omp_get_thread_num(); + internal_execute_st(handle, kind, 0, tid); + } +#else + internal_execute_st(handle, kind, 0/*start_thread*/, 0/*tid*/); +#endif +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_convolution_backward.c b/third_party/libxsmm/src/libxsmm_dnn_convolution_backward.c new file mode 100644 index 0000000000000000000000000000000000000000..32da0d18de8a5fbd7b83ad356ba25ae3d2c40a86 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_convolution_backward.c @@ -0,0 +1,719 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_convolution_backward.h" +#include "libxsmm_main.h" + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid); + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void bf16_vnni_transpose_16x16_kernel(void* source_void, void* dest_void, int source_stride, int dest_stride) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + libxsmm_bfloat16 *source = (libxsmm_bfloat16*)source_void; + libxsmm_bfloat16 *dest = (libxsmm_bfloat16*)dest_void; + __m512i zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + __m512i tmp0, tmp1, tmp2, tmp3; + const __m512i abcdefgh_to_abefcdgh = _mm512_set4_epi32(0x0f0e0b0a, 0x0d0c0908, 0x07060302, 0x05040100); + + zmm0 = _mm512_load_epi32(source); + zmm1 = _mm512_load_epi32(source + source_stride); + zmm2 = _mm512_load_epi32(source + source_stride*2); + zmm3 = _mm512_load_epi32(source + source_stride*3); + zmm4 = _mm512_load_epi32(source + source_stride*4); + zmm5 = _mm512_load_epi32(source + source_stride*5); + zmm6 = _mm512_load_epi32(source + source_stride*6); + zmm7 = _mm512_load_epi32(source + source_stride*7); + + zmm0 = _mm512_shuffle_epi8(zmm0, abcdefgh_to_abefcdgh); + zmm1 = _mm512_shuffle_epi8(zmm1, abcdefgh_to_abefcdgh); + zmm2 = _mm512_shuffle_epi8(zmm2, abcdefgh_to_abefcdgh); + zmm3 = _mm512_shuffle_epi8(zmm3, abcdefgh_to_abefcdgh); + zmm4 = _mm512_shuffle_epi8(zmm4, abcdefgh_to_abefcdgh); + zmm5 = _mm512_shuffle_epi8(zmm5, abcdefgh_to_abefcdgh); + zmm6 = _mm512_shuffle_epi8(zmm6, abcdefgh_to_abefcdgh); + zmm7 = _mm512_shuffle_epi8(zmm7, abcdefgh_to_abefcdgh); + + tmp0 = _mm512_unpacklo_epi64(zmm0, zmm1); + tmp1 = _mm512_unpackhi_epi64(zmm0, zmm1); + tmp2 = _mm512_unpacklo_epi64(zmm2, zmm3); + tmp3 = _mm512_unpackhi_epi64(zmm2, zmm3); + zmm0 = _mm512_unpacklo_epi64(zmm4, zmm5); + zmm1 = _mm512_unpackhi_epi64(zmm4, zmm5); + zmm2 = _mm512_unpacklo_epi64(zmm6, zmm7); + zmm3 = _mm512_unpackhi_epi64(zmm6, zmm7); + + zmm4 = _mm512_shuffle_i32x4(tmp0, tmp2, 0x88); + zmm6 = _mm512_shuffle_i32x4(tmp0, tmp2, 0xdd); + zmm5 = _mm512_shuffle_i32x4(tmp1, tmp3, 0x88); + zmm7 = _mm512_shuffle_i32x4(tmp1, tmp3, 0xdd); + tmp0 = _mm512_shuffle_i32x4(zmm0, zmm2, 0x88); + tmp1 = _mm512_shuffle_i32x4(zmm0, zmm2, 0xdd); + tmp2 = _mm512_shuffle_i32x4(zmm1, zmm3, 0x88); + tmp3 = _mm512_shuffle_i32x4(zmm1, zmm3, 0xdd); + + zmm0 = _mm512_shuffle_i32x4(zmm4, tmp0, 0x88); + zmm1 = _mm512_shuffle_i32x4(zmm5, tmp2, 0x88); + zmm2 = _mm512_shuffle_i32x4(zmm6, tmp1, 0x88); + zmm3 = _mm512_shuffle_i32x4(zmm7, tmp3, 0x88); + zmm4 = _mm512_shuffle_i32x4(zmm4, tmp0, 0xdd); + zmm5 = _mm512_shuffle_i32x4(zmm5, tmp2, 0xdd); + zmm6 = _mm512_shuffle_i32x4(zmm6, tmp1, 0xdd); + zmm7 = _mm512_shuffle_i32x4(zmm7, tmp3, 0xdd); + + _mm512_store_epi32(dest, zmm0); + _mm512_store_epi32(dest + dest_stride, zmm1); + _mm512_store_epi32(dest + dest_stride * 2, zmm2); + _mm512_store_epi32(dest + dest_stride * 3, zmm3); + _mm512_store_epi32(dest + dest_stride * 4, zmm4); + _mm512_store_epi32(dest + dest_stride * 5, zmm5); + _mm512_store_epi32(dest + dest_stride * 6, zmm6); + _mm512_store_epi32(dest + dest_stride * 7, zmm7); +#else + LIBXSMM_UNUSED(source_void); LIBXSMM_UNUSED(dest_void); LIBXSMM_UNUSED(source_stride); LIBXSMM_UNUSED(dest_stride); +#endif +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void bf16_vnni_transpose_kernel(libxsmm_bfloat16* src, libxsmm_bfloat16* dst, int M, int N, int ld_in, int ld_out) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + const int _M = M/16, _N = N/16; + int i = 0, j = 0; + for (i = 0; i < _N; i++) { + for (j = 0; j < _M; j++) { + bf16_vnni_transpose_16x16_kernel((libxsmm_bfloat16*) src+i*16*ld_in+j*32, (libxsmm_bfloat16*) dst+j*16*ld_out+i*32, ld_in*2, ld_out*2); + } + } +#else + LIBXSMM_UNUSED(src); LIBXSMM_UNUSED(dst); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out); +#endif +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if (handle->use_fallback_bwd_loops == 0) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + const libxsmm_blasint ldB = (libxsmm_blasint)handle->ofmblock; + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : (libxsmm_blasint)handle->ifmblock; + const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f); + int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c" + } + } else { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock); + { /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, NULL, NULL, &ldC, NULL, NULL, NULL, NULL); +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c" + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + if (handle->use_fallback_bwd_loops == 0) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + { typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function; + typedef libxsmm_bmmfunction_reducebatch_addr gemm_br_function_bf16bf16; + const libxsmm_blasint ldB = (libxsmm_blasint)handle->ofmblock; + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : (libxsmm_blasint)handle->ifmblock; + const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f); + int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function br_gemm_kernel2 = libxsmm_bsmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function_bf16bf16 br_gemm_kernel_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function_bf16bf16 br_gemm_kernel2_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL); +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + } + } else { + const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock); + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_bsmmfunction_reducebatch_strd brgemm_function; + int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + int stride_a = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(libxsmm_bfloat16); + int stride_b = handle->ofmblock * handle->ofwp * handle->ofhp * sizeof(libxsmm_bfloat16); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + brgemm_function bf16fp32_brgemm_kernel = libxsmm_bsmmdispatch_reducebatch_strd(handle->ifmblock, handle->ofw, handle->ofmblock, stride_a, stride_b, NULL, NULL, &ldC, NULL, NULL, &l_flags, NULL); + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + if (handle->use_fallback_bwd_loops == 0) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + { + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_bsmmfunction gemm_function; + typedef libxsmm_bsmmfunction_reducebatch_offs gemm_br_function_offs; + typedef libxsmm_bsmmfunction_reducebatch_strd gemm_br_function_strd; + gemm_br_function_offs br_gemm_kernel_offs = handle->bwd_compute_kernel_offs; + gemm_br_function_strd br_gemm_kernel_strd = handle->bwd_compute_kernel_strd; + gemm_function tile_config_kernel = handle->bwd_config_kernel; +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + } + } else { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_bsmmfunction_reducebatch_strd brgemm_function; + const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock); + int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + int stride_a = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(libxsmm_bfloat16); + int stride_b = handle->ofmblock * handle->ofwp * handle->ofhp * sizeof(libxsmm_bfloat16); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + brgemm_function bf16fp32_brgemm_kernel = libxsmm_bsmmdispatch_reducebatch_strd(handle->ifmblock, handle->ofw, handle->ofmblock, stride_a, stride_b, NULL, NULL, &ldC, NULL, NULL, &l_flags, NULL); + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) + LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + if (handle->use_fallback_bwd_loops == 0) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + { + typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function; + typedef libxsmm_bmmfunction_reducebatch_addr gemm_br_function_bf16bf16; + const libxsmm_blasint ldB = (libxsmm_blasint)handle->ofmblock; + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : (libxsmm_blasint)handle->ifmblock; + const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f); + int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function br_gemm_kernel2 = libxsmm_bsmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function_bf16bf16 br_gemm_kernel_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function_bf16bf16 br_gemm_kernel2_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, NULL); +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + } +# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + } else { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_bsmmfunction_reducebatch_strd brgemm_function; + const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock); + int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + int stride_a = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(libxsmm_bfloat16); + int stride_b = handle->ofmblock * handle->ofwp * handle->ofhp * sizeof(libxsmm_bfloat16); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + brgemm_function bf16fp32_brgemm_kernel = libxsmm_bsmmdispatch_reducebatch_strd(handle->ifmblock, handle->ofw, handle->ofmblock, stride_a, stride_b, NULL, NULL, &ldC, NULL, NULL, &l_flags, NULL); +# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else + LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + return libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid ); +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) + LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + if (handle->use_fallback_bwd_loops == 0) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + { + typedef libxsmm_bsmmfunction gemm_function; + typedef libxsmm_bsmmfunction_reducebatch_offs gemm_br_function_offs; + typedef libxsmm_bsmmfunction_reducebatch_strd gemm_br_function_strd; + gemm_br_function_offs br_gemm_kernel_offs = handle->bwd_compute_kernel_offs; + gemm_br_function_strd br_gemm_kernel_strd = handle->bwd_compute_kernel_strd; + gemm_function tile_config_kernel = handle->bwd_config_kernel; +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + } +# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + } else { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_bsmmfunction_reducebatch_strd brgemm_function; + const libxsmm_blasint ldC = (libxsmm_blasint)(handle->desc.v*handle->ifmblock); + int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + int stride_a = handle->ifmblock * handle->desc.R * handle->desc.S * handle->ofmblock * sizeof(libxsmm_bfloat16); + int stride_b = handle->ofmblock * handle->ofwp * handle->ofhp * sizeof(libxsmm_bfloat16); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + brgemm_function bf16fp32_brgemm_kernel = libxsmm_bsmmdispatch_reducebatch_strd(handle->ifmblock, handle->ofw, handle->ofmblock, stride_a, stride_b, NULL, NULL, &ldC, NULL, NULL, &l_flags, NULL); + +# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else + LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + return libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid ); +} +#endif + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if (handle->use_fallback_bwd_loops == 0) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock); + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v) : (libxsmm_blasint)(handle->blocksifm * handle->ifmblock); + const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f); + int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM +# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM + } + } else { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock); + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = ( (handle->desc.pad_h != handle->desc.pad_h_in) || (handle->desc.pad_w != handle->desc.pad_w_in) ) ? + (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : + (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, &ldA, &ldB, &ldC, NULL, NULL, NULL, NULL); +# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM +# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if (handle->use_fallback_bwd_loops == 0) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock); + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v) : (libxsmm_blasint)(handle->blocksifm * handle->ifmblock); + const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f); + int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK +# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK + } + } else { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock); + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = ( (handle->desc.pad_h != handle->desc.pad_h_in) || (handle->desc.pad_w != handle->desc.pad_w_in) ) ? + (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : + (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, &ldA, &ldB, &ldC, NULL, NULL, NULL, NULL); +# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK +# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ + if (handle->grad_input == 0 || handle->grad_output == 0 || handle->reg_filter == 0 || handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32( handle, start_thread, tid); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX ) { + status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx( handle, start_thread, tid); + } +#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid); + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + if (handle->use_fallback_bwd_loops == 0) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + const libxsmm_blasint ldx = ((libxsmm_blasint)handle->ofmblock); + const libxsmm_blasint ldA = handle->ifmblock; + const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? handle->ifmblock * handle->desc.v : handle->ifmblock; + const float beta = (handle->avoid_acc_load_bwd) ? 0.f : 1.f; + int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c" + } + } else { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + const libxsmm_blasint ldx = ((libxsmm_blasint)handle->desc.v*handle->ifmblock); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, NULL, NULL, &ldx, NULL, NULL, NULL, NULL); +# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c" + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ + if (handle->grad_input == 0 || handle->grad_output == 0 || handle->reg_filter == 0 || handle->scratch == 0) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + if (handle->use_fallback_bwd_loops == 0) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock); + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v) : (libxsmm_blasint)(handle->blocksifm * handle->ifmblock); + const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f); + int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK +# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK + } + } else { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock); + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = ( (handle->desc.pad_h != handle->desc.pad_h_in) || (handle->desc.pad_w != handle->desc.pad_w_in) ) ? + (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : + (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, &ldA, &ldB, &ldC, NULL, NULL, NULL, NULL); +# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK +# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ + if (handle->grad_input == 0 || handle->grad_output == 0 || handle->reg_filter == 0 || handle->scratch == 0) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + if (handle->use_fallback_bwd_loops == 0) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock); + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = (handle->spread_input_bwd == 1) ? (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v) : (libxsmm_blasint)(handle->blocksifm * handle->ifmblock); + const float beta = (handle->avoid_acc_load_bwd ? 0.f : 1.f); + int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*handle->bwd_ofw_rb, handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ifmblock, handle->bwd_ofh_rb*(handle->bwd_ofw_rb-1), handle->ofmblock, &ldA, &ldB, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM +# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM + } + } else { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + const libxsmm_blasint ldB = (libxsmm_blasint)(handle->blocksofm * handle->ofmblock); + const libxsmm_blasint ldA = (libxsmm_blasint)handle->ifmblock; + const libxsmm_blasint ldC = ( (handle->desc.pad_h != handle->desc.pad_h_in) || (handle->desc.pad_w != handle->desc.pad_w_in) ) ? + (libxsmm_blasint)(handle->ifmblock * handle->desc.v) : + (libxsmm_blasint)(handle->blocksifm * handle->ifmblock * handle->desc.v); + /* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */ + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ifmblock, handle->ofw, handle->ofmblock, &ldA, &ldB, &ldC, NULL, NULL, NULL, NULL); +# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM +# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_convolution_backward.h b/third_party/libxsmm/src/libxsmm_dnn_convolution_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..ed1928d014e6b3e0c04514709804a6eb4bd23622 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_convolution_backward.h @@ -0,0 +1,22 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Rajkishore Barik, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_CONVOLUTION_BACKWARD_H +#define LIBXSMM_DNN_CONVOLUTION_BACKWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_bwd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_CONVOLUTION_BACKWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_convolution_forward.c b/third_party/libxsmm/src/libxsmm_dnn_convolution_forward.c new file mode 100644 index 0000000000000000000000000000000000000000..b56b60b61445a9b402f1ab5c5dd7402110225fc3 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_convolution_forward.c @@ -0,0 +1,544 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas, Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_convolution_forward.h" +#include "libxsmm_main.h" + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8(libxsmm_dnn_layer* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; +#if 1 + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function_addr; + const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock; + const libxsmm_blasint ldA = handle->ofmblock; + const libxsmm_blasint ldC = handle->ofmblock; + const float beta = (handle->avoid_acc_load) ? 0.f : 1.f; + int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags; + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */ + gemm_br_function_addr br_gemm_kernel_a_addr = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function_addr br_gemm_kernel_b_addr = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +#else + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function_addr; + typedef libxsmm_smmfunction_reducebatch_offs gemm_br_function_offs; + typedef libxsmm_smmfunction_reducebatch_strd gemm_br_function_strd; + + { + gemm_br_function_addr br_gemm_kernel_a_addr = handle->fwd_compute_kernel_addr_a_f32; + gemm_br_function_addr br_gemm_kernel_b_addr = handle->fwd_compute_kernel_addr_b_f32; + gemm_br_function_offs br_gemm_kernel_offs = handle->fwd_compute_kernel_offs_f32; + gemm_br_function_strd br_gemm_kernel_strd = handle->fwd_compute_kernel_strd_f32; +#endif +# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c" + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + { + typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function; + typedef libxsmm_bmmfunction_reducebatch_addr gemm_br_function_bf16bf16; + const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock; + const libxsmm_blasint ldA = handle->ofmblock; + const libxsmm_blasint ldC = handle->ofmblock; + const float beta = (handle->avoid_acc_load) ? 0.f : 1.f; + int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') )| handle->fwd_flags; + + /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function br_gemm_kernel2 = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function_bf16bf16 br_gemm_kernel_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function_bf16bf16 br_gemm_kernel2_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); +# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + typedef libxsmm_bsmmfunction gemm_function; + typedef libxsmm_bmmfunction_reducebatch_offs gemm_br_function_offs_a; + typedef libxsmm_bsmmfunction_reducebatch_offs gemm_br_function_offs_b; + typedef libxsmm_bmmfunction_reducebatch_strd gemm_br_function_strd; + gemm_br_function_offs_a br_gemm_kernel_offs_a = handle->fwd_compute_kernel_offs_a; + gemm_br_function_offs_b br_gemm_kernel_offs_b = handle->fwd_compute_kernel_offs_b; + gemm_br_function_strd br_gemm_kernel_strd = handle->fwd_compute_kernel_strd; + gemm_function tile_config_kernel = handle->fwd_config_kernel; +# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c" +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function; + typedef libxsmm_bmmfunction_reducebatch_addr gemm_br_function_bf16bf16; + const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock; + const libxsmm_blasint ldA = handle->ofmblock; + const libxsmm_blasint ldC = handle->ofmblock; + const float beta = (handle->avoid_acc_load) ? 0.f : 1.f; + int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | handle->fwd_flags; + gemm_br_function br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function br_gemm_kernel2 = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function_bf16bf16 br_gemm_kernel_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); + gemm_br_function_bf16bf16 br_gemm_kernel2_bf16bf16 = libxsmm_bmmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL); +# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c" + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + return libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid ); +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + typedef libxsmm_bsmmfunction gemm_function; + typedef libxsmm_bmmfunction_reducebatch_offs gemm_br_function_offs_a; + typedef libxsmm_bsmmfunction_reducebatch_offs gemm_br_function_offs_b; + typedef libxsmm_bmmfunction_reducebatch_strd gemm_br_function_strd; + gemm_br_function_offs_a br_gemm_kernel_offs_a = handle->fwd_compute_kernel_offs_a; + gemm_br_function_offs_b br_gemm_kernel_offs_b = handle->fwd_compute_kernel_offs_b; + gemm_br_function_strd br_gemm_kernel_strd = handle->fwd_compute_kernel_strd; + gemm_function tile_config_kernel = handle->fwd_config_kernel; +# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c" + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + return libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid ); +} +#endif + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef unsigned char element_input_type; + typedef int element_output_type; + typedef char element_filter_type; + /* Basically we need only offset based and strided BRGEMMs */ + libxsmm_subimmfunction_reducebatch_strd br_gemm_kernel_strided = handle->gemm_fwd.xgemm.subimrs; + libxsmm_subimmfunction_reducebatch_strd br_gemm_kernel_strided2 = handle->gemm_fwd2.xgemm.subimrs; + libxsmm_subimmfunction_reducebatch_offs br_gemm_kernel_offset = handle->gemm_fwd.xgemm.subimro; +# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i32.tpl.c" +#else + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef unsigned char element_input_type; + typedef unsigned char element_output_type; + typedef char element_filter_type; + /* Basically we need only offset based and strided BRGEMMs */ + libxsmm_sububmmfunction_reducebatch_strd br_gemm_kernel_strided = handle->gemm_fwd.xgemm.sububmrs; + libxsmm_sububmmfunction_reducebatch_offs br_gemm_kernel_offset = handle->gemm_fwd.xgemm.sububmro; +# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i8.tpl.c" +#else + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->blocksifm*handle->ifmblock : (libxsmm_blasint)handle->blocksifm*handle->desc.v*handle->ifmblock; + const libxsmm_blasint ldA = handle->ofmblock; + const libxsmm_blasint ldC = handle->blocksofm*handle->ofmblock; + const float beta = (handle->avoid_acc_load) ? 0.f : 1.f; + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags; + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM +# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->blocksifm*handle->ifmblock : (libxsmm_blasint)handle->blocksifm*handle->desc.v*handle->ifmblock; + const libxsmm_blasint ldA = handle->blocksofm*handle->ofmblock; + const libxsmm_blasint ldC = handle->blocksofm*handle->ofmblock; + const float beta = (handle->avoid_acc_load) ? 0.f : 1.f; + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags; + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK +# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ + if (handle->reg_input == 0 || handle->reg_output == 0 || handle->reg_filter == 0) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_I32 ) { + status = libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_I8 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_I8 ) { + status = libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8( handle, start_thread, tid); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX) { + status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx( handle, start_thread, tid); + } +#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid); + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; +#if 1 + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function_addr; + const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock; + const libxsmm_blasint ldA = handle->ofmblock; + const libxsmm_blasint ldC = handle->ofmblock; + const float beta = (handle->avoid_acc_load) ? 0.f : 1.f; + int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags; + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */ + gemm_br_function_addr br_gemm_kernel_a_addr = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function_addr br_gemm_kernel_b_addr = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +#else + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function_addr; + typedef libxsmm_smmfunction_reducebatch_offs gemm_br_function_offs; + typedef libxsmm_smmfunction_reducebatch_strd gemm_br_function_strd; + + { + gemm_br_function_addr br_gemm_kernel_a_addr = handle->fwd_compute_kernel_addr_a_f32; + gemm_br_function_addr br_gemm_kernel_b_addr = handle->fwd_compute_kernel_addr_b_f32; + gemm_br_function_offs br_gemm_kernel_offs = handle->fwd_compute_kernel_offs_f32; + gemm_br_function_strd br_gemm_kernel_strd = handle->fwd_compute_kernel_strd_f32; +#endif +# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c" + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ + if (handle->reg_input == 0 || handle->reg_output == 0 || handle->reg_filter == 0) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->blocksifm*handle->ifmblock : (libxsmm_blasint)handle->blocksifm*handle->desc.v*handle->ifmblock; + const libxsmm_blasint ldA = handle->ofmblock; + const libxsmm_blasint ldC = handle->blocksofm*handle->ofmblock; + const float beta = (handle->avoid_acc_load) ? 0.f : 1.f; + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags; + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM +# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ + if (handle->reg_input == 0 || handle->reg_output == 0 || handle->reg_filter == 0) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + const libxsmm_blasint ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->blocksifm*handle->ifmblock : (libxsmm_blasint)handle->blocksifm*handle->desc.v*handle->ifmblock; + const libxsmm_blasint ldA = handle->blocksofm*handle->ofmblock; + const libxsmm_blasint ldC = handle->blocksofm*handle->ofmblock; + const float beta = (handle->avoid_acc_load) ? 0.f : 1.f; + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; + int l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags; + int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + int brgemm_pf_oob = 0; + const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); + if ( 0 == env_brgemm_pf_oob ) { + } else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); + } + if (brgemm_pf_oob > 0) { + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); + } + { /* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */ + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode); +# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK +# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c" +# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_convolution_forward.h b/third_party/libxsmm/src/libxsmm_dnn_convolution_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..de2c4fdb2727531ec4118edeed889af266b73ead --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_convolution_forward.h @@ -0,0 +1,22 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_CONVOLUTION_FORWARD_H +#define LIBXSMM_DNN_CONVOLUTION_FORWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_fwd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_CONVOLUTION_FORWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.c b/third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.c new file mode 100644 index 0000000000000000000000000000000000000000..c20d74c73c02ab0583b8e5689dca40caead49681 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.c @@ -0,0 +1,914 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Rajkishore Barik, Alexander Heinecke, Ankush Mandal, Jason Sewall (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_convolution_weight_update.h" +#include "libxsmm_main.h" + + +/* function prototypes for below implementations */ +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid); + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void transpose_32x16(const libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int ld_in, int ld_out) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + const int in_width=ld_in, out_width=ld_out; + const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10); + + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + rc = _mm512_loadu_si512(in + 12*in_width); + rd = _mm512_loadu_si512(in + 13*in_width); + re = _mm512_loadu_si512(in + 14*in_width); + rf = _mm512_loadu_si512(in + 15*in_width); + + t0 = _mm512_unpacklo_epi16(r0,r1); + t1 = _mm512_unpackhi_epi16(r0,r1); + t2 = _mm512_unpacklo_epi16(r2,r3); + t3 = _mm512_unpackhi_epi16(r2,r3); + t4 = _mm512_unpacklo_epi16(r4,r5); + t5 = _mm512_unpackhi_epi16(r4,r5); + t6 = _mm512_unpacklo_epi16(r6,r7); + t7 = _mm512_unpackhi_epi16(r6,r7); + t8 = _mm512_unpacklo_epi16(r8,r9); + t9 = _mm512_unpackhi_epi16(r8,r9); + ta = _mm512_unpacklo_epi16(ra,rb); + tb = _mm512_unpackhi_epi16(ra,rb); + tc = _mm512_unpacklo_epi16(rc,rd); + td = _mm512_unpackhi_epi16(rc,rd); + te = _mm512_unpacklo_epi16(re,rf); + tf = _mm512_unpackhi_epi16(re,rf); + + r0 = _mm512_unpacklo_epi32(t0,t2); + r1 = _mm512_unpackhi_epi32(t0,t2); + r2 = _mm512_unpacklo_epi32(t1,t3); + r3 = _mm512_unpackhi_epi32(t1,t3); + r4 = _mm512_unpacklo_epi32(t4,t6); + r5 = _mm512_unpackhi_epi32(t4,t6); + r6 = _mm512_unpacklo_epi32(t5,t7); + r7 = _mm512_unpackhi_epi32(t5,t7); + r8 = _mm512_unpacklo_epi32(t8,ta); + r9 = _mm512_unpackhi_epi32(t8,ta); + ra = _mm512_unpacklo_epi32(t9,tb); + rb = _mm512_unpackhi_epi32(t9,tb); + rc = _mm512_unpacklo_epi32(tc,te); + rd = _mm512_unpackhi_epi32(tc,te); + re = _mm512_unpacklo_epi32(td,tf); + rf = _mm512_unpackhi_epi32(td,tf); + + t0 = _mm512_unpacklo_epi64(r0,r4); + t1 = _mm512_unpackhi_epi64(r0,r4); + t2 = _mm512_unpacklo_epi64(r1,r5); + t3 = _mm512_unpackhi_epi64(r1,r5); + t4 = _mm512_unpacklo_epi64(r2,r6); + t5 = _mm512_unpackhi_epi64(r2,r6); + t6 = _mm512_unpacklo_epi64(r3,r7); + t7 = _mm512_unpackhi_epi64(r3,r7); + t8 = _mm512_unpacklo_epi64(r8,rc); + t9 = _mm512_unpackhi_epi64(r8,rc); + ta = _mm512_unpacklo_epi64(r9,rd); + tb = _mm512_unpackhi_epi64(r9,rd); + tc = _mm512_unpacklo_epi64(ra,re); + td = _mm512_unpackhi_epi64(ra,re); + te = _mm512_unpacklo_epi64(rb,rf); + tf = _mm512_unpackhi_epi64(rb,rf); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r2 = _mm512_shuffle_i32x4(t4, t5, 0x88); + r3 = _mm512_shuffle_i32x4(t6, t7, 0x88); + r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd); + r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd); + r8 = _mm512_shuffle_i32x4(t8, t9, 0x88); + r9 = _mm512_shuffle_i32x4(ta, tb, 0x88); + ra = _mm512_shuffle_i32x4(tc, td, 0x88); + rb = _mm512_shuffle_i32x4(te, tf, 0x88); + rc = _mm512_shuffle_i32x4(t8, t9, 0xdd); + rd = _mm512_shuffle_i32x4(ta, tb, 0xdd); + re = _mm512_shuffle_i32x4(tc, td, 0xdd); + rf = _mm512_shuffle_i32x4(te, tf, 0xdd); + + t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8); + t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9); + t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra); + t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb); + t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc); + t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd); + t6 = _mm512_permutex2var_epi64(r6, idx_lo, re); + t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf); + t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0); + t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1); + ta = _mm512_permutex2var_epi64(ra, idx_hi, r2); + tb = _mm512_permutex2var_epi64(rb, idx_hi, r3); + tc = _mm512_permutex2var_epi64(rc, idx_hi, r4); + td = _mm512_permutex2var_epi64(rd, idx_hi, r5); + te = _mm512_permutex2var_epi64(re, idx_hi, r6); + tf = _mm512_permutex2var_epi64(rf, idx_hi, r7); + + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 0*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 1*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 2*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 3*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 4*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 5*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 6*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 7*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 8*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 9*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 10*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 11*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 12*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 13*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 14*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 15*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 16*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 17*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 18*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 19*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 20*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 21*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 22*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 23*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 24*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 25*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 26*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 27*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 28*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 29*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 30*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 31*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1)); +#else + LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out); +#endif +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void transpose_32xcols(const libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int col, int ld_in, int ld_out) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + const int in_width=ld_in, out_width=ld_out; + const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10); + __mmask16 store_mask = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(((unsigned int)1 << col) - 1); + + rf = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + if (col == 15) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + rc = _mm512_loadu_si512(in + 12*in_width); + rd = _mm512_loadu_si512(in + 13*in_width); + re = _mm512_loadu_si512(in + 14*in_width); + } else if (col == 14) { + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + rc = _mm512_loadu_si512(in + 12*in_width); + rd = _mm512_loadu_si512(in + 13*in_width); + } else if (col == 13) { + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + rc = _mm512_loadu_si512(in + 12*in_width); + } else if (col == 12) { + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + } else if (col == 11) { + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + } else if (col == 10) { + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + } else if (col == 9) { + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + } else if (col == 8) { + r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + } else if (col == 7) { + r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + } else if (col == 6) { + r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + } else if (col == 5) { + r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + } else if (col == 4) { + r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + } else if (col == 3) { + r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + } else if (col == 2) { + r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + } else if (col == 1) { + r1 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r0 = _mm512_loadu_si512(in + 0*in_width); + } else { + r0 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r1 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + } + + t0 = _mm512_unpacklo_epi16(r0,r1); + t1 = _mm512_unpackhi_epi16(r0,r1); + t2 = _mm512_unpacklo_epi16(r2,r3); + t3 = _mm512_unpackhi_epi16(r2,r3); + t4 = _mm512_unpacklo_epi16(r4,r5); + t5 = _mm512_unpackhi_epi16(r4,r5); + t6 = _mm512_unpacklo_epi16(r6,r7); + t7 = _mm512_unpackhi_epi16(r6,r7); + t8 = _mm512_unpacklo_epi16(r8,r9); + t9 = _mm512_unpackhi_epi16(r8,r9); + ta = _mm512_unpacklo_epi16(ra,rb); + tb = _mm512_unpackhi_epi16(ra,rb); + tc = _mm512_unpacklo_epi16(rc,rd); + td = _mm512_unpackhi_epi16(rc,rd); + te = _mm512_unpacklo_epi16(re,rf); + tf = _mm512_unpackhi_epi16(re,rf); + + r0 = _mm512_unpacklo_epi32(t0,t2); + r1 = _mm512_unpackhi_epi32(t0,t2); + r2 = _mm512_unpacklo_epi32(t1,t3); + r3 = _mm512_unpackhi_epi32(t1,t3); + r4 = _mm512_unpacklo_epi32(t4,t6); + r5 = _mm512_unpackhi_epi32(t4,t6); + r6 = _mm512_unpacklo_epi32(t5,t7); + r7 = _mm512_unpackhi_epi32(t5,t7); + r8 = _mm512_unpacklo_epi32(t8,ta); + r9 = _mm512_unpackhi_epi32(t8,ta); + ra = _mm512_unpacklo_epi32(t9,tb); + rb = _mm512_unpackhi_epi32(t9,tb); + rc = _mm512_unpacklo_epi32(tc,te); + rd = _mm512_unpackhi_epi32(tc,te); + re = _mm512_unpacklo_epi32(td,tf); + rf = _mm512_unpackhi_epi32(td,tf); + + t0 = _mm512_unpacklo_epi64(r0,r4); + t1 = _mm512_unpackhi_epi64(r0,r4); + t2 = _mm512_unpacklo_epi64(r1,r5); + t3 = _mm512_unpackhi_epi64(r1,r5); + t4 = _mm512_unpacklo_epi64(r2,r6); + t5 = _mm512_unpackhi_epi64(r2,r6); + t6 = _mm512_unpacklo_epi64(r3,r7); + t7 = _mm512_unpackhi_epi64(r3,r7); + t8 = _mm512_unpacklo_epi64(r8,rc); + t9 = _mm512_unpackhi_epi64(r8,rc); + ta = _mm512_unpacklo_epi64(r9,rd); + tb = _mm512_unpackhi_epi64(r9,rd); + tc = _mm512_unpacklo_epi64(ra,re); + td = _mm512_unpackhi_epi64(ra,re); + te = _mm512_unpacklo_epi64(rb,rf); + tf = _mm512_unpackhi_epi64(rb,rf); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r2 = _mm512_shuffle_i32x4(t4, t5, 0x88); + r3 = _mm512_shuffle_i32x4(t6, t7, 0x88); + r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd); + r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd); + r8 = _mm512_shuffle_i32x4(t8, t9, 0x88); + r9 = _mm512_shuffle_i32x4(ta, tb, 0x88); + ra = _mm512_shuffle_i32x4(tc, td, 0x88); + rb = _mm512_shuffle_i32x4(te, tf, 0x88); + rc = _mm512_shuffle_i32x4(t8, t9, 0xdd); + rd = _mm512_shuffle_i32x4(ta, tb, 0xdd); + re = _mm512_shuffle_i32x4(tc, td, 0xdd); + rf = _mm512_shuffle_i32x4(te, tf, 0xdd); + + t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8); + t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9); + t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra); + t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb); + t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc); + t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd); + t6 = _mm512_permutex2var_epi64(r6, idx_lo, re); + t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf); + t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0); + t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1); + ta = _mm512_permutex2var_epi64(ra, idx_hi, r2); + tb = _mm512_permutex2var_epi64(rb, idx_hi, r3); + tc = _mm512_permutex2var_epi64(rc, idx_hi, r4); + td = _mm512_permutex2var_epi64(rd, idx_hi, r5); + te = _mm512_permutex2var_epi64(re, idx_hi, r6); + tf = _mm512_permutex2var_epi64(rf, idx_hi, r7); + + _mm256_mask_storeu_epi16(out + 0*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0)); + _mm256_mask_storeu_epi16(out + 1*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1)); + _mm256_mask_storeu_epi16(out + 2*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0)); + _mm256_mask_storeu_epi16(out + 3*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1)); + _mm256_mask_storeu_epi16(out + 4*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0)); + _mm256_mask_storeu_epi16(out + 5*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1)); + _mm256_mask_storeu_epi16(out + 6*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0)); + _mm256_mask_storeu_epi16(out + 7*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1)); + _mm256_mask_storeu_epi16(out + 8*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0)); + _mm256_mask_storeu_epi16(out + 9*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1)); + _mm256_mask_storeu_epi16(out + 10*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0)); + _mm256_mask_storeu_epi16(out + 11*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1)); + _mm256_mask_storeu_epi16(out + 12*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0)); + _mm256_mask_storeu_epi16(out + 13*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1)); + _mm256_mask_storeu_epi16(out + 14*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0)); + _mm256_mask_storeu_epi16(out + 15*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1)); + _mm256_mask_storeu_epi16(out + 16*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0)); + _mm256_mask_storeu_epi16(out + 17*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1)); + _mm256_mask_storeu_epi16(out + 18*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0)); + _mm256_mask_storeu_epi16(out + 19*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1)); + _mm256_mask_storeu_epi16(out + 20*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0)); + _mm256_mask_storeu_epi16(out + 21*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1)); + _mm256_mask_storeu_epi16(out + 22*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0)); + _mm256_mask_storeu_epi16(out + 23*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1)); + _mm256_mask_storeu_epi16(out + 24*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0)); + _mm256_mask_storeu_epi16(out + 25*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1)); + _mm256_mask_storeu_epi16(out + 26*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0)); + _mm256_mask_storeu_epi16(out + 27*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1)); + _mm256_mask_storeu_epi16(out + 28*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0)); + _mm256_mask_storeu_epi16(out + 29*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1)); + _mm256_mask_storeu_epi16(out + 30*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0)); + _mm256_mask_storeu_epi16(out + 31*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1)); +#else + LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(col); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out); +#endif +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void transpose_input_pixels_bf16(const libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int M, int N, int ld_in, int ld_out){ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + int i, j; + int full16_chunks = N/16; + int remainder_cols = N%16; + int _N = N - remainder_cols; + + if (full16_chunks) { + for (i=0; i FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function; +# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c" + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_bsmmfunction gemm_function; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + typedef libxsmm_bsmmfunction_reducebatch_strd gemm_br_function; + gemm_function tile_config_kernel = handle->upd_config_kernel; + gemm_function gemm_kernel = NULL; + gemm_br_function br_gemm_kernel = NULL; +# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c" + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_bsmmfunction gemm_function; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + typedef libxsmm_bsmmfunction_reducebatch_addr gemm_br_function; +# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c" + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + return libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu( handle, start_thread, tid ); +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_bsmmfunction gemm_function; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + typedef libxsmm_bsmmfunction_reducebatch_strd gemm_br_function; + gemm_function tile_config_kernel = handle->upd_config_kernel; + gemm_function gemm_kernel = NULL; + gemm_br_function br_gemm_kernel = NULL; +# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c" + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + return libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid ); +} +#endif + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; +#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM +# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c" +#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; +#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK +# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c" +#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ + if (handle->reg_input == 0 || handle->grad_output == 0 || handle->grad_filter == 0 || handle->scratch == 0) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ((handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32( handle, start_thread, tid); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX ) { + status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx( handle, start_thread, tid); + } +#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx( handle, start_thread, tid); + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; +# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c" + } + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ + if (handle->reg_input == 0 || handle->grad_output == 0 || handle->grad_filter == 0 || handle->scratch == 0) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; +#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM +# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c" +#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM + } + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ + if (handle->reg_input == 0 || handle->grad_output == 0 || handle->grad_filter == 0 || handle->scratch == 0) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + typedef libxsmm_smmfunction_reducebatch_addr gemm_br_function; +#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK +# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c" +#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK + } + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.h b/third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.h new file mode 100644 index 0000000000000000000000000000000000000000..2966a80ea94e191a4aab31f95a870aef9d49d1d9 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.h @@ -0,0 +1,22 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Rajkishore Barik, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H +#define LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_custom_custom(libxsmm_dnn_layer* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_rsck(libxsmm_dnn_layer* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_convolve_st_upd_nhwc_custom(libxsmm_dnn_layer* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_elementwise.c b/third_party/libxsmm/src/libxsmm_dnn_elementwise.c new file mode 100644 index 0000000000000000000000000000000000000000..06d5782a03eec83be659b96c6478feaf89dc2e2a --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_elementwise.c @@ -0,0 +1,618 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Kunal Banerjee, Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_elementwise.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_zero(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + src[i] = (LIBXSMM_DNN_ELTWISE_FTYPE)0; + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_add(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *a, LIBXSMM_DNN_ELTWISE_FTYPE *b, LIBXSMM_DNN_ELTWISE_FTYPE *c, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + c[i] = a[i] + b[i]; + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_mult(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *a, LIBXSMM_DNN_ELTWISE_FTYPE *b, LIBXSMM_DNN_ELTWISE_FTYPE *c, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + c[i] = a[i] * b[i]; + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + const LIBXSMM_DNN_ELTWISE_FTYPE exp_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[i]); + dst[i] = 1 / (1 + exp_value); + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + dst[i] = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double)src[i]); + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + dst[i] = (src[i] > 0.0f) ? src[i] : 0.0f; + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + const LIBXSMM_DNN_ELTWISE_FTYPE exp_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[i]); + const LIBXSMM_DNN_ELTWISE_FTYPE sig_exp = 1 / (1 + exp_value); + dst[i] = (1 - sig_exp)*sig_exp; + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + const LIBXSMM_DNN_ELTWISE_FTYPE tanh_value = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double)src[i]); + dst[i] = 1 - (tanh_value * tanh_value); + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + dst[i] = (LIBXSMM_DNN_ELTWISE_FTYPE)(src[i] > 0.0f ? 1.0f : 0.0f); + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_transpose(libxsmm_blasint rows, libxsmm_blasint cols, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* number of tasks that could be run in parallel */ + const libxsmm_blasint size = rows * cols; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + LIBXSMM_VLA_DECL(2, LIBXSMM_DNN_ELTWISE_FTYPE, src2D, src, cols); + LIBXSMM_VLA_DECL(2, LIBXSMM_DNN_ELTWISE_FTYPE, dst2D, dst, rows); + libxsmm_blasint job; + + for (job = thr_begin; job < thr_end; ++job) { + const libxsmm_blasint i = job / cols; + const libxsmm_blasint j = job % cols; + LIBXSMM_VLA_ACCESS(2, dst2D, j, i, rows) = LIBXSMM_VLA_ACCESS(2, src2D, i, j, cols); + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_copy(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + dst[i] = src[i]; + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + dst[i] = 1 - src[i]; + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_square(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + dst[i] = 1 - (src[i] * src[i]); + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size); + libxsmm_blasint i; + + for (i = thr_begin; i < thr_end; i++) { + dst[i] = -src[i]; + } +} + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_1D_2D(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint bm, libxsmm_blasint bn, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads) +{ + const int ltid = tid - start_thread; + /* compute chunk size */ + const libxsmm_blasint chunksize = (m % nthreads == 0) ? (m / nthreads) : (m / nthreads) + 1; + /* compute thr_begin and thr_end */ + const libxsmm_blasint thr_begin = (ltid * chunksize < m) ? (ltid * chunksize) : m; + const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, m); + libxsmm_blasint i, j; + LIBXSMM_VLA_DECL(4, LIBXSMM_DNN_ELTWISE_FTYPE, real_dst, (LIBXSMM_DNN_ELTWISE_FTYPE*)dst, m/bm, bn, bm); + + for (i = thr_begin; i < thr_end; i++) { + const libxsmm_blasint mb = i/bm; + const libxsmm_blasint ibm = i%bm; + for (j = 0; j < n; j++) { + const libxsmm_blasint nb = j/bn; + const libxsmm_blasint ibn = j%bn; + LIBXSMM_VLA_ACCESS(4, real_dst, nb, mb, ibn, ibm, m/bm, bn, bm) = src[i]; + } + } +} + + +/* #define LSTM_TIMING */ +#if defined(LSTM_TIMING) +extern double Gbl_t_input_total, Gbl_t_recur_total, Gbl_t_eltwise_total, Gbl_t_nonlin_total; +extern unsigned long long Gbl_t_input, Gbl_t_recur, Gbl_t_eltwise, Gbl_t_nonlin; +extern double Gbl_duration_input, Gbl_duration_recur, Gbl_duration_eltwise, Gbl_duration_nonlin; +#endif + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_zero_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + srcdst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)0; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_copy_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] = src[(j*ld)+i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_add_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] = src0[(j*ld)+i] + src1[(j*ld)+i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sub_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] = src0[(j*ld)+i] - src1[(j*ld)+i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] = src0[(j*ld)+i] * src1[(j*ld)+i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + srcdst[(j*ld)+i] *= src0[(j*ld)+i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_fma_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] += src0[(j*ld)+i] * src1[(j*ld)+i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_add_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + srcdst[(j*ld)+i] += colv[i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + srcdst[(j*ld)+i] = colv[i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_cvt_bf16_fp32_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, libxsmm_bfloat16 *colv) { + libxsmm_blasint i, j; + libxsmm_bfloat16_hp t; + + t.i[0] = 0; + for ( j = 0; j < n; ++j ) { + for ( i = 0; i < m; ++i ) { + t.i[1] = colv[i]; + srcdst[(j*ld)+i] = t.f; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_colvector_const_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv, LIBXSMM_DNN_ELTWISE_FTYPE const_bias) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + srcdst[(j*ld)+i] = colv[i] + const_bias; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_cvt_bf16_fp32_colvector_const_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, libxsmm_bfloat16 *colv, LIBXSMM_DNN_ELTWISE_FTYPE const_bias) { + libxsmm_blasint i, j; + libxsmm_bfloat16_hp t; + + t.i[0] = 0; + for ( j = 0; j < n; ++j ) { + for ( i = 0; i < m; ++i ) { + t.i[1] = colv[i]; + srcdst[(j*ld)+i] = t.f + const_bias; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + const LIBXSMM_DNN_ELTWISE_FTYPE mid_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[(j*ld)+i]); + dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)1 / ((LIBXSMM_DNN_ELTWISE_FTYPE)1 + mid_value); + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double) src[(j*ld)+i]); + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] = (src[(j*ld)+i] < 0) ? (LIBXSMM_DNN_ELTWISE_FTYPE)0 : src[(j*ld)+i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + LIBXSMM_DNN_ELTWISE_FTYPE exp_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[(j*ld)+i]); + LIBXSMM_DNN_ELTWISE_FTYPE mid_value = (LIBXSMM_DNN_ELTWISE_FTYPE)1 / ((LIBXSMM_DNN_ELTWISE_FTYPE)1 + exp_value); + dst[(j*ld)+i] = ((LIBXSMM_DNN_ELTWISE_FTYPE)1 - mid_value) * mid_value; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + LIBXSMM_DNN_ELTWISE_FTYPE tanh_value = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double) src[(j*ld)+i]); + dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)1 - (tanh_value * tanh_value); + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] = (src[(j*ld)+i] < 0) ? (LIBXSMM_DNN_ELTWISE_FTYPE)0 : (LIBXSMM_DNN_ELTWISE_FTYPE)1; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + LIBXSMM_DNN_ELTWISE_FTYPE exp_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[(j*ld)+i]); + LIBXSMM_DNN_ELTWISE_FTYPE mid_value = (LIBXSMM_DNN_ELTWISE_FTYPE)1 / ((LIBXSMM_DNN_ELTWISE_FTYPE)1 + exp_value); + dst[(j*ld)+i] *= ((LIBXSMM_DNN_ELTWISE_FTYPE)1 - mid_value) * mid_value; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + LIBXSMM_DNN_ELTWISE_FTYPE tanh_value = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double) src[(j*ld)+i]); + dst[(j*ld)+i] *= (LIBXSMM_DNN_ELTWISE_FTYPE)1 - (tanh_value * tanh_value); + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] *= (src[(j*ld)+i] < 0) ? (LIBXSMM_DNN_ELTWISE_FTYPE)0 : (LIBXSMM_DNN_ELTWISE_FTYPE)1; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)1 - src[(j*ld)+i]; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_square_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i = 0, j; + + for ( j = 0; j < n; ++j ) { + LIBXSMM_PRAGMA_SIMD + for ( i = 0; i < m; ++i ) { + dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)1 - (src[(j*ld)+i] * src[(j*ld)+i]); + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_rne_mask_fp32_bfp16_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, float* src, float* dst) { + libxsmm_blasint i,j; + + /* rnaz buffer to bfp16 */ + for ( j = 0; j < n; ++j ) { + for ( i = 0; i < m; ++i ) { + unsigned int int_round = 0; + unsigned int do_round = 1; + const void *const ptr = &int_round; + + int_round = *((unsigned int*)&(src[(j*ld)+i])); + + /* we don't round NaN and inf */ + if ( (int_round & 0x7f800000) == 0x7f800000 ) { + do_round = 0; + } + + /* perform round nearest tie even */ + if ( do_round != 0 ) { + unsigned int fixup = (int_round >> 16) & 1; + int_round = int_round + 0x00007fff + fixup; + } + + /* chop bits to create BFP16 in FP32 */ + int_round = int_round & 0xffff0000; + + dst[(j*ld)+i] = *((float*)ptr); + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_rne_cvt_fp32_bfp16_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, float* src, libxsmm_bfloat16* dst) { + libxsmm_blasint i,j; + + /* truncate buffer to bfp16 */ + for ( j = 0; j < n; ++j ) { + for ( i = 0; i < m; ++i ) { + unsigned int int_round = 0; + unsigned int do_round = 1; + int_round = *((unsigned int*)&(src[(j*ld)+i])); + /* we don't round NaN and inf */ + if ( (int_round & 0x7f800000) == 0x7f800000 ) { + do_round = 0; + } + /* perform round nearest tie even */ + if ( do_round != 0 ) { + unsigned int fixup = (int_round >> 16) & 1; + int_round = int_round + 0x00007fff + fixup; + } + /* create the bfp16 value by shifting out the lower 16bits */ + int_round = int_round >> 16; + dst[(j*ld)+i] = (unsigned short)int_round; + } + } +} + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_cvt_bf16_fp32_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, libxsmm_bfloat16 *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) { + libxsmm_blasint i, j; + libxsmm_bfloat16_hp t; + + t.i[0] = 0; + for ( j = 0; j < n; ++j ) { + for ( i = 0; i < m; ++i ) { + t.i[1] = src[(j*ld)+i]; + dst[(j*ld)+i] = t.f; + } + } +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_elementwise.h b/third_party/libxsmm/src/libxsmm_dnn_elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..aea0b12936b3760c68c3e6737cdcb3863abbf2c8 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_elementwise.h @@ -0,0 +1,65 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Kunal Banerjee, Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_ELEMENTWISE_H +#define LIBXSMM_DNN_ELEMENTWISE_H + +#include + +#if !defined(LIBXSMM_DNN_ELTWISE_FTYPE) +# define LIBXSMM_DNN_ELTWISE_FTYPE float +#endif + + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_zero(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_add(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *a, LIBXSMM_DNN_ELTWISE_FTYPE *b, LIBXSMM_DNN_ELTWISE_FTYPE *c, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_mult(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *a, LIBXSMM_DNN_ELTWISE_FTYPE *b, LIBXSMM_DNN_ELTWISE_FTYPE *c, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_transpose(libxsmm_blasint rows, libxsmm_blasint cols, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_copy(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_square(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_1D_2D(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint bm, libxsmm_blasint bn, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads); + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_zero_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_add_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sub_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_copy_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_fma_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_add_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_colvector_const_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv, LIBXSMM_DNN_ELTWISE_FTYPE const_bias); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_cvt_bf16_fp32_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, libxsmm_bfloat16 *colv); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_cvt_bf16_fp32_colvector_const_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, libxsmm_bfloat16 *colv, LIBXSMM_DNN_ELTWISE_FTYPE const_bias); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); + +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_square_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_rne_mask_fp32_bfp16_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, float* src, float* dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_rne_cvt_fp32_bfp16_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, float* src, libxsmm_bfloat16* dst); +LIBXSMM_API_INTERN void libxsmm_internal_matrix_cvt_bf16_fp32_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, libxsmm_bfloat16 *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst); +#endif /*LIBXSMM_DNN_ELEMENTWISE_H*/ + diff --git a/third_party/libxsmm/src/libxsmm_dnn_fullyconnected.c b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected.c new file mode 100644 index 0000000000000000000000000000000000000000..9fde7fce91cb5a85f7517cc94aa64ebca3d44bfd --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected.c @@ -0,0 +1,1514 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_fullyconnected_backward_weight_update.h" +#include "libxsmm_dnn_fullyconnected_forward.h" +#include "libxsmm_main.h" + +LIBXSMM_API libxsmm_dnn_fullyconnected* libxsmm_dnn_create_fullyconnected(libxsmm_dnn_fullyconnected_desc fullyconnected_desc, libxsmm_dnn_err_t* status) { + libxsmm_dnn_fullyconnected* handle = 0; + + /* init libxsmm */ + LIBXSMM_INIT + + if ( ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) || + ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || + ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + handle = (libxsmm_dnn_fullyconnected*)calloc(1, sizeof(libxsmm_dnn_fullyconnected)); + + if (0 != handle) { + *status = LIBXSMM_DNN_SUCCESS; + /* let's make the description persistent */ + handle->desc = fullyconnected_desc; + handle->target_archid = libxsmm_target_archid; + if ( ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && ((handle->desc.C % 16 != 0) || (handle->desc.K % 16 != 0)) ) { + handle->target_archid = LIBXSMM_X86_AVX512_CPX; + } + + /* @TODO perhaps we need a better switch here */ + if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) { + handle->bk = handle->desc.bk; + handle->bn = handle->desc.bn; + handle->bc = handle->desc.bc; + + if ( handle->desc.N % handle->bn != 0 ) { + handle->bn = handle->desc.N; + *status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING; + } + if ( handle->desc.C % handle->bc != 0 ) { + handle->bc = handle->desc.C; + *status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING; + } + if ( handle->desc.K % handle->bk != 0 ) { + handle->bk = handle->desc.K; + *status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING; + } + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { +#if 0 + handle->fwd_bf = atoi(getenv("FWD_BF")); + handle->bwd_bf = atoi(getenv("BWD_BF")); + handle->upd_bf = atoi(getenv("UPD_BF")); + handle->fwd_2d_blocking = atoi(getenv("FWD_2D_BLOCKING")); + handle->bwd_2d_blocking = atoi(getenv("BWD_2D_BLOCKING")); + handle->upd_2d_blocking = atoi(getenv("UPD_2D_BLOCKING")); + handle->fwd_row_teams = atoi(getenv("FWD_ROW_TEAMS")); + handle->fwd_column_teams = atoi(getenv("FWD_COLUMN_TEAMS")); + handle->bwd_row_teams = atoi(getenv("BWD_ROW_TEAMS")); + handle->bwd_column_teams = atoi(getenv("BWD_COLUMN_TEAMS")); + handle->upd_row_teams = atoi(getenv("UPD_ROW_TEAMS")); + handle->upd_column_teams = atoi(getenv("UPD_COLUMN_TEAMS")); + handle->ifm_subtasks = atoi(getenv("IFM_SUBTASKS")); + handle->ofm_subtasks = atoi(getenv("OFM_SUBTASKS")); +#else + /* Initialize with default values */ + handle->fwd_bf = 1; + handle->bwd_bf = 1; + handle->upd_bf = 1; + handle->fwd_2d_blocking = 0; + handle->bwd_2d_blocking = 0; + handle->upd_2d_blocking = 0; + handle->fwd_row_teams = 1; + handle->fwd_column_teams = 1; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1; + handle->ofm_subtasks = 1; + + if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 28) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 1; + handle->fwd_row_teams = 14; + handle->fwd_column_teams = 2; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 28) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 1; + handle->fwd_row_teams = 7; + handle->fwd_column_teams = 4; + handle->bwd_bf = ((handle->desc.K/handle->bk) % 8 == 0) ? 8 : 1; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 7; + handle->bwd_column_teams = 4; + handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 7; + handle->upd_column_teams = 4; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 28) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 1; + handle->fwd_column_teams = 1; + handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 28) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 1; + handle->fwd_column_teams = 1; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 1; + handle->bwd_row_teams = 14; + handle->bwd_column_teams = 2; + handle->upd_bf = ((handle->desc.N/handle->bn) % 2 == 0) ? 2 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 20) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 1; + handle->bwd_row_teams = 5; + handle->bwd_column_teams = 4; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 5; + handle->upd_column_teams = 4; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 20) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 1; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_bf = ((handle->desc.N/handle->bn) % 9 == 0) ? 9 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = ((handle->bk % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + } + + if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 24) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 6; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 6; + handle->bwd_column_teams = 4; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 6; + handle->upd_column_teams = 4; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 24) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 1; + handle->bwd_row_teams = 12; + handle->bwd_column_teams = 2; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 5; + handle->upd_column_teams = 4; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 24) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 5; + handle->bwd_column_teams = 4; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 5; + handle->upd_column_teams = 4; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 20) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 1; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 24) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 5; + handle->bwd_column_teams = 4; + handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 5; + handle->upd_column_teams = 4; + handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 20) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 6; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 1; + handle->bwd_row_teams = 5; + handle->bwd_column_teams = 4; + handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 6; + handle->upd_column_teams = 4; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } +#endif + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { +#if 0 + handle->fwd_bf = atoi(getenv("FWD_BF")); + handle->bwd_bf = atoi(getenv("BWD_BF")); + handle->upd_bf = atoi(getenv("UPD_BF")); + handle->fwd_2d_blocking = atoi(getenv("FWD_2D_BLOCKING")); + handle->bwd_2d_blocking = atoi(getenv("BWD_2D_BLOCKING")); + handle->upd_2d_blocking = atoi(getenv("UPD_2D_BLOCKING")); + handle->fwd_row_teams = atoi(getenv("FWD_ROW_TEAMS")); + handle->fwd_column_teams = atoi(getenv("FWD_COLUMN_TEAMS")); + handle->bwd_row_teams = atoi(getenv("BWD_ROW_TEAMS")); + handle->bwd_column_teams = atoi(getenv("BWD_COLUMN_TEAMS")); + handle->upd_row_teams = atoi(getenv("UPD_ROW_TEAMS")); + handle->upd_column_teams = atoi(getenv("UPD_COLUMN_TEAMS")); + handle->ifm_subtasks = atoi(getenv("IFM_SUBTASKS")); + handle->ofm_subtasks = atoi(getenv("OFM_SUBTASKS")); +#else + if (handle->desc.compressed_A > 0) { + handle->compressed_A = 1; + handle->sparsity_factor_A = handle->desc.sparsity_factor_A; + } + + /* Initialize with default values */ + handle->fwd_bf = 1; + handle->bwd_bf = 1; + handle->upd_bf = 1; + handle->fwd_2d_blocking = 0; + handle->bwd_2d_blocking = 0; + handle->upd_2d_blocking = 0; + handle->fwd_row_teams = 1; + handle->fwd_column_teams = 1; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1; + handle->ofm_subtasks = 1; + + if (handle->desc.threads == 14) { + handle->fwd_bf = 1; + handle->bwd_bf = 1; + handle->upd_bf = 1; + handle->fwd_2d_blocking = 1; + handle->bwd_2d_blocking = 1; + handle->upd_2d_blocking = 0; + handle->fwd_row_teams = 2; + handle->fwd_column_teams = 7; + handle->bwd_row_teams = 2; + handle->bwd_column_teams = 7; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1; + handle->ofm_subtasks = 1; + } + + if (handle->desc.threads == 2) { + handle->fwd_bf = 1; + handle->bwd_bf = 1; + handle->upd_bf = 1; + handle->fwd_2d_blocking = 1; + handle->bwd_2d_blocking = 1; + handle->upd_2d_blocking = 0; + handle->fwd_row_teams = 2; + handle->fwd_column_teams = 1; + handle->bwd_row_teams = 2; + handle->bwd_column_teams = 1; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1; + handle->ofm_subtasks = 1; + } + + if (handle->desc.threads == 4) { + handle->fwd_bf = 1; + handle->bwd_bf = 1; + handle->upd_bf = 1; + handle->fwd_2d_blocking = 1; + handle->bwd_2d_blocking = 1; + handle->upd_2d_blocking = 0; + handle->fwd_row_teams = 2; + handle->fwd_column_teams = 2; + handle->bwd_row_teams = 2; + handle->bwd_column_teams = 2; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1; + handle->ofm_subtasks = 1; + } + + if (handle->desc.threads == 8) { + handle->fwd_bf = 1; + handle->bwd_bf = 1; + handle->upd_bf = 1; + handle->fwd_2d_blocking = 1; + handle->bwd_2d_blocking = 1; + handle->upd_2d_blocking = 0; + handle->fwd_row_teams = 2; + handle->fwd_column_teams = 4; + handle->bwd_row_teams = 2; + handle->bwd_column_teams = 4; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1; + handle->ofm_subtasks = 1; + } + + if (handle->desc.threads == 16) { + handle->fwd_bf = 1; + handle->bwd_bf = 1; + handle->upd_bf = 1; + handle->fwd_2d_blocking = 1; + handle->bwd_2d_blocking = 1; + handle->upd_2d_blocking = 0; + handle->fwd_row_teams = 2; + handle->fwd_column_teams = 8; + handle->bwd_row_teams = 2; + handle->bwd_column_teams = 8; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1; + handle->ofm_subtasks = 1; + } + + if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 28) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 1; + handle->fwd_row_teams = 14; + handle->fwd_column_teams = 2; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 28) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 1; + handle->fwd_row_teams = 7; + handle->fwd_column_teams = 4; + handle->bwd_bf = ((handle->desc.K/handle->bk) % 8 == 0) ? 8 : 1; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 7; + handle->bwd_column_teams = 4; + handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 7; + handle->upd_column_teams = 4; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 28) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 1; + handle->fwd_column_teams = 1; + handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 28) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 1; + handle->fwd_column_teams = 1; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 1; + handle->bwd_row_teams = 14; + handle->bwd_column_teams = 2; + handle->upd_bf = ((handle->desc.N/handle->bn) % 2 == 0) ? 2 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 20) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 1; + handle->bwd_row_teams = 5; + handle->bwd_column_teams = 4; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 5; + handle->upd_column_teams = 4; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + + if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 20) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 1; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_bf = ((handle->desc.N/handle->bn) % 9 == 0) ? 9 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = ((handle->bk % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + } + + if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 24) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 6; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 6; + handle->bwd_column_teams = 4; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 6; + handle->upd_column_teams = 4; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 24) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 1; + handle->bwd_row_teams = 12; + handle->bwd_column_teams = 2; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 5; + handle->upd_column_teams = 4; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 24) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 5; + handle->bwd_column_teams = 4; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 5; + handle->upd_column_teams = 4; + handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 20) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 1; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 1; + handle->bwd_column_teams = 1; + handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 1; + handle->upd_column_teams = 1; + handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 24) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 5; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 0; + handle->bwd_row_teams = 5; + handle->bwd_column_teams = 4; + handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 5; + handle->upd_column_teams = 4; + handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } + if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 20) { + handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/; + handle->fwd_2d_blocking = 0; + handle->fwd_row_teams = 6; + handle->fwd_column_teams = 4; + handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/; + handle->bwd_2d_blocking = 1; + handle->bwd_row_teams = 5; + handle->bwd_column_teams = 4; + handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/; + handle->upd_2d_blocking = 0; + handle->upd_row_teams = 6; + handle->upd_column_teams = 4; + handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/; + } +#endif + + /* In this case force 2D decomposition */ + if (handle->compressed_A == 1) { + handle->fwd_2d_blocking = 1; + handle->fwd_row_teams = 2; + while (handle->desc.threads % handle->fwd_row_teams != 0) { + handle->fwd_row_teams--; + } + handle->fwd_column_teams = handle->desc.threads/handle->fwd_row_teams; + } + + } + } else { + /* check that we cannot fuse */ + if ( handle->desc.fuse_ops != LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { + free( handle ); + *status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + return 0; + } + + /* we need to compute the memory layout given the */ + if ( (handle->desc.C % 16 == 0) && (handle->desc.K % 16 == 0) ) { + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K, + &(handle->ifmblock), &(handle->ofmblock), &(handle->fm_lp_block), + LIBXSMM_DNN_DATATYPE_F32, LIBXSMM_DNN_DATATYPE_F32 ); + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K, + &(handle->ifmblock), &(handle->ofmblock), &(handle->fm_lp_block), + handle->desc.datatype_in, handle->desc.datatype_out ); + } else { + /* should not happen, not implemented */ + } + } else if ( (handle->desc.C % 64 == 0) && (handle->desc.K == 1000) ) { + /* @TODO this a hack for the last FC layer */ + handle->ifmblock = 64; + handle->fm_lp_block = 1; + handle->ofmblock = 10; + } else if ( (handle->desc.C % 16 == 0) && (handle->desc.K == 1000) ) { + /* @TODO this a hack for the last FC layer */ + handle->ifmblock = 16; + handle->fm_lp_block = 1; + handle->ofmblock = 10; + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + free( handle ); + return 0; + } + /* compute the outer blocks */ + handle->blocksifm = handle->desc.C / handle->ifmblock; + handle->blocksofm = handle->desc.K / handle->ofmblock; + } + /* create barrier */ + handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1); + + /* If in SPR, generate tilerelease kernel */ + if ((handle->target_archid >= LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) { + int l_tr_flags = LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + handle->tilerelease_kernel = libxsmm_bsmmdispatch(handle->bk, handle->bk, handle->bk, NULL, NULL, NULL, NULL, NULL, &l_tr_flags, NULL); + } + /* calculate scratch size */ + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + handle->scratch_size = sizeof(float) * ( ( (size_t)handle->desc.C * (size_t)handle->desc.N ) + ( (size_t)handle->desc.C * (size_t)handle->desc.K ) ); + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + /* Let's allocate maximum required scratch */ + size_t size_fwd = sizeof(float) * LIBXSMM_MAX(handle->desc.K * handle->desc.N, handle->desc.threads * LIBXSMM_MAX(handle->bk * handle->bn, handle->desc.K)); + /* In case of K = 1 we pad A and B to "bk=2" */ + size_t size_bwd = (handle->desc.K != 1) ? ( sizeof(float) * LIBXSMM_MAX(handle->desc.C * handle->desc.N, handle->desc.threads * handle->bc * handle->bn) + sizeof(libxsmm_bfloat16) * handle->desc.C * handle->desc.K ) : ( sizeof(float) * handle->desc.C * handle->desc.N + sizeof(libxsmm_bfloat16) * handle->desc.C * 2 + sizeof(libxsmm_bfloat16) * 2 * handle->desc.N ); + size_t size_upd = sizeof(float) * LIBXSMM_MAX(handle->desc.C * handle->desc.K, handle->desc.threads * handle->bc * handle->bk) + sizeof(libxsmm_bfloat16) * handle->desc.threads * handle->bk * handle->bc + sizeof(libxsmm_bfloat16) * (handle->desc.N * (handle->desc.C + handle->desc.K)); + if (handle->compressed_A == 1) { + size_fwd += handle->desc.threads * handle->desc.C * handle->bk *sizeof(libxsmm_bfloat16); + } + handle->scratch_size = LIBXSMM_MAX(LIBXSMM_MAX(size_fwd, size_bwd), size_upd); + handle->doutput_scratch_mark = handle->scratch_size; + handle->scratch_size += 2 * sizeof(libxsmm_bfloat16) * handle->desc.N * handle->desc.K; + } else { + handle->scratch_size = sizeof(float) * ( (((size_t)handle->desc.C + (size_t)handle->desc.K) * (size_t)handle->desc.N) + ((size_t)handle->desc.C * (size_t)handle->desc.K) ); + } + /* create code pointers in some special cases */ + if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) && ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) ) { + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + float alpha = 1.0f; + /* beta is set to 1 for ncnc kcck format because ifm is split into 2 blocks */ + float beta = 1.0f; + float zerobeta = 0.0f; + int updflags = LIBXSMM_GEMM_FLAGS( 'N', 'T' ); + /* For UPD kernels we consider subtasking... */ + libxsmm_blasint M = handle->bk/handle->ofm_subtasks; + libxsmm_blasint N = handle->bc/handle->ifm_subtasks; + + libxsmm_blasint lda = (libxsmm_blasint)handle->bk; + libxsmm_blasint ldb = (libxsmm_blasint)handle->bc; + libxsmm_blasint ldc = (libxsmm_blasint)handle->bk; + + handle->gemm_fwd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(float), handle->bc*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL); + handle->gemm_fwd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(float), handle->bc*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL); + handle->gemm_bwd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(float), handle->bk*handle->bn*sizeof(float), &ldb, &lda, &ldb, &alpha, &beta, NULL, NULL); + handle->gemm_bwd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(float), handle->bk*handle->bn*sizeof(float), &ldb, &lda, &ldb, &alpha, &zerobeta, NULL, NULL); + + /* Transpose kernel used for weight transpose in bwd pass */ + handle->tr_kernel = libxsmm_dispatch_meltw_unary((libxsmm_blasint)(handle->bk), (libxsmm_blasint)(handle->bc), (const libxsmm_blasint*)&(handle->bk), (const libxsmm_blasint*)&(handle->bc), LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + + /* update has different LDs */ + lda = (libxsmm_blasint)handle->bk; + ldb = (libxsmm_blasint)handle->bc; + ldc = (libxsmm_blasint)handle->bk; + handle->gemm_upd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(M, N, handle->bn, handle->desc.K*handle->bn*sizeof(float), handle->desc.C*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &beta, &updflags, NULL); + handle->gemm_upd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(M, N, handle->bn, handle->desc.K*handle->bn*sizeof(float), handle->desc.C*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &zerobeta, &updflags, NULL); + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + float alpha = 1.0f; + float beta = 1.0f; + float zerobeta = 0.0f; + /* For UPD kernels we consider subtasking... */ + libxsmm_blasint M = handle->bk/handle->ofm_subtasks; + libxsmm_blasint N = handle->bc/handle->ifm_subtasks; + + libxsmm_blasint lda = (libxsmm_blasint)handle->bk; + libxsmm_blasint ldb = (libxsmm_blasint)handle->bc; + libxsmm_blasint ldc = (libxsmm_blasint)handle->bk; + + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) { + libxsmm_meltw_flags fusion_flags; + int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG; + int l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + libxsmm_blasint unroll_hint = (handle->desc.C/handle->bc)/handle->fwd_bf; + + handle->gemm_fwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &beta, &l_flags, NULL); + handle->gemm_fwd2.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL); + handle->fwd_config_kernel = libxsmm_bsmmdispatch(handle->bk, handle->bn, handle->bc, &lda, &ldb, &ldc, NULL, &beta, &l_tc_flags, NULL); + handle->gemm_fwd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL); + fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_OVERWRITE_C; + handle->gemm_fwd4.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0); + fusion_flags = LIBXSMM_MELTW_FLAG_ACT_RELU_OVERWRITE_C; + handle->gemm_fwd5.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0); + fusion_flags = LIBXSMM_MELTW_FLAG_ACT_SIGM_OVERWRITE_C; + handle->gemm_fwd6.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0); + fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_ACT_RELU_OVERWRITE_C; + handle->gemm_fwd7.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0); + fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_ACT_SIGM_OVERWRITE_C; + handle->gemm_fwd8.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0); + + if (handle->compressed_A == 1) { + fusion_flags = LIBXSMM_MELTW_FLAG_FUSE_NONE; + handle->gemm_fwd9.xgemm.bsmrs_meltwfused = libxsmm_bsmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &beta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0); + handle->gemm_fwd10.xgemm.bsmrs_meltwfused = libxsmm_bsmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0); + handle->fwd_config_kernel = libxsmm_bsmmdispatch(handle->bk, handle->bn, handle->bc, &lda, &ldb, &ldc, NULL, &beta, &l_tc_flags, NULL); + handle->gemm_fwd11.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0); + fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_OVERWRITE_C; + handle->gemm_fwd12.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0); + fusion_flags = LIBXSMM_MELTW_FLAG_ACT_RELU_OVERWRITE_C; + handle->gemm_fwd13.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0); + fusion_flags = LIBXSMM_MELTW_FLAG_ACT_SIGM_OVERWRITE_C; + handle->gemm_fwd14.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0); + fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_ACT_RELU_OVERWRITE_C; + handle->gemm_fwd15.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0); + fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_ACT_SIGM_OVERWRITE_C; + handle->gemm_fwd16.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0); + } + + /* Also JIT eltwise functions... */ + handle->fwd_cvtfp32bf16_kernel = libxsmm_dispatch_meltw_unary(handle->bk, handle->bn, &ldc, &ldc, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY); + handle->fwd_cvtfp32bf16_relu_kernel = libxsmm_dispatch_meltw_unary(handle->bk, handle->bn, &ldc, &ldc, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_BITMASK, LIBXSMM_MELTW_TYPE_UNARY_RELU); + handle->fwd_sigmoid_cvtfp32bf16_kernel = libxsmm_dispatch_meltw_unary(handle->bk, handle->bn, &ldc, &ldc, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_SIGMOID); + } else { + handle->gemm_fwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL); + handle->gemm_fwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL); + handle->gemm_fwd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL); + } + + /* Special bwd kernels for K == 1 */ + if (handle->desc.K == 1) { + libxsmm_blasint _bk = 2; + handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bc, handle->bn, _bk, _bk*handle->bc*sizeof(libxsmm_bfloat16), _bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &_bk, &ldb, &alpha, &beta, NULL, NULL); + handle->gemm_bwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bc, handle->bn, _bk, _bk*handle->bc*sizeof(libxsmm_bfloat16), _bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &_bk, &ldb, &alpha, &zerobeta, NULL, NULL); + } else { + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) { + int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG; + int l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + libxsmm_blasint unroll_hint = (handle->desc.K/handle->bk)/handle->bwd_bf; + handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &ldb, &lda, &ldb, &alpha, &beta, &l_flags, NULL); + handle->gemm_bwd2.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &ldb, &lda, &ldb, &alpha, &zerobeta, &l_flags, NULL); + handle->bwd_config_kernel = libxsmm_bsmmdispatch(handle->bc, handle->bn, handle->bk, &ldb, &lda, &ldb, NULL, &beta, &l_tc_flags, NULL); + handle->gemm_bwd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd_unroll(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &ldb, &lda, &ldb, &alpha, &zerobeta, &l_flags, NULL); + /* Also JIT eltwise functions... */ + handle->bwd_cvtfp32bf16_kernel = libxsmm_dispatch_meltw_unary(handle->bc, handle->bn, &ldb, &ldb, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY); + handle->bwd_relu_kernel = libxsmm_dispatch_meltw_unary(handle->bc, handle->bn, &ldb, &ldb, LIBXSMM_DATATYPE_BF16, LIBXSMM_DATATYPE_BF16, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_BITMASK, LIBXSMM_MELTW_TYPE_UNARY_RELU_INV); + } else { + handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &lda, &ldb, &alpha, &beta, NULL, NULL); + handle->gemm_bwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &lda, &ldb, &alpha, &zerobeta, NULL, NULL); + } + } + lda = (libxsmm_blasint)handle->bk; + ldb = (libxsmm_blasint)handle->bn; + ldc = (libxsmm_blasint)handle->bk; + if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) { + int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG; + int l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + libxsmm_blasint unroll_hint = (handle->desc.N/handle->bn)/handle->upd_bf; + handle->gemm_upd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &beta, &l_flags, NULL); + handle->gemm_upd2.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL); + handle->upd_config_kernel = libxsmm_bsmmdispatch(M, N, handle->bn, &lda, &ldb, &ldc, NULL, &beta, &l_tc_flags, NULL); + l_flags = l_flags | LIBXSMM_GEMM_FLAG_VNNI_C; + handle->gemm_upd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd_unroll(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL); + } else { + handle->gemm_upd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL); + handle->gemm_upd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL); + + } + } else { + + } + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + } + } else { + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + + return handle; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fullyconnected(const libxsmm_dnn_fullyconnected* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + /* Deallocate barrier */ + if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); } + /* deallocate handle structure */ + free(/*remove constness*/(libxsmm_dnn_fullyconnected*)handle); + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fullyconnected_create_tensor_datalayout(const libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor_datalayout* layout; + + *status = LIBXSMM_DNN_SUCCESS; + layout = 0; + + if (handle != 0) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout)); + + if (layout != 0) { + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->format = handle->desc.buffer_format; + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = 1; + layout->dim_size[2] = 1; + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = 1; + layout->dim_size[2] = 1; + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = 1; + layout->dim_size[2] = 1; + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->datatype = handle->desc.datatype_out; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = 1; + layout->dim_size[2] = 1; + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || + ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || + ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->dim_size[0] = handle->desc.C; + layout->dim_size[1] = 1; + layout->dim_size[2] = 1; + layout->dim_size[3] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->desc.K; + layout->dim_size[1] = 1; + layout->dim_size[2] = 1; + layout->dim_size[3] = handle->desc.N; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || + ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = (unsigned int)handle->bc; + layout->dim_size[1] = (unsigned int)handle->bn; + layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc); + layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn); + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)handle->bn; + layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn); + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) || (type == LIBXSMM_DNN_FILTER) ) { + layout->format = handle->desc.filter_format; + layout->tensor_type = LIBXSMM_DNN_FILTER; + + if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 6; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->ifmblock; + layout->dim_size[2] = 1; + layout->dim_size[3] = 1; + layout->dim_size[4] = handle->blocksifm; + layout->dim_size[5] = handle->blocksofm; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ( ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) || + ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_BF16; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 7; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = handle->fm_lp_block; + layout->dim_size[1] = handle->ofmblock; + layout->dim_size[2] = handle->ifmblock/handle->fm_lp_block; + layout->dim_size[3] = 1; + layout->dim_size[4] = 1; + layout->dim_size[5] = handle->blocksifm; + layout->dim_size[6] = handle->blocksofm; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || + ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || + ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R; + layout->dim_size[0] = handle->ofmblock * handle->blocksofm; + layout->dim_size[1] = handle->ifmblock * handle->blocksifm; + layout->dim_size[2] = 1; + layout->dim_size[3] = 1; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + + if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)handle->bc; + layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_BF16; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 5; + + if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = (unsigned int)2; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)handle->bc/2; + layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc); + layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk); + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) || (type == LIBXSMM_DNN_CHANNEL_BIAS) ) { + layout->format = handle->desc.buffer_format; + layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR; + + if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) { + if ( (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) || (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + layout->datatype = handle->desc.datatype_out; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 2; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)(handle->desc.K / handle->bk); + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else if ( (type == LIBXSMM_DNN_RELU_MASK) ) { + layout->format = handle->desc.buffer_format; + layout->tensor_type = LIBXSMM_DNN_RELU_MASK; + + if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_I8; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 1; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = handle->desc.N * handle->desc.K; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return layout; +} + +LIBXSMM_API size_t libxsmm_dnn_fullyconnected_get_scratch_size(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status) { + size_t l_scratch_size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */ + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return l_scratch_size; +} + + +LIBXSMM_API void* libxsmm_dnn_fullyconnected_get_scratch_ptr(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status) +{ + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + return handle->scratch; + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return 0; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_scratch(libxsmm_dnn_fullyconnected* handle, const void* scratch) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + uintptr_t address = (uintptr_t)scratch; + size_t offset = 0; + + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + + if (0 != handle) { + /* align the internal scratch buffer if needed */ + if (address % 64 == 0) { + handle->scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch = (void*)(address+offset); + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_scratch(libxsmm_dnn_fullyconnected* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + handle->scratch = 0; + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) && + (type != LIBXSMM_DNN_RELU_MASK) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0 && tensor != 0) { + libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout(handle, type, &status); + + if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) { + handle->reg_filter = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) { + handle->grad_filter = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) { + handle->reg_bias = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) { + handle->grad_bias = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RELU_MASK ) { + handle->relumask = (libxsmm_dnn_tensor*)tensor; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR; + } + + libxsmm_dnn_destroy_tensor_datalayout( handle_layout ); + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fullyconnected_get_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor* return_tensor = 0; + + *status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) && + (type != LIBXSMM_DNN_RELU_MASK) ) { + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return return_tensor; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + return_tensor = handle->reg_input; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + return_tensor = handle->grad_input; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + return_tensor = handle->reg_output; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + return_tensor = handle->grad_output; + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) { + return_tensor = handle->reg_filter; + } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) { + return_tensor = handle->grad_filter; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) { + return_tensor = handle->reg_bias; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) { + return_tensor = handle->grad_bias; + } else if ( type == LIBXSMM_DNN_RELU_MASK ) { + return_tensor = handle->relumask; + } else { + /* cannot happen */ + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return return_tensor; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) && + (type != LIBXSMM_DNN_RELU_MASK) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) { + handle->reg_filter = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) { + handle->grad_filter = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) { + handle->reg_bias = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) { + handle->grad_bias = 0; + } else if ( type == LIBXSMM_DNN_RELU_MASK ) { + handle->relumask = 0; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_execute_st(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + LIBXSMM_UNUSED( start_thread ); + LIBXSMM_UNUSED( tid ); + + if (0 != handle) { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) ) { + status = libxsmm_dnn_fullyconnected_st_fwd_custom( handle, start_thread, tid ); + } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) { + status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FC; + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: { + if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) ) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_custom( handle, kind, start_thread, tid ); + } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck( handle, kind, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FC; + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_backward_weight_update.c b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_backward_weight_update.c new file mode 100644 index 0000000000000000000000000000000000000000..d985dc37c34af03f98df268f35b8964547b72376 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_backward_weight_update.c @@ -0,0 +1,1281 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_fullyconnected_backward_weight_update.h" +#include "libxsmm_main.h" + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); + +#if 0 +#define USE_CLDEMOTE +#endif + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void bf16_vnni_transpose_16x16(void* source_void, void* dest_void, int source_stride, int dest_stride) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + libxsmm_bfloat16 *source = (libxsmm_bfloat16*)source_void; + libxsmm_bfloat16 *dest = (libxsmm_bfloat16*)dest_void; + __m512i zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + __m512i tmp0, tmp1, tmp2, tmp3; + const __m512i abcdefgh_to_abefcdgh = _mm512_set4_epi32(0x0f0e0b0a, 0x0d0c0908, 0x07060302, 0x05040100); + + zmm0 = _mm512_loadu_si512(source); + zmm1 = _mm512_loadu_si512(source + source_stride); + zmm2 = _mm512_loadu_si512(source + source_stride*2); + zmm3 = _mm512_loadu_si512(source + source_stride*3); + zmm4 = _mm512_loadu_si512(source + source_stride*4); + zmm5 = _mm512_loadu_si512(source + source_stride*5); + zmm6 = _mm512_loadu_si512(source + source_stride*6); + zmm7 = _mm512_loadu_si512(source + source_stride*7); + + zmm0 = _mm512_shuffle_epi8(zmm0, abcdefgh_to_abefcdgh); + zmm1 = _mm512_shuffle_epi8(zmm1, abcdefgh_to_abefcdgh); + zmm2 = _mm512_shuffle_epi8(zmm2, abcdefgh_to_abefcdgh); + zmm3 = _mm512_shuffle_epi8(zmm3, abcdefgh_to_abefcdgh); + zmm4 = _mm512_shuffle_epi8(zmm4, abcdefgh_to_abefcdgh); + zmm5 = _mm512_shuffle_epi8(zmm5, abcdefgh_to_abefcdgh); + zmm6 = _mm512_shuffle_epi8(zmm6, abcdefgh_to_abefcdgh); + zmm7 = _mm512_shuffle_epi8(zmm7, abcdefgh_to_abefcdgh); + + tmp0 = _mm512_unpacklo_epi64(zmm0, zmm1); + tmp1 = _mm512_unpackhi_epi64(zmm0, zmm1); + tmp2 = _mm512_unpacklo_epi64(zmm2, zmm3); + tmp3 = _mm512_unpackhi_epi64(zmm2, zmm3); + zmm0 = _mm512_unpacklo_epi64(zmm4, zmm5); + zmm1 = _mm512_unpackhi_epi64(zmm4, zmm5); + zmm2 = _mm512_unpacklo_epi64(zmm6, zmm7); + zmm3 = _mm512_unpackhi_epi64(zmm6, zmm7); + + zmm4 = _mm512_shuffle_i32x4(tmp0, tmp2, 0x88); + zmm6 = _mm512_shuffle_i32x4(tmp0, tmp2, 0xdd); + zmm5 = _mm512_shuffle_i32x4(tmp1, tmp3, 0x88); + zmm7 = _mm512_shuffle_i32x4(tmp1, tmp3, 0xdd); + tmp0 = _mm512_shuffle_i32x4(zmm0, zmm2, 0x88); + tmp1 = _mm512_shuffle_i32x4(zmm0, zmm2, 0xdd); + tmp2 = _mm512_shuffle_i32x4(zmm1, zmm3, 0x88); + tmp3 = _mm512_shuffle_i32x4(zmm1, zmm3, 0xdd); + + zmm0 = _mm512_shuffle_i32x4(zmm4, tmp0, 0x88); + zmm1 = _mm512_shuffle_i32x4(zmm5, tmp2, 0x88); + zmm2 = _mm512_shuffle_i32x4(zmm6, tmp1, 0x88); + zmm3 = _mm512_shuffle_i32x4(zmm7, tmp3, 0x88); + zmm4 = _mm512_shuffle_i32x4(zmm4, tmp0, 0xdd); + zmm5 = _mm512_shuffle_i32x4(zmm5, tmp2, 0xdd); + zmm6 = _mm512_shuffle_i32x4(zmm6, tmp1, 0xdd); + zmm7 = _mm512_shuffle_i32x4(zmm7, tmp3, 0xdd); + + _mm512_storeu_si512(dest, zmm0); + _mm512_storeu_si512(dest + dest_stride, zmm1); + _mm512_storeu_si512(dest + dest_stride * 2, zmm2); + _mm512_storeu_si512(dest + dest_stride * 3, zmm3); + _mm512_storeu_si512(dest + dest_stride * 4, zmm4); + _mm512_storeu_si512(dest + dest_stride * 5, zmm5); + _mm512_storeu_si512(dest + dest_stride * 6, zmm6); + _mm512_storeu_si512(dest + dest_stride * 7, zmm7); +#ifdef USE_CLDEMOTE + _mm_cldemote(dest); + _mm_cldemote(dest + dest_stride); + _mm_cldemote(dest + dest_stride * 2); + _mm_cldemote(dest + dest_stride * 3); + _mm_cldemote(dest + dest_stride * 4); + _mm_cldemote(dest + dest_stride * 5); + _mm_cldemote(dest + dest_stride * 6); + _mm_cldemote(dest + dest_stride * 7); +#endif +#else + LIBXSMM_UNUSED(source_void); LIBXSMM_UNUSED(dest_void); LIBXSMM_UNUSED(source_stride); LIBXSMM_UNUSED(dest_stride); +#endif +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void bf16_vnni_transpose(libxsmm_bfloat16* src, libxsmm_bfloat16* dst, int M, int N, int ld_in, int ld_out) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + const int _M = M/16, _N = N/16; + int i = 0, j = 0; + for (i = 0; i < _N; i++) { + for (j = 0; j < _M; j++) { + bf16_vnni_transpose_16x16((libxsmm_bfloat16*) src+i*16*ld_in+j*32, (libxsmm_bfloat16*) dst+j*16*ld_out+i*32, ld_in*2, ld_out*2); + } + } +#else + LIBXSMM_UNUSED(src); LIBXSMM_UNUSED(dst); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out); +#endif +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void bf16_transpose_32x16(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int ld_in, int ld_out) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + const int in_width=ld_in, out_width=ld_out; + const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10); + + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + rc = _mm512_loadu_si512(in + 12*in_width); + rd = _mm512_loadu_si512(in + 13*in_width); + re = _mm512_loadu_si512(in + 14*in_width); + rf = _mm512_loadu_si512(in + 15*in_width); + + t0 = _mm512_unpacklo_epi16(r0,r1); + t1 = _mm512_unpackhi_epi16(r0,r1); + t2 = _mm512_unpacklo_epi16(r2,r3); + t3 = _mm512_unpackhi_epi16(r2,r3); + t4 = _mm512_unpacklo_epi16(r4,r5); + t5 = _mm512_unpackhi_epi16(r4,r5); + t6 = _mm512_unpacklo_epi16(r6,r7); + t7 = _mm512_unpackhi_epi16(r6,r7); + t8 = _mm512_unpacklo_epi16(r8,r9); + t9 = _mm512_unpackhi_epi16(r8,r9); + ta = _mm512_unpacklo_epi16(ra,rb); + tb = _mm512_unpackhi_epi16(ra,rb); + tc = _mm512_unpacklo_epi16(rc,rd); + td = _mm512_unpackhi_epi16(rc,rd); + te = _mm512_unpacklo_epi16(re,rf); + tf = _mm512_unpackhi_epi16(re,rf); + + r0 = _mm512_unpacklo_epi32(t0,t2); + r1 = _mm512_unpackhi_epi32(t0,t2); + r2 = _mm512_unpacklo_epi32(t1,t3); + r3 = _mm512_unpackhi_epi32(t1,t3); + r4 = _mm512_unpacklo_epi32(t4,t6); + r5 = _mm512_unpackhi_epi32(t4,t6); + r6 = _mm512_unpacklo_epi32(t5,t7); + r7 = _mm512_unpackhi_epi32(t5,t7); + r8 = _mm512_unpacklo_epi32(t8,ta); + r9 = _mm512_unpackhi_epi32(t8,ta); + ra = _mm512_unpacklo_epi32(t9,tb); + rb = _mm512_unpackhi_epi32(t9,tb); + rc = _mm512_unpacklo_epi32(tc,te); + rd = _mm512_unpackhi_epi32(tc,te); + re = _mm512_unpacklo_epi32(td,tf); + rf = _mm512_unpackhi_epi32(td,tf); + + t0 = _mm512_unpacklo_epi64(r0,r4); + t1 = _mm512_unpackhi_epi64(r0,r4); + t2 = _mm512_unpacklo_epi64(r1,r5); + t3 = _mm512_unpackhi_epi64(r1,r5); + t4 = _mm512_unpacklo_epi64(r2,r6); + t5 = _mm512_unpackhi_epi64(r2,r6); + t6 = _mm512_unpacklo_epi64(r3,r7); + t7 = _mm512_unpackhi_epi64(r3,r7); + t8 = _mm512_unpacklo_epi64(r8,rc); + t9 = _mm512_unpackhi_epi64(r8,rc); + ta = _mm512_unpacklo_epi64(r9,rd); + tb = _mm512_unpackhi_epi64(r9,rd); + tc = _mm512_unpacklo_epi64(ra,re); + td = _mm512_unpackhi_epi64(ra,re); + te = _mm512_unpacklo_epi64(rb,rf); + tf = _mm512_unpackhi_epi64(rb,rf); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r2 = _mm512_shuffle_i32x4(t4, t5, 0x88); + r3 = _mm512_shuffle_i32x4(t6, t7, 0x88); + r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd); + r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd); + r8 = _mm512_shuffle_i32x4(t8, t9, 0x88); + r9 = _mm512_shuffle_i32x4(ta, tb, 0x88); + ra = _mm512_shuffle_i32x4(tc, td, 0x88); + rb = _mm512_shuffle_i32x4(te, tf, 0x88); + rc = _mm512_shuffle_i32x4(t8, t9, 0xdd); + rd = _mm512_shuffle_i32x4(ta, tb, 0xdd); + re = _mm512_shuffle_i32x4(tc, td, 0xdd); + rf = _mm512_shuffle_i32x4(te, tf, 0xdd); + + t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8); + t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9); + t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra); + t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb); + t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc); + t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd); + t6 = _mm512_permutex2var_epi64(r6, idx_lo, re); + t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf); + t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0); + t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1); + ta = _mm512_permutex2var_epi64(ra, idx_hi, r2); + tb = _mm512_permutex2var_epi64(rb, idx_hi, r3); + tc = _mm512_permutex2var_epi64(rc, idx_hi, r4); + td = _mm512_permutex2var_epi64(rd, idx_hi, r5); + te = _mm512_permutex2var_epi64(re, idx_hi, r6); + tf = _mm512_permutex2var_epi64(rf, idx_hi, r7); + + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 0*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 1*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 2*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 3*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 4*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 5*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 6*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 7*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 8*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 9*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 10*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 11*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 12*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 13*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 14*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 15*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 16*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 17*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 18*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 19*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 20*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 21*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 22*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 23*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 24*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 25*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 26*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 27*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 28*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 29*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 30*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0)); + LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 31*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1)); +#ifdef USE_CLDEMOTE + _mm_cldemote(out + 0*out_width); + _mm_cldemote(out + 1*out_width); + _mm_cldemote(out + 2*out_width); + _mm_cldemote(out + 3*out_width); + _mm_cldemote(out + 4*out_width); + _mm_cldemote(out + 5*out_width); + _mm_cldemote(out + 6*out_width); + _mm_cldemote(out + 7*out_width); + _mm_cldemote(out + 8*out_width); + _mm_cldemote(out + 9*out_width); + _mm_cldemote(out + 10*out_width); + _mm_cldemote(out + 11*out_width); + _mm_cldemote(out + 12*out_width); + _mm_cldemote(out + 13*out_width); + _mm_cldemote(out + 14*out_width); + _mm_cldemote(out + 15*out_width); + _mm_cldemote(out + 16*out_width); + _mm_cldemote(out + 17*out_width); + _mm_cldemote(out + 18*out_width); + _mm_cldemote(out + 19*out_width); + _mm_cldemote(out + 20*out_width); + _mm_cldemote(out + 21*out_width); + _mm_cldemote(out + 22*out_width); + _mm_cldemote(out + 23*out_width); + _mm_cldemote(out + 24*out_width); + _mm_cldemote(out + 25*out_width); + _mm_cldemote(out + 26*out_width); + _mm_cldemote(out + 27*out_width); + _mm_cldemote(out + 28*out_width); + _mm_cldemote(out + 29*out_width); + _mm_cldemote(out + 30*out_width); + _mm_cldemote(out + 31*out_width); +#endif +#else + LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out); +#endif +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void bf16_transpose_32xcols(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int col, int ld_in, int ld_out) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + __m512i r0 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r1 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rf = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + const int in_width=ld_in, out_width=ld_out; + const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10); + __mmask16 store_mask = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(((unsigned int)1 << col) - 1); + + if (col == 15) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + rc = _mm512_loadu_si512(in + 12*in_width); + rd = _mm512_loadu_si512(in + 13*in_width); + re = _mm512_loadu_si512(in + 14*in_width); + } else if (col == 14) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + rc = _mm512_loadu_si512(in + 12*in_width); + rd = _mm512_loadu_si512(in + 13*in_width); + } else if (col == 13) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + rc = _mm512_loadu_si512(in + 12*in_width); + } else if (col == 12) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + } else if (col == 11) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + ra = _mm512_loadu_si512(in + 10*in_width); + } else if (col == 10) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + } else if (col == 9) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + r8 = _mm512_loadu_si512(in + 8*in_width); + } else if (col == 8) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + } else if (col == 7) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + r6 = _mm512_loadu_si512(in + 6*in_width); + } else if (col == 6) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + } else if (col == 5) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + r4 = _mm512_loadu_si512(in + 4*in_width); + } else if (col == 4) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + } else if (col == 3) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + r2 = _mm512_loadu_si512(in + 2*in_width); + } else if (col == 2) { + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + } else if (col == 1) { + r0 = _mm512_loadu_si512(in + 0*in_width); + } + + t0 = _mm512_unpacklo_epi16(r0,r1); + t1 = _mm512_unpackhi_epi16(r0,r1); + t2 = _mm512_unpacklo_epi16(r2,r3); + t3 = _mm512_unpackhi_epi16(r2,r3); + t4 = _mm512_unpacklo_epi16(r4,r5); + t5 = _mm512_unpackhi_epi16(r4,r5); + t6 = _mm512_unpacklo_epi16(r6,r7); + t7 = _mm512_unpackhi_epi16(r6,r7); + t8 = _mm512_unpacklo_epi16(r8,r9); + t9 = _mm512_unpackhi_epi16(r8,r9); + ta = _mm512_unpacklo_epi16(ra,rb); + tb = _mm512_unpackhi_epi16(ra,rb); + tc = _mm512_unpacklo_epi16(rc,rd); + td = _mm512_unpackhi_epi16(rc,rd); + te = _mm512_unpacklo_epi16(re,rf); + tf = _mm512_unpackhi_epi16(re,rf); + + r0 = _mm512_unpacklo_epi32(t0,t2); + r1 = _mm512_unpackhi_epi32(t0,t2); + r2 = _mm512_unpacklo_epi32(t1,t3); + r3 = _mm512_unpackhi_epi32(t1,t3); + r4 = _mm512_unpacklo_epi32(t4,t6); + r5 = _mm512_unpackhi_epi32(t4,t6); + r6 = _mm512_unpacklo_epi32(t5,t7); + r7 = _mm512_unpackhi_epi32(t5,t7); + r8 = _mm512_unpacklo_epi32(t8,ta); + r9 = _mm512_unpackhi_epi32(t8,ta); + ra = _mm512_unpacklo_epi32(t9,tb); + rb = _mm512_unpackhi_epi32(t9,tb); + rc = _mm512_unpacklo_epi32(tc,te); + rd = _mm512_unpackhi_epi32(tc,te); + re = _mm512_unpacklo_epi32(td,tf); + rf = _mm512_unpackhi_epi32(td,tf); + + t0 = _mm512_unpacklo_epi64(r0,r4); + t1 = _mm512_unpackhi_epi64(r0,r4); + t2 = _mm512_unpacklo_epi64(r1,r5); + t3 = _mm512_unpackhi_epi64(r1,r5); + t4 = _mm512_unpacklo_epi64(r2,r6); + t5 = _mm512_unpackhi_epi64(r2,r6); + t6 = _mm512_unpacklo_epi64(r3,r7); + t7 = _mm512_unpackhi_epi64(r3,r7); + t8 = _mm512_unpacklo_epi64(r8,rc); + t9 = _mm512_unpackhi_epi64(r8,rc); + ta = _mm512_unpacklo_epi64(r9,rd); + tb = _mm512_unpackhi_epi64(r9,rd); + tc = _mm512_unpacklo_epi64(ra,re); + td = _mm512_unpackhi_epi64(ra,re); + te = _mm512_unpacklo_epi64(rb,rf); + tf = _mm512_unpackhi_epi64(rb,rf); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r2 = _mm512_shuffle_i32x4(t4, t5, 0x88); + r3 = _mm512_shuffle_i32x4(t6, t7, 0x88); + r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd); + r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd); + r8 = _mm512_shuffle_i32x4(t8, t9, 0x88); + r9 = _mm512_shuffle_i32x4(ta, tb, 0x88); + ra = _mm512_shuffle_i32x4(tc, td, 0x88); + rb = _mm512_shuffle_i32x4(te, tf, 0x88); + rc = _mm512_shuffle_i32x4(t8, t9, 0xdd); + rd = _mm512_shuffle_i32x4(ta, tb, 0xdd); + re = _mm512_shuffle_i32x4(tc, td, 0xdd); + rf = _mm512_shuffle_i32x4(te, tf, 0xdd); + + t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8); + t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9); + t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra); + t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb); + t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc); + t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd); + t6 = _mm512_permutex2var_epi64(r6, idx_lo, re); + t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf); + t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0); + t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1); + ta = _mm512_permutex2var_epi64(ra, idx_hi, r2); + tb = _mm512_permutex2var_epi64(rb, idx_hi, r3); + tc = _mm512_permutex2var_epi64(rc, idx_hi, r4); + td = _mm512_permutex2var_epi64(rd, idx_hi, r5); + te = _mm512_permutex2var_epi64(re, idx_hi, r6); + tf = _mm512_permutex2var_epi64(rf, idx_hi, r7); + + _mm256_mask_storeu_epi16(out + 0*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0)); + _mm256_mask_storeu_epi16(out + 1*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1)); + _mm256_mask_storeu_epi16(out + 2*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0)); + _mm256_mask_storeu_epi16(out + 3*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1)); + _mm256_mask_storeu_epi16(out + 4*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0)); + _mm256_mask_storeu_epi16(out + 5*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1)); + _mm256_mask_storeu_epi16(out + 6*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0)); + _mm256_mask_storeu_epi16(out + 7*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1)); + _mm256_mask_storeu_epi16(out + 8*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0)); + _mm256_mask_storeu_epi16(out + 9*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1)); + _mm256_mask_storeu_epi16(out + 10*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0)); + _mm256_mask_storeu_epi16(out + 11*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1)); + _mm256_mask_storeu_epi16(out + 12*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0)); + _mm256_mask_storeu_epi16(out + 13*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1)); + _mm256_mask_storeu_epi16(out + 14*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0)); + _mm256_mask_storeu_epi16(out + 15*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1)); + _mm256_mask_storeu_epi16(out + 16*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0)); + _mm256_mask_storeu_epi16(out + 17*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1)); + _mm256_mask_storeu_epi16(out + 18*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0)); + _mm256_mask_storeu_epi16(out + 19*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1)); + _mm256_mask_storeu_epi16(out + 20*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0)); + _mm256_mask_storeu_epi16(out + 21*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1)); + _mm256_mask_storeu_epi16(out + 22*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0)); + _mm256_mask_storeu_epi16(out + 23*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1)); + _mm256_mask_storeu_epi16(out + 24*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0)); + _mm256_mask_storeu_epi16(out + 25*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1)); + _mm256_mask_storeu_epi16(out + 26*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0)); + _mm256_mask_storeu_epi16(out + 27*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1)); + _mm256_mask_storeu_epi16(out + 28*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0)); + _mm256_mask_storeu_epi16(out + 29*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1)); + _mm256_mask_storeu_epi16(out + 30*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0)); + _mm256_mask_storeu_epi16(out + 31*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1)); +#ifdef USE_CLDEMOTE + _mm_cldemote(out + 0*out_width); + _mm_cldemote(out + 1*out_width); + _mm_cldemote(out + 2*out_width); + _mm_cldemote(out + 3*out_width); + _mm_cldemote(out + 4*out_width); + _mm_cldemote(out + 5*out_width); + _mm_cldemote(out + 6*out_width); + _mm_cldemote(out + 7*out_width); + _mm_cldemote(out + 8*out_width); + _mm_cldemote(out + 9*out_width); + _mm_cldemote(out + 10*out_width); + _mm_cldemote(out + 11*out_width); + _mm_cldemote(out + 12*out_width); + _mm_cldemote(out + 13*out_width); + _mm_cldemote(out + 14*out_width); + _mm_cldemote(out + 15*out_width); + _mm_cldemote(out + 16*out_width); + _mm_cldemote(out + 17*out_width); + _mm_cldemote(out + 18*out_width); + _mm_cldemote(out + 19*out_width); + _mm_cldemote(out + 20*out_width); + _mm_cldemote(out + 21*out_width); + _mm_cldemote(out + 22*out_width); + _mm_cldemote(out + 23*out_width); + _mm_cldemote(out + 24*out_width); + _mm_cldemote(out + 25*out_width); + _mm_cldemote(out + 26*out_width); + _mm_cldemote(out + 27*out_width); + _mm_cldemote(out + 28*out_width); + _mm_cldemote(out + 29*out_width); + _mm_cldemote(out + 30*out_width); + _mm_cldemote(out + 31*out_width); +#endif +#else + LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out); LIBXSMM_UNUSED(col); +#endif +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void bf16_transpose(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int M, int N, int ld_in, int ld_out){ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + int i, j; + int full16_chunks = N/16; + int remainder_cols = N%16; + int _N = N - remainder_cols; + + if (full16_chunks) { + for (i=0; iifmblock; + libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K; + libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C; + libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K; + libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N; + libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock; + element_input_type alpha = (element_input_type)1; + element_input_type beta = (element_input_type)0; + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { + typedef libxsmm_smmfunction gemm_function; + gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL); + gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL); +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c" + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef float element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_smmfunction gemm_function; + libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock; + libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K; + libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C; + libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K; + libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N; + libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock; + float alpha = (element_input_type)1; + float beta = (element_input_type)0; + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { + gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL); + gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL); +# define LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32 +# define LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32 +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32 +# undef LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32 + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.smrs; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.smrs; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.smrs; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.smrs; + +#define LIBXSMM_DNN_FC_BWD_USE_AVX512 + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } +#undef LIBXSMM_DNN_FC_BWD_USE_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.bmrs; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.bmrs; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) + LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.bmrs; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.bmrs; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + return libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid ); +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd3.xgemm.bmrs; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_upd_zerobeta = handle->gemm_upd3.xgemm.bmrs; + libxsmm_bsmmfunction bwd_tile_config_kernel = handle->bwd_config_kernel; + /*libxsmm_bsmmfunction upd_tile_config_kernel = handle->upd_config_kernel;*/ + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + return libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu(handle, kind, start_thread, tid); +} +#endif + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd3.xgemm.bmrs; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_upd_zerobeta = handle->gemm_upd3.xgemm.bmrs; + libxsmm_bsmmfunction bwd_tile_config_kernel = handle->bwd_config_kernel; + /*libxsmm_bsmmfunction upd_tile_config_kernel = handle->upd_config_kernel;*/ + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if all required tensors are bound */ + if ( kind == LIBXSMM_DNN_COMPUTE_KIND_BWD ) { + if (handle->grad_input == 0 || handle->grad_output == 0 || + handle->reg_filter == 0 || handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } else if ( kind == LIBXSMM_DNN_COMPUTE_KIND_UPD ) { + if (handle->reg_input == 0 || handle->grad_output == 0 || + handle->grad_filter == 0 || handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } else { + if (handle->grad_input == 0 || handle->grad_output == 0 || + handle->reg_input == 0 || handle->grad_filter == 0 || + handle->reg_filter == 0 || handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32( handle, kind, start_thread, tid); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__*/ + else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32( handle, kind, start_thread, tid); + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock; + libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K; + libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C; + libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K; + libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N; + libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock; + element_input_type alpha = (element_input_type)1; + element_input_type beta = (element_input_type)0; + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { + gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL); + gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL); +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c" + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef libxsmm_bfloat16 element_input_type; + typedef float element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_smmfunction gemm_function; + libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock; + libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K; + libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C; + libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K; + libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N; + libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock; + float alpha = (element_input_type)1; + float beta = (element_input_type)0; + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { + gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL); + gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL); +# define LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32 +# define LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32 +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32 +# undef LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32 + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + int l_emu_amx = 0; + const char *const l_env_emu_amx = getenv("EMULATE_AMX"); + if ( 0 == l_env_emu_amx ) { + } else { + l_emu_amx = atoi(l_env_emu_amx); + } + + /* check if all required tensors are bound */ + if ( kind == LIBXSMM_DNN_COMPUTE_KIND_BWD ) { + if (handle->grad_input == 0 || handle->grad_output == 0 || + handle->reg_filter == 0 || handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } else if ( kind == LIBXSMM_DNN_COMPUTE_KIND_UPD ) { + if (handle->reg_input == 0 || handle->grad_output == 0 || + handle->grad_filter == 0 || handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } else { + if (handle->grad_input == 0 || handle->grad_output == 0 || + handle->reg_input == 0 || handle->grad_filter == 0 || + handle->reg_filter == 0 || handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) != 0) && ( handle->grad_bias == 0 ) ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) != 0) && ( handle->relumask == 0 ) ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32( handle, kind, start_thread, tid); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16( handle, kind, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + if ( l_emu_amx == 0 ) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid); + } else { + status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu( handle, kind, start_thread, tid); + } + } +#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR ) { + if ( l_emu_amx == 0 ) { + status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid); + } else { + status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu( handle, kind, start_thread, tid); + } + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + LIBXSMM_UNUSED( l_emu_amx ); + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.smrs; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.smrs; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.smrs; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.smrs; + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_nhwc(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + LIBXSMM_UNUSED( handle ); + LIBXSMM_UNUSED( kind ); + LIBXSMM_UNUSED( start_thread ); + LIBXSMM_UNUSED( tid ); + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_backward_weight_update.h b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_backward_weight_update.h new file mode 100644 index 0000000000000000000000000000000000000000..ab59cd44e507ef481769eccb91dccb397c9bbd40 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_backward_weight_update.h @@ -0,0 +1,22 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_FULLYCONNECTED_BACKWARD_WEIGHT_UPDATE_H +#define LIBXSMM_DNN_FULLYCONNECTED_BACKWARD_WEIGHT_UPDATE_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_nhwc(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_FULLYCONNECTED_BACKWARD_WEIGHT_UPDATE_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_forward.c b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_forward.c new file mode 100644 index 0000000000000000000000000000000000000000..52904ac7ca68611f58286b8aba2e5ec3424ad195 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_forward.c @@ -0,0 +1,649 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_fullyconnected_forward.h" +#include "libxsmm_main.h" + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + element_input_type alpha = (element_input_type)1; + element_input_type beta = (element_input_type)0; + libxsmm_blasint lda = (libxsmm_blasint)handle->ofmblock; + libxsmm_blasint ldb = (libxsmm_blasint)handle->desc.C; + libxsmm_blasint ldc = (libxsmm_blasint)handle->desc.K; + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->desc.N, handle->desc.C, &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL); +# include "template/libxsmm_dnn_fullyconnected_st_fwd_custom_generic.tpl.c" + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef float element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + typedef libxsmm_smmfunction gemm_function; + libxsmm_blasint lda = (libxsmm_blasint)handle->ofmblock; + libxsmm_blasint ldb = (libxsmm_blasint)handle->desc.C; + libxsmm_blasint ldc = (libxsmm_blasint)handle->desc.K; + float alpha = (element_input_type)1; + float beta = (element_input_type)0; + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->desc.N, handle->desc.C, &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL); +# define LIBXSMM_DNN_FULLYCONNECTED_FWD_BF16_F32 +# include "template/libxsmm_dnn_fullyconnected_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FULLYCONNECTED_FWD_BF16_F32 + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_beta = handle->gemm_fwd.xgemm.smrs; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_zerobeta = handle->gemm_fwd2.xgemm.smrs; + +#define LIBXSMM_DNN_FC_FWD_USE_AVX512 + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_NONE +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } +#undef LIBXSMM_DNN_FC_FWD_USE_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel = handle->gemm_fwd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_zerobeta = handle->gemm_fwd2.xgemm.bmrs; + libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_beta = handle->gemm_fwd3.xgemm.bmrs; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_NONE +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel = handle->gemm_fwd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_zerobeta = handle->gemm_fwd2.xgemm.bmrs; + libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_beta = handle->gemm_fwd3.xgemm.bmrs; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_NONE +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + return libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu( handle, start_thread, tid ); +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel = handle->gemm_fwd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_zerobeta = handle->gemm_fwd3.xgemm.bmrs; + libxsmm_bsmmfunction tile_config_kernel = handle->fwd_config_kernel; +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if (handle->compressed_A == 1) { + libxsmm_bsmmfunction_reducebatch_strd_meltwfused batchreduce_kernel_decompress = handle->gemm_fwd9.xgemm.bsmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_decompress = handle->gemm_fwd11.xgemm.bmrs_meltwfused; + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_NONE +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd4.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd12.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd5.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd13.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd6.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd14.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd7.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd15.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd8.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd16.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + } else { + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_NONE +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd4.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd5.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd6.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd7.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd8.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) { + return libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu( handle, start_thread, tid ); +} +#endif + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel = handle->gemm_fwd.xgemm.bsmrs; + libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_zerobeta = handle->gemm_fwd3.xgemm.bmrs; + libxsmm_bsmmfunction tile_config_kernel = handle->fwd_config_kernel; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if (handle->compressed_A == 1) { + libxsmm_bsmmfunction_reducebatch_strd_meltwfused batchreduce_kernel_decompress = handle->gemm_fwd9.xgemm.bsmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_decompress = handle->gemm_fwd11.xgemm.bmrs_meltwfused; + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_NONE +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd4.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd12.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd5.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd13.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd6.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd14.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd7.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd15.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd8.xgemm.bmrs_meltwfused; + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd16.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + } else { + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_NONE +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd4.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd5.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd6.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd7.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { + libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd8.xgemm.bmrs_meltwfused; +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if all required tensors are bound */ + if (handle->reg_input == 0 || handle->reg_output == 0 || + handle->reg_filter == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fullyconnected_st_fwd_custom_f32_f32( handle, start_thread, tid); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE ) { + status = libxsmm_dnn_fullyconnected_st_fwd_custom_bf16_f32( handle, start_thread, tid); + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + typedef libxsmm_smmfunction gemm_function; + libxsmm_blasint lda = (libxsmm_blasint)handle->ofmblock; + libxsmm_blasint ldb = (libxsmm_blasint)handle->desc.C; + libxsmm_blasint ldc = (libxsmm_blasint)handle->desc.K; + element_input_type beta = (element_input_type)0; + element_input_type alpha = (element_input_type)1; + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->desc.N, handle->desc.C, &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL); +# include "template/libxsmm_dnn_fullyconnected_st_fwd_custom_generic.tpl.c" + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + int l_emu_amx = 0; + const char *const l_env_emu_amx = getenv("EMULATE_AMX"); + if ( 0 == l_env_emu_amx ) { + } else { + l_emu_amx = atoi(l_env_emu_amx); + } + + /* check if all required tensors are bound */ + if (handle->reg_input == 0 || handle->reg_output == 0 || + handle->reg_filter == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) != 0) && ( handle->reg_bias == 0 ) ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) != 0) && ( handle->relumask == 0 ) ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_f32_f32( handle, start_thread, tid); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX) { + status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) { + if ( l_emu_amx == 0 ) { + status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx( handle, start_thread, tid); + } else { + status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu( handle, start_thread, tid); + } + } +#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu( handle, start_thread, tid); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR ) { + if ( l_emu_amx == 0 ) { + status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx( handle, start_thread, tid); + } else { + status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu( handle, start_thread, tid); + } + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + LIBXSMM_UNUSED( l_emu_amx ); + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_beta = handle->gemm_fwd.xgemm.smrs; + libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_zerobeta = handle->gemm_fwd2.xgemm.smrs; + + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_NONE +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_RELU +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) { +#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS +#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c" +#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID +#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + } else { + status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION; + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_nhwc(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + LIBXSMM_UNUSED( handle ); + LIBXSMM_UNUSED( start_thread ); + LIBXSMM_UNUSED( tid ); + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_forward.h b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..949bc955b0fc394c7720a983af318c41599dea71 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fullyconnected_forward.h @@ -0,0 +1,22 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_FULLYCONNECTED_FORWARD_H +#define LIBXSMM_DNN_FULLYCONNECTED_FORWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_nhwc(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_FULLYCONNECTED_FORWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm.c b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm.c new file mode 100644 index 0000000000000000000000000000000000000000..6d91c8d4e78afdf27afa679154de1e54c1376286 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm.c @@ -0,0 +1,638 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_fusedbatchnorm_backward.h" +#include "libxsmm_dnn_fusedbatchnorm_forward.h" +#include "libxsmm_main.h" + + +LIBXSMM_API libxsmm_dnn_fusedbatchnorm* libxsmm_dnn_create_fusedbatchnorm(libxsmm_dnn_fusedbatchnorm_desc fusedbatchnorm_desc, libxsmm_dnn_err_t* status) { + libxsmm_dnn_fusedbatchnorm* handle = 0; + int lpb; + + /* init libxsmm */ + LIBXSMM_INIT + + if ( fusedbatchnorm_desc.partN > fusedbatchnorm_desc.fullN ) { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + return handle; + } else if ( (fusedbatchnorm_desc.partN != fusedbatchnorm_desc.fullN) && ((fusedbatchnorm_desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) == 0 ) && ((fusedbatchnorm_desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) == 0 ) ) { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + return handle; + } else { + } + + if ( ((fusedbatchnorm_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fusedbatchnorm_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) || + ((fusedbatchnorm_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (fusedbatchnorm_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + handle = (libxsmm_dnn_fusedbatchnorm*)calloc(1, sizeof(libxsmm_dnn_fusedbatchnorm)); + + if (0 != handle) { + *status = LIBXSMM_DNN_SUCCESS; + /* let's make the description persistent */ + handle->desc = fusedbatchnorm_desc; + /* we need to compute the memory layout given the */ + *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C, + &(handle->ifmblock), &(handle->ofmblock), &lpb, + handle->desc.datatype_in, handle->desc.datatype_out ); + /* compute the outer blocks */ + handle->blocksifm = handle->desc.C / handle->ifmblock; + handle->blocksofm = handle->desc.C / handle->ofmblock; + /* create barrier */ + handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1); + /* calculate scratch size for batchstats */ + handle->scratch_size = (sizeof(float) * 2 * handle->desc.C * handle->desc.partN); + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + } + } else { + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + + return handle; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fusedbatchnorm(const libxsmm_dnn_fusedbatchnorm* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + /* Deallocate barrier */ + if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); } + /* deallocate handle structure */ + free(/*remove constness*/(libxsmm_dnn_fusedbatchnorm*)handle); + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fusedbatchnorm_create_tensor_datalayout(const libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor_datalayout* layout; + + *status = LIBXSMM_DNN_SUCCESS; + layout = 0; + + if (handle != 0) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout)); + + if (layout != 0) { + layout->format = handle->desc.buffer_format; + + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) || + (type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) { + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in); + layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in); + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.partN; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.partN; + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_BF16; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in); + layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in); + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.partN; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.partN; + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || + ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) { + layout->dim_size[0] = handle->desc.C; + layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in); + layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in); + layout->dim_size[3] = handle->desc.partN; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->desc.C; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->desc.partN; + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) || (type == LIBXSMM_DNN_CHANNEL_BETA) || + (type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) || (type == LIBXSMM_DNN_CHANNEL_GAMMA) || + (type == LIBXSMM_DNN_CHANNEL_EXPECTVAL) || (type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV) || (type == LIBXSMM_DNN_CHANNEL_VARIANCE) ) { + layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR; + + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) { + layout->datatype = handle->desc.datatype_stats; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 2; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->blocksifm; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) { + layout->datatype = handle->desc.datatype_stats; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 1; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = handle->desc.C; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_RELU_MASK) ) { + layout->tensor_type = LIBXSMM_DNN_RELU_MASK; + + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + layout->datatype = LIBXSMM_DNN_DATATYPE_I8; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.partN; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + layout->datatype = LIBXSMM_DNN_DATATYPE_I8; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 6; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ofmblock*handle->blocksofm; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->desc.partN; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return layout; +} + +LIBXSMM_API size_t libxsmm_dnn_fusedbatchnorm_get_scratch_size(const libxsmm_dnn_fusedbatchnorm* handle, libxsmm_dnn_err_t* status) { + size_t l_scratch_size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */ + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return l_scratch_size; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_bind_scratch(libxsmm_dnn_fusedbatchnorm* handle, const void* scratch) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + uintptr_t address = (uintptr_t)scratch; + size_t offset = 0; + + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + + if (0 != handle) { + /* align the internal scratch buffer if needed */ + if (address % 64 == 0) { + handle->scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch = (void*)(address+offset); + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_release_scratch(libxsmm_dnn_fusedbatchnorm* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + handle->scratch = 0; + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_bind_tensor(libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) && + (type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) && + (type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0 && tensor != 0) { + libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_fusedbatchnorm_create_tensor_datalayout(handle, type, &status); + + if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) { + handle->reg_add = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) { + handle->grad_add = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) { + handle->reg_beta = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) { + handle->grad_beta = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) { + handle->reg_gamma = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) { + handle->grad_gamma = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) { + handle->expvalue = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) { + handle->rcpstddev = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) { + handle->variance = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RELU_MASK ) { + handle->relumask = (libxsmm_dnn_tensor*)tensor; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR; + } + + libxsmm_dnn_destroy_tensor_datalayout( handle_layout ); + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fusedbatchnorm_get_tensor(libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor* return_tensor = 0; + + *status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) && + (type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) && + (type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) { + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return return_tensor; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + return_tensor = handle->reg_input; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + return_tensor = handle->grad_input; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + return_tensor = handle->reg_output; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + return_tensor = handle->grad_output; + } else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) { + return_tensor = handle->reg_add; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) { + return_tensor = handle->grad_add; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) { + return_tensor = handle->reg_beta; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) { + return_tensor = handle->grad_beta; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) { + return_tensor = handle->reg_gamma; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) { + return_tensor = handle->grad_gamma; + } else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) { + return_tensor = handle->expvalue; + } else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) { + return_tensor = handle->rcpstddev; + } else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) { + return_tensor = handle->variance; + } else if ( type == LIBXSMM_DNN_RELU_MASK ) { + return_tensor = handle->relumask; + } else { + /* cannot happen */ + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return return_tensor; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_release_tensor(libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) && + (type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) && + (type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) { + handle->reg_add = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) { + handle->grad_add = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) { + handle->reg_beta = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) { + handle->grad_beta = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) { + handle->reg_gamma = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) { + handle->grad_gamma = 0; + } else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) { + handle->expvalue = 0; + } else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) { + handle->rcpstddev = 0; + } else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) { + handle->variance = 0; + } else if ( type == LIBXSMM_DNN_RELU_MASK ) { + handle->relumask = 0; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_execute_st(libxsmm_dnn_fusedbatchnorm* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + switch (handle->desc.buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom( handle, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN; + } + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: { + switch (handle->desc.buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom( handle, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handles && num_handles > 0) { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + switch (handles[0]->desc.buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom( handles, num_handles, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN; + } + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: { + switch (handles[0]->desc.buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_bwd_custom( handles, num_handles, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_backward.c b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_backward.c new file mode 100644 index 0000000000000000000000000000000000000000..2d632f42f270521871ae9f9c13f5d3c2a531dde2 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_backward.c @@ -0,0 +1,604 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_fusedbatchnorm_backward.h" +#include "libxsmm_main.h" + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDBN_BWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDBN_BWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDBN_BWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDBN_BWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDBN_BWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDBN_BWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if all required tensors are bound */ + if ( handle->reg_input == 0 || handle->reg_gamma == 0 || + handle->grad_input == 0 || handle->grad_output == 0 || + handle->grad_beta == 0 || handle->grad_gamma == 0 || + handle->expvalue == 0 || handle->rcpstddev == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0 ) { + if ( handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) > 0 ) { + if ( handle->grad_add == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) > 0 ) { + if ( handle->reg_output == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) > 0 ) { + if ( handle->relumask == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 16) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c16( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c16( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 32) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c32( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c32( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 64) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c64( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c64( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDBN_BWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) { +# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDBN_BWD_BF16 + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + LIBXSMM_UNUSED( handle ); + LIBXSMM_UNUSED( start_thread ); + LIBXSMM_UNUSED( tid ); + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_bwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + int l_count; + + /* check if all required tensors are bound */ + for ( l_count = 0; l_count < num_handles; ++l_count ) { + if ( handles[l_count]->grad_beta == 0 || handles[l_count]->grad_gamma == 0 || handles[l_count]->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + +#if 0 + /* check if we are on an AVX512 platform */ + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) { + status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_bwd_custom_avx512( handles, num_handles, start_thread, tid ); + } else +#endif + { + const int nImg = handles[0]->desc.partN; + const int nBlocksFm = handles[0]->blocksifm; + const int nFmBlock = handles[0]->ifmblock; + /* computing first logical thread */ + const int ltid = tid - start_thread; + /* number of tasks that could be run in parallel */ + const int work2 = nBlocksFm; + /* compute chunk size */ + const int chunksize2 = (work2 % handles[0]->desc.threads == 0) ? (work2 / handles[0]->desc.threads) : ((work2 / handles[0]->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; + const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + int v = 0, fm; + + LIBXSMM_VLA_DECL(2, float, dgamma0, (float*)handles[0]->grad_gamma->data, nFmBlock); + LIBXSMM_VLA_DECL(2, float, dbeta0, (float*)handles[0]->grad_beta->data, nFmBlock); + LIBXSMM_VLA_DECL(3, float, dgamma_img0, (float*)handles[0]->scratch, nImg, nFmBlock); + LIBXSMM_VLA_DECL(3, float, dbeta_img0, ((float*)handles[0]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock); + + /* lazy barrier init */ + libxsmm_barrier_init(handles[0]->barrier, ltid); + + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock); + float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock); + float* dgamma_img0_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img0, fm, 0, 0, nImg, nFmBlock); + float* dbeta_img0_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img0, fm, 0, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + dgamma0_ptr[v] = dgamma_img0_ptr[v]; + dbeta0_ptr[v] = dbeta_img0_ptr[v]; + } + } + + /* now we need to reduce the dgamma and dbeta */ + for ( l_count = 1; l_count < num_handles; ++l_count ) { + LIBXSMM_VLA_DECL(3, float, dgamma_imgr, (float*)handles[l_count]->scratch, nImg, nFmBlock); + LIBXSMM_VLA_DECL(3, float, dbeta_imgr, ((float*)handles[l_count]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock); + + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock); + float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock); + float* dgamma_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_imgr, fm, 0, 0, nImg, nFmBlock); + float* dbeta_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_imgr, fm, 0, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + dgamma0_ptr[v] += dgamma_imgr_ptr[v]; + dbeta0_ptr[v] += dbeta_imgr_ptr[v]; + } + } + } + + for ( l_count = 1; l_count < num_handles; ++l_count ) { + LIBXSMM_VLA_DECL(2, float, dgammar, (float*)handles[l_count]->grad_gamma->data, nFmBlock); + LIBXSMM_VLA_DECL(2, float, dbetar, (float*)handles[l_count]->grad_beta->data, nFmBlock); + + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock); + float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock); + float* dgammar_ptr = &LIBXSMM_VLA_ACCESS(2, dgammar, fm, 0, nFmBlock); + float* dbetar_ptr = &LIBXSMM_VLA_ACCESS(2, dbetar, fm, 0, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + dgammar_ptr[v] = dgamma0_ptr[v]; + dbetar_ptr[v] = dbeta0_ptr[v]; + } + } + } + + libxsmm_barrier_wait(handles[0]->barrier, ltid); + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_backward.h b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..a09c3421785329904a19b79429ade168bb98434b --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_backward.h @@ -0,0 +1,22 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_FUSEDBATCHNORM_BACKWARD_H +#define LIBXSMM_DNN_FUSEDBATCHNORM_BACKWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_bwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_FUSEDBATCHNORM_BACKWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_forward.c b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_forward.c new file mode 100644 index 0000000000000000000000000000000000000000..fd3bf92e1921020b123c89471ffeec0fe40e5642 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_forward.c @@ -0,0 +1,618 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_fusedbatchnorm_forward.h" +#include "libxsmm_main.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDBN_FWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDBN_FWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDBN_FWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDBN_FWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDBN_FWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDBN_FWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if all required tensors are bound */ + if ( handle->reg_input == 0 || handle->reg_output == 0 || + handle->reg_beta == 0 || handle->reg_gamma == 0 || + handle->expvalue == 0 || handle->rcpstddev == 0 || handle->variance == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0 ) { + if ( handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) > 0 ) { + if ( handle->reg_add == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) > 0 ) { + if ( handle->relumask == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 16) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 32) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 64) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDBN_FWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) || + (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) { +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDBN_FWD_BF16 + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + LIBXSMM_UNUSED( handle ); + LIBXSMM_UNUSED( start_thread ); + LIBXSMM_UNUSED( tid ); + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + int l_count; + + /* check if all required tensors are bound */ + for ( l_count = 0; l_count < num_handles; ++l_count ) { + if ( handles[l_count]->expvalue == 0 || handles[l_count]->rcpstddev == 0 || handles[l_count]->variance == 0 || handles[l_count]->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + +#if 0 + /* check if we are on an AVX512 platform */ + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) { + status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom_avx512( handles, num_handles, start_thread, tid ); + } else +#endif + { + const int nImg = handles[0]->desc.partN; + const int nBlocksFm = handles[0]->blocksifm; + const int nFmBlock = handles[0]->ifmblock; + /* computing first logical thread */ + const int ltid = tid - start_thread; + /* number of tasks that could be run in parallel */ + const int work2 = nBlocksFm; + /* compute chunk size */ + const int chunksize2 = (work2 % handles[0]->desc.threads == 0) ? (work2 / handles[0]->desc.threads) : ((work2 / handles[0]->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; + const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + int v = 0, fm; + const float sqrt_eps = 1e-7f; + const float nhw = (float)(handles[0]->desc.fullN * handles[0]->desc.H * handles[0]->desc.W); + const float recp_nhw = 1.0f/nhw; + + LIBXSMM_VLA_DECL(2, float, bmean0, (float*)handles[0]->expvalue->data, nFmBlock); + LIBXSMM_VLA_DECL(2, float, brstd0, (float*)handles[0]->rcpstddev->data, nFmBlock); + LIBXSMM_VLA_DECL(2, float, variance0, (float*)handles[0]->variance->data, nFmBlock); + LIBXSMM_VLA_DECL(3, float, sum_img0, (float*)handles[0]->scratch, nImg, nFmBlock); + LIBXSMM_VLA_DECL(3, float, sumsq_img0, ((float*)handles[0]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock); + + /* lazy barrier init */ + libxsmm_barrier_init(handles[0]->barrier, ltid); + + /* now we need to reduce the sum and sum^2, we use the final */ + for ( l_count = 1; l_count < num_handles; ++l_count ) { + LIBXSMM_VLA_DECL(3, float, sum_imgr, (float*)handles[l_count]->scratch, nImg, nFmBlock); + LIBXSMM_VLA_DECL(3, float, sumsq_imgr, ((float*)handles[l_count]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock); + + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + float* sum_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img0, fm, 0, 0, nImg, nFmBlock); + float* sumsq_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img0, fm, 0, 0, nImg, nFmBlock); + float* sum_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, sum_imgr, fm, 0, 0, nImg, nFmBlock); + float* sumsq_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_imgr, fm, 0, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + sum_img0_ptr[v] += sum_imgr_ptr[v]; + sumsq_img0_ptr[v] += sumsq_imgr_ptr[v]; + } + } + } + + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + float* bmean0_ptr = &LIBXSMM_VLA_ACCESS(2, bmean0, fm, 0, nFmBlock); + float* brstd0_ptr = &LIBXSMM_VLA_ACCESS(2, brstd0, fm, 0, nFmBlock); + float* tvar0_ptr = &LIBXSMM_VLA_ACCESS(2, variance0, fm, 0, nFmBlock); + float* sum_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img0, fm, 0, 0, nImg, nFmBlock); + float* sumsq_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img0, fm, 0, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + const float tbmean = (recp_nhw * sum_img0_ptr[v]); + const float tbmeansq = tbmean * tbmean; + const float tsqbmean = recp_nhw * sumsq_img0_ptr[v]; + const float tvar = tsqbmean - tbmeansq; + const float tbrstd = (float)(1.0/sqrt((double)tvar + sqrt_eps)); + bmean0_ptr[v] = tbmean; + brstd0_ptr[v] = tbrstd; + tvar0_ptr[v] = tvar; + } + } + + for ( l_count = 1; l_count < num_handles; ++l_count ) { + LIBXSMM_VLA_DECL(2, float, bmeanr, (float*)handles[l_count]->expvalue->data, nFmBlock); + LIBXSMM_VLA_DECL(2, float, brstdr, (float*)handles[l_count]->rcpstddev->data, nFmBlock); + LIBXSMM_VLA_DECL(2, float, variancer, (float*)handles[l_count]->variance->data, nFmBlock); + + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + float* bmean0_ptr = &LIBXSMM_VLA_ACCESS(2, bmean0, fm, 0, nFmBlock); + float* brstd0_ptr = &LIBXSMM_VLA_ACCESS(2, brstd0, fm, 0, nFmBlock); + float* tvar0_ptr = &LIBXSMM_VLA_ACCESS(2, variance0, fm, 0, nFmBlock); + float* bmeanr_ptr = &LIBXSMM_VLA_ACCESS(2, bmeanr, fm, 0, nFmBlock); + float* brstdr_ptr = &LIBXSMM_VLA_ACCESS(2, brstdr, fm, 0, nFmBlock); + float* tvarr_ptr = &LIBXSMM_VLA_ACCESS(2, variancer, fm, 0, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + bmeanr_ptr[v] = bmean0_ptr[v]; + brstdr_ptr[v] = brstd0_ptr[v]; + tvarr_ptr[v] = tvar0_ptr[v]; + } + } + } + + libxsmm_barrier_wait(handles[0]->barrier, ltid); + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_forward.h b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..dfd76f6670a39c17c07ac24749f5e757972ff627 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedbatchnorm_forward.h @@ -0,0 +1,22 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_FUSEDBATCHNORM_FORWARD_H +#define LIBXSMM_DNN_FUSEDBATCHNORM_FORWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_FUSEDBATCHNORM_FORWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm.c b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm.c new file mode 100644 index 0000000000000000000000000000000000000000..97796014105dfd3f007408f5aebf176a7bee09d8 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm.c @@ -0,0 +1,648 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_fusedgroupnorm_backward.h" +#include "libxsmm_dnn_fusedgroupnorm_forward.h" +#include "libxsmm_main.h" + + +LIBXSMM_API libxsmm_dnn_fusedgroupnorm* libxsmm_dnn_create_fusedgroupnorm(libxsmm_dnn_fusedgroupnorm_desc fusedgroupnorm_desc, libxsmm_dnn_err_t* status) { + libxsmm_dnn_fusedgroupnorm* handle = 0; + int lpb; + + /* init libxsmm */ + LIBXSMM_INIT + + if ( ((fusedgroupnorm_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fusedgroupnorm_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) || + ((fusedgroupnorm_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (fusedgroupnorm_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + handle = (libxsmm_dnn_fusedgroupnorm*)calloc(1, sizeof(libxsmm_dnn_fusedgroupnorm)); + + if (0 != handle) { + *status = LIBXSMM_DNN_SUCCESS; + /* let's make the description persistent */ + handle->desc = fusedgroupnorm_desc; + /* we need to compute the memory layout given the */ + *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C, + &(handle->ifmblock), &(handle->ofmblock), &lpb, + handle->desc.datatype_in, handle->desc.datatype_out ); + /* compute the outer blocks */ + handle->blocksifm = handle->desc.C / handle->ifmblock; + handle->blocksofm = handle->desc.C / handle->ofmblock; + /* create barrier */ + handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1); + /* calculate scratch size for batchstats */ + handle->scratch_size = (sizeof(float) * 2 * ((handle->desc.C * handle->desc.N) + (handle->desc.G * handle->desc.N))); + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + } + } else { + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + + return handle; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fusedgroupnorm(const libxsmm_dnn_fusedgroupnorm* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + /* Deallocate barrier */ + if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); } + /* deallocate handle structure */ + free(/*remove constness*/(libxsmm_dnn_fusedgroupnorm*)handle); + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fusedgroupnorm_create_tensor_datalayout(const libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor_datalayout* layout; + + *status = LIBXSMM_DNN_SUCCESS; + layout = 0; + + if (handle != 0) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout)); + + if (layout != 0) { + layout->format = handle->desc.buffer_format; + + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) || + (type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) { + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in); + layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in); + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_BF16; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in); + layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in); + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || + ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) { + layout->dim_size[0] = handle->desc.C; + layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in); + layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in); + layout->dim_size[3] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->desc.C; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->desc.N; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) || (type == LIBXSMM_DNN_CHANNEL_BETA) || + (type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) || (type == LIBXSMM_DNN_CHANNEL_GAMMA) ) { + layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR; + + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) { + layout->datatype = handle->desc.datatype_stats; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 2; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->blocksifm; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) { + layout->datatype = handle->desc.datatype_stats; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 1; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = handle->desc.C; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_CHANNEL_EXPECTVAL) || (type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV) || (type == LIBXSMM_DNN_CHANNEL_VARIANCE) ) { + layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR; + + if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) || ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) ) { + if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) { + layout->datatype = handle->desc.datatype_stats; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 2; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_G; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->desc.G; + layout->dim_size[1] = handle->desc.N; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_RELU_MASK) ) { + layout->tensor_type = LIBXSMM_DNN_RELU_MASK; + + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + layout->datatype = LIBXSMM_DNN_DATATYPE_I8; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + layout->datatype = LIBXSMM_DNN_DATATYPE_I8; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 6; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->ofmblock*handle->blocksofm; + layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->desc.N; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return layout; +} + +LIBXSMM_API size_t libxsmm_dnn_fusedgroupnorm_get_scratch_size(const libxsmm_dnn_fusedgroupnorm* handle, libxsmm_dnn_err_t* status) { + size_t l_scratch_size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */ + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return l_scratch_size; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_bind_scratch(libxsmm_dnn_fusedgroupnorm* handle, const void* scratch) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + uintptr_t address = (uintptr_t)scratch; + size_t offset = 0; + + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + + if (0 != handle) { + /* align the internal scratch buffer if needed */ + if (address % 64 == 0) { + handle->scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch = (void*)(address+offset); + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_release_scratch(libxsmm_dnn_fusedgroupnorm* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + handle->scratch = 0; + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_bind_tensor(libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) && + (type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) && + (type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0 && tensor != 0) { + libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_fusedgroupnorm_create_tensor_datalayout(handle, type, &status); + + if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) { + handle->reg_add = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) { + handle->grad_add = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) { + handle->reg_beta = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) { + handle->grad_beta = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) { + handle->reg_gamma = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) { + handle->grad_gamma = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) { + handle->expvalue = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) { + handle->rcpstddev = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) { + handle->variance = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RELU_MASK ) { + handle->relumask = (libxsmm_dnn_tensor*)tensor; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR; + } + + libxsmm_dnn_destroy_tensor_datalayout( handle_layout ); + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fusedgroupnorm_get_tensor(libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor* return_tensor = 0; + + *status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) && + (type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) && + (type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) { + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return return_tensor; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + return_tensor = handle->reg_input; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + return_tensor = handle->grad_input; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + return_tensor = handle->reg_output; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + return_tensor = handle->grad_output; + } else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) { + return_tensor = handle->reg_add; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) { + return_tensor = handle->grad_add; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) { + return_tensor = handle->reg_beta; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) { + return_tensor = handle->grad_beta; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) { + return_tensor = handle->reg_gamma; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) { + return_tensor = handle->grad_gamma; + } else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) { + return_tensor = handle->expvalue; + } else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) { + return_tensor = handle->rcpstddev; + } else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) { + return_tensor = handle->variance; + } else if ( type == LIBXSMM_DNN_RELU_MASK ) { + return_tensor = handle->relumask; + } else { + /* cannot happen */ + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return return_tensor; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_release_tensor(libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) && + (type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) && + (type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) && + (type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) { + handle->reg_add = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) { + handle->grad_add = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) { + handle->reg_beta = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) { + handle->grad_beta = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) { + handle->reg_gamma = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) { + handle->grad_gamma = 0; + } else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) { + handle->expvalue = 0; + } else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) { + handle->rcpstddev = 0; + } else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) { + handle->variance = 0; + } else if ( type == LIBXSMM_DNN_RELU_MASK ) { + handle->relumask = 0; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_execute_st(libxsmm_dnn_fusedgroupnorm* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + switch (handle->desc.buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom( handle, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN; + } + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: { + switch (handle->desc.buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom( handle, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_reduce_stats_st(libxsmm_dnn_fusedgroupnorm** handles, int num_handles, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handles && num_handles > 0) { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_BWD: { + switch (handles[0]->desc.buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_fusedgroupnorm_reduce_stats_st_bwd_custom( handles, num_handles, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_backward.c b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_backward.c new file mode 100644 index 0000000000000000000000000000000000000000..1cdc7142aa54e8b13a97191a3786ebbf2983ffb2 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_backward.c @@ -0,0 +1,581 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_fusedgroupnorm_backward.h" +#include "libxsmm_main.h" + +#if 0 +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDGN_BWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDGN_BWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDGN_BWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDGN_BWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDGN_BWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDGN_BWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#endif + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if all required tensors are bound */ + if ( handle->reg_input == 0 || handle->reg_gamma == 0 || + handle->grad_input == 0 || handle->grad_output == 0 || + handle->grad_beta == 0 || handle->grad_gamma == 0 || + handle->expvalue == 0 || handle->rcpstddev == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_GN) > 0 ) { + if ( handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { + if ( handle->grad_add == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { + if ( handle->reg_output == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { + if ( handle->relumask == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + + /* check if we are on an AVX512 platform */ +#if 0 +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 16) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c16( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c16( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 32) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c32( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c32( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 64) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c64( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c64( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDGN_BWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDGN_BWD_BF16 + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_nhwc(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + LIBXSMM_UNUSED( handle ); + LIBXSMM_UNUSED( start_thread ); + LIBXSMM_UNUSED( tid ); + return status; +} + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_reduce_stats_st_bwd_custom(libxsmm_dnn_fusedgroupnorm** handles, int num_handles, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + int l_count; + + /* check if all required tensors are bound */ + for ( l_count = 0; l_count < num_handles; ++l_count ) { + if ( handles[l_count]->grad_beta == 0 || handles[l_count]->grad_gamma == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + +#if 0 + /* check if we are on an AVX512 platform */ + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) { + status = libxsmm_dnn_fusedgroupnorm_reduce_stats_st_bwd_custom_avx512( handles, num_handles, start_thread, tid ); + } else +#endif + { + const int nBlocksFm = handles[0]->blocksifm; + const int nFmBlock = handles[0]->ifmblock; + /* computing first logical thread */ + const int ltid = tid - start_thread; + /* number of tasks that could be run in parallel */ + const int work2 = nBlocksFm; + /* compute chunk size */ + const int chunksize2 = (work2 % handles[0]->desc.threads == 0) ? (work2 / handles[0]->desc.threads) : ((work2 / handles[0]->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; + const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + int v = 0, fm; + + LIBXSMM_VLA_DECL(2, float, dgamma0, (float*)handles[0]->grad_gamma->data, nFmBlock); + LIBXSMM_VLA_DECL(2, float, dbeta0, (float*)handles[0]->grad_beta->data, nFmBlock); + + /* lazy barrier init */ + libxsmm_barrier_init(handles[0]->barrier, ltid); + + /* now we need to reduce the dgamma and dbeta */ + for ( l_count = 1; l_count < num_handles; ++l_count ) { + LIBXSMM_VLA_DECL(2, float, dgammar, (float*)handles[l_count]->grad_gamma->data, nFmBlock); + LIBXSMM_VLA_DECL(2, float, dbetar, (float*)handles[l_count]->grad_beta->data, nFmBlock); + + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock); + float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock); + float* dgammar_ptr = &LIBXSMM_VLA_ACCESS(2, dgammar, fm, 0, nFmBlock); + float* dbetar_ptr = &LIBXSMM_VLA_ACCESS(2, dbetar, fm, 0, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + dgamma0_ptr[v] += dgammar_ptr[v]; + dbeta0_ptr[v] += dbetar_ptr[v]; + } + } + } + + for ( l_count = 1; l_count < num_handles; ++l_count ) { + LIBXSMM_VLA_DECL(2, float, dgammar, (float*)handles[l_count]->grad_gamma->data, nFmBlock); + LIBXSMM_VLA_DECL(2, float, dbetar, (float*)handles[l_count]->grad_beta->data, nFmBlock); + + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock); + float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock); + float* dgammar_ptr = &LIBXSMM_VLA_ACCESS(2, dgammar, fm, 0, nFmBlock); + float* dbetar_ptr = &LIBXSMM_VLA_ACCESS(2, dbetar, fm, 0, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + dgammar_ptr[v] = dgamma0_ptr[v]; + dbetar_ptr[v] = dbeta0_ptr[v]; + } + } + } + + libxsmm_barrier_wait(handles[0]->barrier, ltid); + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_backward.h b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..4ef94633993169b7032bea44fbf9c28373edcc05 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_backward.h @@ -0,0 +1,22 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_FUSEDGROUPNORM_BACKWARD_H +#define LIBXSMM_DNN_FUSEDGROUPNORM_BACKWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_nhwc(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_reduce_stats_st_bwd_custom(libxsmm_dnn_fusedgroupnorm** handles, int num_handles, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_FUSEDGROUPNORM_BACKWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_forward.c b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_forward.c new file mode 100644 index 0000000000000000000000000000000000000000..df7d3a9b70221b3b7b21eb445559d75a3c505cc8 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_forward.c @@ -0,0 +1,500 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_fusedgroupnorm_forward.h" +#include "libxsmm_main.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if 0 +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDGN_FWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDGN_FWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDGN_FWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDGN_FWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDGN_FWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDGN_FWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#endif + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if all required tensors are bound */ + if ( handle->reg_input == 0 || handle->reg_output == 0 || + handle->reg_beta == 0 || handle->reg_gamma == 0 || + handle->expvalue == 0 || handle->rcpstddev == 0 || handle->variance == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_GN) > 0 ) { + if ( handle->scratch == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) { + if ( handle->reg_add == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) { + if ( handle->relumask == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + } + + /* check if we are on an AVX512 platform */ +#if 0 +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 16) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c16( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 32) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c32( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 64) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c64( handle, start_thread, tid ); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_stats_type; + + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef float element_stats_type; + +# define LIBXSMM_DNN_FUSEDGN_FWD_BF16 + if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER; + } else { + if ( handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN ) { +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_RELU ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU + } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK ) { +# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK +# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK + } else { + status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION; + } + } +# undef LIBXSMM_DNN_FUSEDGN_FWD_BF16 + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_nhwc(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + LIBXSMM_UNUSED( handle ); + LIBXSMM_UNUSED( start_thread ); + LIBXSMM_UNUSED( tid ); + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_forward.h b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..41d11bfad4c156a74e81ca4eed8b4c0a949d2298 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_fusedgroupnorm_forward.h @@ -0,0 +1,20 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_FUSEDGROUPNORM_FORWARD_H +#define LIBXSMM_DNN_FUSEDGROUPNORM_FORWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_nhwc(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_FUSEDGROUPNORM_FORWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_optimizer.c b/third_party/libxsmm/src/libxsmm_dnn_optimizer.c new file mode 100644 index 0000000000000000000000000000000000000000..8d32287977ceb158831565ca9deafe79c4bcfb20 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_optimizer.c @@ -0,0 +1,345 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_optimizer_sgd.h" +#include "libxsmm_main.h" + + +LIBXSMM_API libxsmm_dnn_optimizer* libxsmm_dnn_create_optimizer(libxsmm_dnn_optimizer_desc optimizer_desc, libxsmm_dnn_err_t* status) { + libxsmm_dnn_optimizer* handle = 0; + + /* init libxsmm */ + LIBXSMM_INIT + + if ( (optimizer_desc.datatype == LIBXSMM_DNN_DATATYPE_F32) || (optimizer_desc.datatype == LIBXSMM_DNN_DATATYPE_BF16) ) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + handle = (libxsmm_dnn_optimizer*)calloc(1, sizeof(libxsmm_dnn_optimizer)); + + if (0 != handle) { + *status = LIBXSMM_DNN_SUCCESS; + /* let's make the description persistent */ + handle->desc = optimizer_desc; + + if ( (handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) { + /* we need to compute the memory layout given the */ + *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K, + &(handle->bc), &(handle->bk), &(handle->fm_lp_block), + handle->desc.datatype, handle->desc.datatype ); + /* compute the outer blocks */ + handle->Bc = handle->desc.C / handle->bc; + handle->Bk = handle->desc.K / handle->bk; + } else if ( (handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0 ) { + if ( optimizer_desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) { + handle->fm_lp_block = 1; + } else if ( optimizer_desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) { + handle->fm_lp_block = 2; + } else { + } + handle->bc = handle->desc.bc; + handle->bk = handle->desc.bk; + handle->Bc = handle->desc.C / handle->bc; + handle->Bk = handle->desc.K / handle->bk; + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + free( handle ); + handle = 0; + return handle; + } + /* create barrier */ + handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1); + /* calculate scratch size for local optimizer copies of one feature map block per thread */ + handle->scratch_size = 1; + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + } + } else { + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + + return handle; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_optimizer(const libxsmm_dnn_optimizer* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + /* Deallocate barrier */ + if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); } + /* deallocate handle structure */ + free(/*remove constness*/(libxsmm_dnn_optimizer*)handle); + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_optimizer_create_tensor_datalayout(const libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor_datalayout* layout; + + *status = LIBXSMM_DNN_SUCCESS; + layout = 0; + + if (handle != 0) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout)); + + if (layout != 0) { + layout->format = handle->desc.filter_format; + + if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) || (type == LIBXSMM_DNN_MASTER_FILTER) ) { + if ( ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) || ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) ) { + if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) { + layout->datatype = handle->desc.datatype; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = handle->bk; + layout->dim_size[1] = handle->bc; + layout->dim_size[2] = handle->Bc; + layout->dim_size[3] = handle->Bk; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) { + layout->datatype = handle->desc.datatype; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = handle->fm_lp_block; + layout->dim_size[1] = handle->bk; + layout->dim_size[2] = handle->bc/handle->fm_lp_block; + layout->dim_size[3] = handle->Bc; + layout->dim_size[4] = handle->Bk; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return layout; +} + + +LIBXSMM_API size_t libxsmm_dnn_optimizer_get_scratch_size(const libxsmm_dnn_optimizer* handle, libxsmm_dnn_err_t* status) { + size_t l_scratch_size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */ + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return l_scratch_size; +} + + +LIBXSMM_API void* libxsmm_dnn_optimizer_get_scratch_ptr(const libxsmm_dnn_optimizer* handle, libxsmm_dnn_err_t* status) +{ + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + return handle->scratch; + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return 0; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_bind_scratch(libxsmm_dnn_optimizer* handle, const void* scratch) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + uintptr_t address = (uintptr_t)scratch; + size_t offset = 0; + + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + + if (0 != handle) { + /* align the internal scratch buffer if needed */ + if (address % 64 == 0) { + handle->scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch = (void*)(address+offset); + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_release_scratch(libxsmm_dnn_optimizer* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + handle->scratch = 0; + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_bind_tensor(libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && (type != LIBXSMM_DNN_MASTER_FILTER) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0 && tensor != 0) { + libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_optimizer_create_tensor_datalayout(handle, type, &status); + + if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) { + if ( type == LIBXSMM_DNN_REGULAR_FILTER ) { + handle->reg_filter = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) { + handle->grad_filter = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_MASTER_FILTER ) { + handle->master_filter = (libxsmm_dnn_tensor*)tensor; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR; + } + + libxsmm_dnn_destroy_tensor_datalayout( handle_layout ); + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_optimizer_get_tensor(libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor* return_tensor = 0; + + *status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && (type != LIBXSMM_DNN_MASTER_FILTER) ) { + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return return_tensor; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_FILTER ) { + return_tensor = handle->reg_filter; + } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) { + return_tensor = handle->grad_filter; + } else if ( type == LIBXSMM_DNN_MASTER_FILTER ) { + return_tensor = handle->master_filter; + } else { + /* cannot happen */ + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return return_tensor; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_release_tensor(libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && (type != LIBXSMM_DNN_MASTER_FILTER) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_FILTER ) { + handle->reg_filter = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) { + handle->grad_filter = 0; + } else if ( type == LIBXSMM_DNN_MASTER_FILTER ) { + handle->master_filter = 0; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_execute_st(libxsmm_dnn_optimizer* handle, /*unsigned*/int start_thread, /*unsigned*/int tid) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + if (handle->desc.opt_type == LIBXSMM_DNN_OPTIMIZER_SGD) { + libxsmm_dnn_optimizer_sgd_st( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_optimizer_sgd.c b/third_party/libxsmm/src/libxsmm_dnn_optimizer_sgd.c new file mode 100644 index 0000000000000000000000000000000000000000..b1532c24a28e2aa712ce989b7ec514bdbeeb86a5 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_optimizer_sgd.c @@ -0,0 +1,103 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_optimizer_sgd.h" +#include "libxsmm_main.h" + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_f32_f32(libxsmm_dnn_optimizer* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_bf16_bf16(libxsmm_dnn_optimizer* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_f32_f32(libxsmm_dnn_optimizer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_filter_type; + +# define LIBXSMM_DNN_OPTIMIZER_SGD_F32_AVX512 +# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c" +# undef LIBXSMM_DNN_OPTIMIZER_SGD_F32_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_bf16_bf16(libxsmm_dnn_optimizer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_filter_type; + typedef float element_master_type; + +# define LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512 +# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c" +# undef LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st(libxsmm_dnn_optimizer* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have filter, grad_filter */ + if ( handle->reg_filter == 0 || handle->grad_filter == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + if ( (handle->master_filter == 0) && (handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16) ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) { + if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_optimizer_sgd_st_f32_f32( handle, start_thread, tid); + } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_optimizer_sgd_st_bf16_bf16( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_filter_type; + +# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c" + } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) { + typedef libxsmm_bfloat16 element_filter_type; + typedef float element_master_type; + +# define LIBXSMM_DNN_OPTIMIZER_SGD_BF16 +# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c" +# undef LIBXSMM_DNN_OPTIMIZER_SGD_BF16 + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_optimizer_sgd.h b/third_party/libxsmm/src/libxsmm_dnn_optimizer_sgd.h new file mode 100644 index 0000000000000000000000000000000000000000..7bc64fc845f25199349595086a8c759cb834b198 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_optimizer_sgd.h @@ -0,0 +1,18 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_OPTIMIZER_SGD_H +#define LIBXSMM_DNN_OPTIMIZER_SGD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st(libxsmm_dnn_optimizer* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_OPTIMIZER_SGD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_pooling.c b/third_party/libxsmm/src/libxsmm_dnn_pooling.c new file mode 100644 index 0000000000000000000000000000000000000000..764663d4467f7cc953af2328090f17bcfa600b65 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_pooling.c @@ -0,0 +1,451 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_pooling_backward.h" +#include "libxsmm_dnn_pooling_forward.h" +#include "libxsmm_main.h" + + +LIBXSMM_API libxsmm_dnn_pooling* libxsmm_dnn_create_pooling(libxsmm_dnn_pooling_desc pooling_desc, libxsmm_dnn_err_t* status) { + libxsmm_dnn_pooling* handle = 0; + int lpb; + + /* init libxsmm */ + LIBXSMM_INIT + + if ( ((pooling_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (pooling_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) || + ((pooling_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (pooling_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + handle = (libxsmm_dnn_pooling*)calloc(1, sizeof(libxsmm_dnn_pooling)); + + if (0 != handle) { + *status = LIBXSMM_DNN_SUCCESS; + /* let's make the description persistent */ + handle->desc = pooling_desc; + /* we need to compute the memory layout given the */ + *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C, + &(handle->ifmblock), &(handle->ofmblock), &lpb, + handle->desc.datatype_in, handle->desc.datatype_out ); + /* compute the outer blocks */ + handle->blocksifm = handle->desc.C / handle->ifmblock; + handle->blocksofm = handle->desc.C / handle->ofmblock; + /* setting ofh and ofw */ + handle->ofh = (handle->desc.H + 2 * handle->desc.pad_h - handle->desc.R) / handle->desc.u + 1; + handle->ofw = (handle->desc.W + 2 * handle->desc.pad_w - handle->desc.S) / handle->desc.v + 1; + /* create barrier */ + handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1); + /* calculate scratch size for local pooling copies of one feature map block per thread */ + handle->scratch_size = (sizeof(float) * ( (size_t)handle->desc.H + (size_t)LIBXSMM_MAX(handle->desc.pad_h_in, handle->desc.pad_h_out)*2 ) + * ( (size_t)handle->desc.W + (size_t)LIBXSMM_MAX(handle->desc.pad_w_in, handle->desc.pad_w_out)*2 ) + * LIBXSMM_MAX( handle->ofmblock, handle->ifmblock ) + * handle->desc.threads ); + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + } + } else { + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + + return handle; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_pooling(const libxsmm_dnn_pooling* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + /* Deallocate barrier */ + if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); } + /* deallocate handle structure */ + free(/*remove constness*/(libxsmm_dnn_pooling*)handle); + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_pooling_create_tensor_datalayout(const libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor_datalayout* layout; + + *status = LIBXSMM_DNN_SUCCESS; + layout = 0; + + if (handle != 0) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout)); + + if (layout != 0) { + layout->format = handle->desc.buffer_format; + + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) || + (type == LIBXSMM_DNN_POOLING_MASK) ) { + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) { + if ( type == LIBXSMM_DNN_POOLING_MASK ) { + layout->datatype = handle->desc.datatype_mask; + } else { + layout->datatype = LIBXSMM_DNN_DATATYPE_F32; + } + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in); + layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in); + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_POOLING_MASK) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->ofw; + layout->dim_size[2] = handle->ofh; + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + if ( type == LIBXSMM_DNN_POOLING_MASK ) { + layout->datatype = handle->desc.datatype_mask; + } else { + layout->datatype = LIBXSMM_DNN_DATATYPE_BF16; + } + + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 5; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->dim_size[0] = handle->ifmblock; + layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in); + layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in); + layout->dim_size[3] = handle->blocksifm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_POOLING_MASK) ) { + layout->dim_size[0] = handle->ofmblock; + layout->dim_size[1] = handle->ofw; + layout->dim_size[2] = handle->ofh; + layout->dim_size[3] = handle->blocksofm; + layout->dim_size[4] = handle->desc.N; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || + ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + if ( type == LIBXSMM_DNN_POOLING_MASK ) { + layout->datatype = handle->desc.datatype_mask; + } else { + layout->datatype = handle->desc.datatype_in; + } + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) { + layout->dim_size[0] = handle->desc.C; + layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in); + layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in); + layout->dim_size[3] = handle->desc.N; + } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + layout->dim_size[0] = handle->desc.C; + layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out); + layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out); + layout->dim_size[3] = handle->desc.N; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return layout; +} + +LIBXSMM_API size_t libxsmm_dnn_pooling_get_scratch_size(const libxsmm_dnn_pooling* handle, libxsmm_dnn_err_t* status) { + size_t l_scratch_size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */ + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return l_scratch_size; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_scratch(libxsmm_dnn_pooling* handle, const void* scratch) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + uintptr_t address = (uintptr_t)scratch; + size_t offset = 0; + + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + + if (0 != handle) { + /* align the internal scratch buffer if needed */ + if (address % 64 == 0) { + handle->scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch = (void*)(address+offset); + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_scratch(libxsmm_dnn_pooling* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + handle->scratch = 0; + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_POOLING_MASK) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0 && tensor != 0) { + libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_pooling_create_tensor_datalayout(handle, type, &status); + + if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_POOLING_MASK ) { + handle->mask = (libxsmm_dnn_tensor*)tensor; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR; + } + + libxsmm_dnn_destroy_tensor_datalayout( handle_layout ); + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_pooling_get_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor* return_tensor = 0; + + *status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_POOLING_MASK) ) { + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return return_tensor; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + return_tensor = handle->reg_input; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + return_tensor = handle->grad_input; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + return_tensor = handle->reg_output; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + return_tensor = handle->grad_output; + } else if ( type == LIBXSMM_DNN_POOLING_MASK ) { + return_tensor = handle->mask; + } else { + /* cannot happen */ + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return return_tensor; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) && + (type != LIBXSMM_DNN_POOLING_MASK) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) { + handle->grad_output = 0; + } else if ( type == LIBXSMM_DNN_POOLING_MASK ) { + handle->mask = 0; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_execute_st(libxsmm_dnn_pooling* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + switch (handle->desc.buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_pooling_st_fwd_custom( handle, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN; + } + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: { + switch (handle->desc.buffer_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: { + status = libxsmm_dnn_pooling_st_bwd_custom( handle, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_pooling_backward.c b/third_party/libxsmm/src/libxsmm_dnn_pooling_backward.c new file mode 100644 index 0000000000000000000000000000000000000000..6cffe9c8fa94f726e66b95b5debfdb6cb9d136d1 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_pooling_backward.c @@ -0,0 +1,301 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_pooling_backward.h" +#include "libxsmm_main.h" + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_BWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_BWD_AVG +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_BWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_BWD_AVG +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_BWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_BWD_AVG +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + +# define LIBXSMM_DNN_POOLING_BWD_BF16 + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_BWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_BWD_AVG +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +# undef LIBXSMM_DNN_POOLING_BWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + +# define LIBXSMM_DNN_POOLING_BWD_BF16 + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_BWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_BWD_AVG +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +# undef LIBXSMM_DNN_POOLING_BWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + +# define LIBXSMM_DNN_POOLING_BWD_BF16 + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_BWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_BWD_AVG +# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +# undef LIBXSMM_DNN_POOLING_BWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and mask */ + if ( handle->grad_input == 0 || handle->grad_output == 0 || + ( (handle->mask == 0) && (handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX) ) ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 16) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c16( handle, start_thread, tid); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c16( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 32) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c32( handle, start_thread, tid); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c32( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 64) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c64( handle, start_thread, tid); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c64( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_BWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_BWD_AVG +# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + +# define LIBXSMM_DNN_POOLING_BWD_BF16 + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_BWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_BWD_AVG +# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_POOLING_BWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +# undef LIBXSMM_DNN_POOLING_BWD_BF16 + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + LIBXSMM_UNUSED( handle ); + LIBXSMM_UNUSED( start_thread ); + LIBXSMM_UNUSED( tid ); + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_pooling_backward.h b/third_party/libxsmm/src/libxsmm_dnn_pooling_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..ce08683db2de38f4985980c029b6e93471ddf500 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_pooling_backward.h @@ -0,0 +1,20 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_POOLING_BACKWARD_H +#define LIBXSMM_DNN_POOLING_BACKWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_POOLING_BACKWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_pooling_forward.c b/third_party/libxsmm/src/libxsmm_dnn_pooling_forward.c new file mode 100644 index 0000000000000000000000000000000000000000..dc2a16d99fa66600e2467d5b598d2706f67ff944 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_pooling_forward.c @@ -0,0 +1,301 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_pooling_forward.h" +#include "libxsmm_main.h" + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_FWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_FWD_AVG +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_FWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_FWD_AVG +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_FWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_FWD_AVG +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + +# define LIBXSMM_DNN_POOLING_FWD_BF16 + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_FWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_FWD_AVG +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +# undef LIBXSMM_DNN_POOLING_FWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + +# define LIBXSMM_DNN_POOLING_FWD_BF16 + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_FWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_FWD_AVG +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +# undef LIBXSMM_DNN_POOLING_FWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + +# define LIBXSMM_DNN_POOLING_FWD_BF16 + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_FWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_FWD_AVG +# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +# undef LIBXSMM_DNN_POOLING_FWD_BF16 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and mask */ + if ( handle->reg_input == 0 || handle->reg_output == 0 || + ( (handle->mask == 0) && (handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX) ) ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 16) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16( handle, start_thread, tid); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 32) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32( handle, start_thread, tid); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) && + (handle->ofmblock == 64) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64( handle, start_thread, tid); + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + LIBXSMM_ASSERT(NULL != handle->mask); + status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_FWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_FWD_AVG +# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } + } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + +# define LIBXSMM_DNN_POOLING_FWD_BF16 + if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) { +# define LIBXSMM_DNN_POOLING_FWD_MAX + typedef int element_mask_type; +# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_MAX + } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) { +# define LIBXSMM_DNN_POOLING_FWD_AVG +# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c" +# undef LIBXSMM_DNN_POOLING_FWD_AVG + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING; + } +# undef LIBXSMM_DNN_POOLING_FWD_BF16 + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + LIBXSMM_UNUSED( handle ); + LIBXSMM_UNUSED( start_thread ); + LIBXSMM_UNUSED( tid ); + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_pooling_forward.h b/third_party/libxsmm/src/libxsmm_dnn_pooling_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..e7eb4322bfa0455d6a6831c0da2147dd80b3bde3 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_pooling_forward.h @@ -0,0 +1,20 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_POOLING_FORWARD_H +#define LIBXSMM_DNN_POOLING_FORWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid); + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_POOLING_FORWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_rnncell.c b/third_party/libxsmm/src/libxsmm_dnn_rnncell.c new file mode 100644 index 0000000000000000000000000000000000000000..ad3fa5b6af2ee06c93072a907b7639d109758745 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_rnncell.c @@ -0,0 +1,2357 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_rnncell_forward.h" +#include "libxsmm_dnn_rnncell_backward_weight_update.h" +#include "libxsmm_dnn_elementwise.h" +#include "libxsmm_main.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +LIBXSMM_API libxsmm_dnn_rnncell* libxsmm_dnn_create_rnncell(libxsmm_dnn_rnncell_desc rnncell_desc, libxsmm_dnn_err_t* status) +{ + libxsmm_dnn_rnncell* handle = 0; + + /* init libxsmm */ + LIBXSMM_INIT + + /* some check we can do before allocating the handle */ + if ( (rnncell_desc.datatype_in != rnncell_desc.datatype_out) || + ( (rnncell_desc.datatype_in != LIBXSMM_DNN_DATATYPE_BF16) && (rnncell_desc.datatype_in != LIBXSMM_DNN_DATATYPE_F32) ) ) { + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return NULL; + } + /* let's do some simple checks for BF16 as this limits the cell and architecture */ + if ( (rnncell_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (rnncell_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + if ( (LIBXSMM_X86_AVX512_CORE > libxsmm_target_archid) || (rnncell_desc.C % 16 != 0) || (rnncell_desc.K % 16 != 0) ) { + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return NULL; + } + } + /* we need at least one timestep */ + if (rnncell_desc.max_T < 1) { + *status = LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL; + return NULL; + } + + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + handle = (libxsmm_dnn_rnncell*)calloc(1, sizeof(libxsmm_dnn_rnncell)); + if (NULL != handle) { + *status = LIBXSMM_DNN_SUCCESS; + /* initialize known handle components */ + handle->desc = rnncell_desc; + /* set current seq length to max length */ + handle->T = rnncell_desc.max_T; + /* set blocking factors */ + handle->bk = (handle->desc.bk == 0) ? 64 : handle->desc.bk; + handle->bn = (handle->desc.bn == 0) ? 64 : handle->desc.bn; + handle->bc = (handle->desc.bc == 0) ? 64 : handle->desc.bc; + handle->use_fwd_fused_impl = handle->desc.use_fwd_fused_impl; + handle->fwd_block = handle->desc.fwd_block; + handle->bwdupd_block = handle->desc.bwdupd_block; + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + handle->lpb = 2; + } else { + handle->lpb = 1; + } + /* validate blocking factors */ + if ( handle->desc.N % handle->bn != 0 ) { + handle->bn = handle->desc.N; + *status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING; + } + if ( handle->desc.C % handle->bc != 0 ) { + handle->bc = handle->desc.C; + *status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING; + } + if ( handle->desc.K % handle->bk != 0 ) { + handle->bk = handle->desc.K; + *status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING; + } + + /* If in SPR, generate tilerelease kernel */ + if ((libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) { + int l_tr_flags = LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + handle->tilerelease_kernel = libxsmm_bsmmdispatch(handle->bk, handle->bk, handle->bk, NULL, NULL, NULL, NULL, NULL, &l_tr_flags, NULL); + } + + /* In case of BF16 for now hoist the BRGEMM and make them to use STRIDED variant by default */ + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + libxsmm_blasint BF, CB_BLOCKS, KB_BLOCKS; + const libxsmm_blasint K = handle->desc.K; + const libxsmm_blasint N = handle->desc.N; + const libxsmm_blasint C = handle->desc.C; + const libxsmm_blasint bk = handle->bk; + const libxsmm_blasint bn = handle->bn; + const libxsmm_blasint bc = handle->bc; + const libxsmm_blasint cBlocks = C/bc; + const libxsmm_blasint kBlocks = K/bk; + const libxsmm_blasint nBlocks = N/bn; + int tc_flags = 0; + int kernel_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + int stride_a, stride_b; + + if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) { + kernel_flags = ((handle->bk % 32 == 0) && (handle->bc % 32 == 0) && (handle->bn % 32 == 0)) ? LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG : 0; + kernel_flags = kernel_flags | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ); + } + + /* Blocking reduction domain if it is too large */ + BF = 1; + if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } + } + if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } + } + if (C == 2048 && K == 1024) { + BF = 2; + } + BF = handle->fwd_block; + + if (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) { + CB_BLOCKS = cBlocks/BF; + KB_BLOCKS = kBlocks/BF; + + /* define batch-reduce gemm kernels */ + stride_a = bc * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bc * bn * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->fwd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bc, stride_a, stride_b, CB_BLOCKS, &bk, &bc, &bk, NULL, NULL, &kernel_flags, NULL ); + stride_a = bk * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bk * bn * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->fwd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &bk, &bk, NULL, NULL, &kernel_flags, NULL ); + if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) { + handle->fwd_tileconfig = libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL ); + } + + BF = handle->bwdupd_block; + KB_BLOCKS = kBlocks/BF; + + stride_a = bc * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bk * bn * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->bwdupd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bc, bn, bk, stride_a, stride_b, KB_BLOCKS, &bc, &bk, &bc, NULL, NULL, &kernel_flags, NULL); + stride_a = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->bwdupd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bk, bn, stride_a, stride_b, nBlocks, &bk, &bn, &bk, NULL, NULL, &kernel_flags, NULL); + stride_a = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bn * bc * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->bwdupd_kernelc = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bc, bn, stride_a, stride_b, nBlocks, &bk, &bn, &bk, NULL, NULL, &kernel_flags, NULL); + stride_a = bk * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->bwdupd_kerneld = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &bk, &bk, NULL, NULL, &kernel_flags, NULL); + if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) { + handle->bwdupd_tileconfig = libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL); + } + } else { + CB_BLOCKS = cBlocks/BF; + KB_BLOCKS = kBlocks/BF; + + /* define batch-reduce gemm kernels */ + stride_a = bc * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bc * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->fwd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bc, stride_a, stride_b, CB_BLOCKS, &bk, &C, &K, NULL, NULL, &kernel_flags, NULL ); + stride_a = bk * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->fwd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL ); + if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) { + handle->fwd_tileconfig = libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL ); + } + + BF = handle->bwdupd_block; + KB_BLOCKS = kBlocks/BF; + + stride_a = bc * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->bwdupd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bc, bn, bk, stride_a, stride_b, KB_BLOCKS, &bc, &K, &C, NULL, NULL, &kernel_flags, NULL); + stride_a = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bn * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->bwdupd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bk, bn, stride_a, stride_b, nBlocks, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL); + stride_a = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bn * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->bwdupd_kernelc = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bc, bn, stride_a, stride_b, nBlocks, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL); + stride_a = bk * bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + stride_b = bk * libxsmm_dnn_typesize(handle->desc.datatype_in); + handle->bwdupd_kerneld = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL); + if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) { + handle->bwdupd_tileconfig = libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL); + } + } + } + + /* Need to allocate space for scratch libxsmm_dnn_tensor's, let's set all pointers to zero */ + handle->internal_z = 0; + handle->scratch_wT = 0; + handle->scratch_rT = 0; + handle->scratch_xT = 0; + handle->scratch_hT = 0; + handle->scratch_deltat = 0; + handle->scratch_di = 0; + handle->scratch_df = 0; + handle->scratch_do = 0; + handle->scratch_dci = 0; + handle->scratch_diB = 0; + handle->scratch_dfB = 0; + handle->scratch_dpB = 0; + handle->scratch_dciB = 0; + /* initialize a high-performant barrier */ + handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1); + if (NULL == handle->barrier) + { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + free(handle); + return NULL; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + } + return handle; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_rnncell(const libxsmm_dnn_rnncell* handle) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + if (0 != handle) { + /* Deallocate barrier */ + if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); } + /* deallocate handle structure */ + free(/*remove constness*/(libxsmm_dnn_rnncell*)handle); + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_rnncell_create_tensor_datalayout(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) +{ + libxsmm_dnn_tensor_datalayout* layout; + *status = LIBXSMM_DNN_SUCCESS; + layout = 0; + if (handle != 0) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout)); + if (layout != 0) { + if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) || + (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) || + (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) || + (type == LIBXSMM_DNN_RNN_REGULAR_CS) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS) || + (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) || + (type == LIBXSMM_DNN_RNN_INTERNAL_I) || (type == LIBXSMM_DNN_RNN_INTERNAL_F) || + (type == LIBXSMM_DNN_RNN_INTERNAL_O) || (type == LIBXSMM_DNN_RNN_INTERNAL_CI) || + (type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) { + layout->format = handle->desc.buffer_format; + layout->tensor_type = LIBXSMM_DNN_ACTIVATION; + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 5; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_T; + layout->dim_size[0] = (unsigned int)handle->bc; + layout->dim_size[1] = (unsigned int)handle->bn; + layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc); + layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn); + layout->dim_size[4] = (unsigned int)handle->desc.max_T; + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) || + (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) || + (type == LIBXSMM_DNN_RNN_REGULAR_CS) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS) || + (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) || + (type == LIBXSMM_DNN_RNN_INTERNAL_I) || (type == LIBXSMM_DNN_RNN_INTERNAL_F) || + (type == LIBXSMM_DNN_RNN_INTERNAL_O) || (type == LIBXSMM_DNN_RNN_INTERNAL_CI) || + (type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_T; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)handle->bn; + layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn); + layout->dim_size[4] = (unsigned int)handle->desc.max_T; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NC) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 3; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_T; + layout->dim_size[0] = (unsigned int)handle->desc.C; + layout->dim_size[1] = (unsigned int)handle->desc.N; + layout->dim_size[2] = (unsigned int)handle->desc.max_T; + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) || + (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) || + (type == LIBXSMM_DNN_RNN_REGULAR_CS) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS) || + (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) || + (type == LIBXSMM_DNN_RNN_INTERNAL_I) || (type == LIBXSMM_DNN_RNN_INTERNAL_F) || + (type == LIBXSMM_DNN_RNN_INTERNAL_O) || (type == LIBXSMM_DNN_RNN_INTERNAL_CI) || + (type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_T; + layout->dim_size[0] = (unsigned int)handle->desc.K; + layout->dim_size[1] = (unsigned int)handle->desc.N; + layout->dim_size[2] = (unsigned int)handle->desc.max_T; + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) || + (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) { + layout->format = handle->desc.filter_format; + layout->tensor_type = LIBXSMM_DNN_FILTER; + if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) { + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + layout->datatype = handle->desc.datatype_in; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 5; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)handle->bc; + layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[4] = 4; + } else { + layout->dim_size[4] = 3; + } + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[4] = 4; + } else { + layout->dim_size[4] = 3; + } + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)handle->bc; + layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + layout->datatype = handle->desc.datatype_in; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 6; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = (unsigned int)handle->lpb; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->bc / handle->lpb); + layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc); + layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk); + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[5] = 4; + } else { + layout->dim_size[5] = 3; + } + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = (unsigned int)handle->lpb; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk); + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[5] = 4; + } else { + layout->dim_size[5] = 3; + } + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 5; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = (unsigned int)handle->lpb; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->bc / handle->lpb); + layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc); + layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk); + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = (unsigned int)handle->lpb; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk); + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CK) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 2; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[0] = (unsigned int)(handle->desc.K * 4); + layout->dim_size[1] = (unsigned int)handle->desc.C; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + layout->dim_size[0] = (unsigned int)(handle->desc.K * 3); + layout->dim_size[1] = (unsigned int)handle->desc.C; + } else { + layout->dim_size[0] = (unsigned int)handle->desc.K; + layout->dim_size[1] = (unsigned int)handle->desc.C; + } + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[0] = (unsigned int)(handle->desc.K * 4); + layout->dim_size[1] = (unsigned int)handle->desc.K; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + layout->dim_size[0] = (unsigned int)(handle->desc.K * 3); + layout->dim_size[1] = (unsigned int)handle->desc.K; + } else { + layout->dim_size[0] = (unsigned int)handle->desc.K; + layout->dim_size[1] = (unsigned int)handle->desc.K; + } + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) || (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) { + layout->format = handle->desc.filter_format; + layout->tensor_type = LIBXSMM_DNN_FILTER; + if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) { + if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) { + layout->datatype = handle->desc.datatype_in; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 5; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = (unsigned int)handle->bc; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc); + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[4] = 4; + } else { + layout->dim_size[4] = 3; + } + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[4] = 4; + } else { + layout->dim_size[4] = 3; + } + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 4; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = (unsigned int)handle->bc; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc); + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = (unsigned int)handle->bk; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) { + layout->datatype = handle->desc.datatype_in; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 6; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = (unsigned int)handle->lpb; + layout->dim_size[1] = (unsigned int)handle->bc; + layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[4] = (unsigned int)(handle->desc.C / handle->bc); + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[5] = 4; + } else { + layout->dim_size[5] = 3; + } + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X; + layout->dim_size[0] = (unsigned int)handle->lpb; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk); + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[5] = 4; + } else { + layout->dim_size[5] = 3; + } + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 5; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_size[0] = (unsigned int)handle->lpb; + layout->dim_size[1] = (unsigned int)handle->bc; + layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[4] = (unsigned int)(handle->desc.C / handle->bc); + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_size[0] = (unsigned int)handle->lpb; + layout->dim_size[1] = (unsigned int)handle->bk; + layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb); + layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk); + layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk); + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CK) > 0) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 2; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[0] = (unsigned int)handle->desc.C; + layout->dim_size[1] = (unsigned int)(handle->desc.K * 4); + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + layout->dim_size[0] = (unsigned int)handle->desc.C; + layout->dim_size[1] = (unsigned int)(handle->desc.K * 3); + } else { + layout->dim_size[0] = (unsigned int)handle->desc.C; + layout->dim_size[1] = (unsigned int)handle->desc.K; + } + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[0] = (unsigned int)handle->desc.K; + layout->dim_size[1] = (unsigned int)(handle->desc.K * 4); + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + layout->dim_size[0] = (unsigned int)handle->desc.K; + layout->dim_size[1] = (unsigned int)(handle->desc.K * 3); + } else { + layout->dim_size[0] = (unsigned int)handle->desc.K; + layout->dim_size[1] = (unsigned int)handle->desc.K; + } + } else { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_BIAS) || (type == LIBXSMM_DNN_RNN_GRADIENT_BIAS) ) { + layout->format = handle->desc.buffer_format; + layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR; + + + if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NC) > 0) || ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) { + if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) { + layout->datatype = handle->desc.datatype_in; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */ + layout->num_dims = 1; + + if ( (type == LIBXSMM_DNN_RNN_REGULAR_BIAS) || (type == LIBXSMM_DNN_RNN_GRADIENT_BIAS) ) { + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + layout->dim_size[0] = (unsigned int)(handle->desc.K * 4); + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + layout->dim_size[0] = (unsigned int)(handle->desc.K * 3); + } else { + layout->dim_size[0] = (unsigned int)handle->desc.K; + } + } else { /* coverity[dead_error_begin] */ + free(layout->dim_type); + free(layout->dim_size); + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + return layout; +} + + +LIBXSMM_API size_t libxsmm_dnn_rnncell_get_scratch_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status) +{ + size_t size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + const size_t typesize_in = libxsmm_dnn_typesize(handle->desc.datatype_in); + const size_t dwdr_typesize = (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ? sizeof(float) : typesize_in; + + switch (handle->desc.cell_type) { + case LIBXSMM_DNN_RNNCELL_RNN_RELU: + case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID: + case LIBXSMM_DNN_RNNCELL_RNN_TANH: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + size += 0; + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in + 64; /* wT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in + 64; /* rT */ + size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; /* xT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) * (size_t)handle->desc.max_T + 64; /* deltat */ + + } break; + default: { + *status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_LSTM: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* w */ + size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* r */ + /* The scratches below are needed only for BF16 code for the intermediate results */ + if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) { + size += (size_t)7 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64); /* intermediate scratches */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64; /* intermediate scratches */ + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + size += (size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize * 4 + 4 * 64; /* w */ + size += (size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize * 4 + 4 * 64; /* r */ + size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* wT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* rT */ + size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; /* xT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; /* deltat */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* di */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* df */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* do */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dci */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* diB */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dfB */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dpB */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dciB */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t1 */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t2 */ + /* The scratches below are needed only for BF16 code for the intermediate results */ + if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) { + size += (size_t)4 *((size_t)handle->desc.K * sizeof(float) + 64); /* intermediate db scratch */ + size += (size_t)handle->desc.C * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; /* intermediate dx scratches */ + size += (size_t)7 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64); /* intermediate scratches */ + size += (size_t)2 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64); /* intermediate scratches */ + } + } break; + default: { + *status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_GRU: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* w */ + size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* r */ + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + size += (size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize * 3 + 3 * 64; /* w */ + size += (size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize * 3 + 3 * 64; /* r */ + size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* wT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* rT */ + size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; /* xT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; /* deltat */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* di */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dc */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* df */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* do */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* diB */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dcB */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dfB */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* oT */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t1 */ + size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t2 */ + } break; + default: { + *status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + default: { + *status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE; + } + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return size; +} + + +LIBXSMM_API void* libxsmm_dnn_rnncell_get_scratch_ptr(const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status) +{ + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + return handle->scratch_base; + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return NULL; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* scratch) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (NULL != handle) { + const size_t typesize_in = libxsmm_dnn_typesize(handle->desc.datatype_in); + const size_t dwdr_typesize = (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ? sizeof(float) : typesize_in; + uintptr_t address = (uintptr_t)scratch; + size_t offset = 0; + + switch (handle->desc.cell_type) { + case LIBXSMM_DNN_RNNCELL_RNN_RELU: + case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID: + case LIBXSMM_DNN_RNNCELL_RNN_TANH: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + /* forward only has no scratch need */ + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + handle->scratch_base = (void*)address; + /* wT */ + if (address % 64 == 0) { + handle->scratch_wT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_wT = (void*)(address+offset); + } + address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) + 64; + /* rT */ + if (address % 64 == 0) { + handle->scratch_rT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_rT = (void*)(address+offset); + } + address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) + 64; + /* xT */ + if (address % 64 == 0) { + handle->scratch_xT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_xT = (void*)(address+offset); + } + address += ((size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in) + 64; + /* hT */ + if (address % 64 == 0) { + handle->scratch_hT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_hT = (void*)(address+offset); + } + address += ((size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out)) + 64; + /* deltat */ + if (address % 64 == 0) { + handle->scratch_deltat = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_deltat = (void*)(address+offset); + } + address += ((size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) * (size_t)handle->desc.max_T) + 64; + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_LSTM: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + handle->scratch_base = (void*)address; + /* w scratch */ + if (address % 64 == 0) { + handle->scratch_w = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_w = (void*)(address+offset); + } + address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 4 + 64; + /* r scratch */ + if (address % 64 == 0) { + handle->scratch_r = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_r = (void*)(address+offset); + } + address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 4 + 64; + /* The scratches below are needed only for BF16 code for the intermediate results */ + if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) { + /* cst scratch */ + if (address % 64 == 0) { + handle->cst_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->cst_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* ht scratch */ + if (address % 64 == 0) { + handle->ht_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->ht_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* it scratch */ + if (address % 64 == 0) { + handle->it_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->it_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* ft scratch */ + if (address % 64 == 0) { + handle->ft_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->ft_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* ot scratch */ + if (address % 64 == 0) { + handle->ot_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->ot_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* cit scratch */ + if (address % 64 == 0) { + handle->cit_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->cit_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* cot scratch */ + if (address % 64 == 0) { + handle->cot_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->cot_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* csp scratch */ + if (address % 64 == 0) { + handle->csp_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->csp_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64; + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + handle->scratch_base = (void*)address; + /* w scratch */ + if (address % 64 == 0) { + handle->scratch_w = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_w = (void*)(address+offset); + } + address += ((size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize) * 4 + 64; + /* r scratch */ + if (address % 64 == 0) { + handle->scratch_r = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_r = (void*)(address+offset); + } + address += ((size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize) * 4 + 64; + /* wT */ + if (address % 64 == 0) { + handle->scratch_wT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_wT = (void*)(address+offset); + } + address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 4 + 64; + /* rT */ + if (address % 64 == 0) { + handle->scratch_rT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_rT = (void*)(address+offset); + } + address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 4 + 64; + /* xT */ + if (address % 64 == 0) { + handle->scratch_xT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_xT = (void*)(address+offset); + } + address += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; + /* hT */ + if (address % 64 == 0) { + handle->scratch_hT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_hT = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* deltat */ + if (address % 64 == 0) { + handle->scratch_deltat = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_deltat = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; + /* di */ + if (address % 64 == 0) { + handle->scratch_di = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_di = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* df */ + if (address % 64 == 0) { + handle->scratch_df = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_df = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* do */ + if (address % 64 == 0) { + handle->scratch_do = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_do = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* dci */ + if (address % 64 == 0) { + handle->scratch_dci = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dci = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* diB */ + if (address % 64 == 0) { + handle->scratch_diB = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_diB = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* dfB */ + if (address % 64 == 0) { + handle->scratch_dfB = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dfB = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* dpB */ + if (address % 64 == 0) { + handle->scratch_dpB = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dpB = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* dciB */ + if (address % 64 == 0) { + handle->scratch_dciB = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dciB = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* t1 */ + if (address % 64 == 0) { + handle->scratch_t1 = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_t1 = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* t2 */ + if (address % 64 == 0) { + handle->scratch_t2 = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_t2 = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* The scratches below are needed only for BF16 code for the intermediate results */ + if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) { + /* dx scratch */ + if (address % 64 == 0) { + handle->scratch_dx = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dx = (void*)(address+offset); + } + address += (size_t)handle->desc.C * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* dhp scratch */ + if (address % 64 == 0) { + handle->scratch_dhp = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dhp = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64; + /* db scratch */ + if (address % 64 == 0) { + handle->scratch_db = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_db = (void*)(address+offset); + } + address += (size_t)handle->desc.K * 4 * sizeof(float) + 64; + /* cst scratch */ + if (address % 64 == 0) { + handle->cst_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->cst_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* ht scratch */ + if (address % 64 == 0) { + handle->ht_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->ht_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* it scratch */ + if (address % 64 == 0) { + handle->it_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->it_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* ft scratch */ + if (address % 64 == 0) { + handle->ft_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->ft_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* ot scratch */ + if (address % 64 == 0) { + handle->ot_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->ot_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* cit scratch */ + if (address % 64 == 0) { + handle->cit_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->cit_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* cot scratch */ + if (address % 64 == 0) { + handle->cot_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->cot_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; + /* csp scratch */ + if (address % 64 == 0) { + handle->csp_scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->csp_scratch = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64; + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_GRU: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + handle->scratch_base = (void*)address; + /* w scratch */ + if (address % 64 == 0) { + handle->scratch_w = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_w = (void*)(address+offset); + } + address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 3 + 64; + /* r scratch */ + if (address % 64 == 0) { + handle->scratch_r = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_r = (void*)(address+offset); + } + address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 3 + 64; + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + handle->scratch_base = (void*)address; + /* w scratch */ + if (address % 64 == 0) { + handle->scratch_w = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_w = (void*)(address+offset); + } + address += ((size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize) * 3 + 64; + /* r scratch */ + if (address % 64 == 0) { + handle->scratch_r = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_r = (void*)(address+offset); + } + address += ((size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize) * 3 + 64; + /* wT */ + if (address % 64 == 0) { + handle->scratch_wT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_wT = (void*)(address+offset); + } + address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 3 + 64; + /* rT */ + if (address % 64 == 0) { + handle->scratch_rT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_rT = (void*)(address+offset); + } + address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 3 + 64; + /* xT */ + if (address % 64 == 0) { + handle->scratch_xT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_xT = (void*)(address+offset); + } + address += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; + /* hT */ + if (address % 64 == 0) { + handle->scratch_hT = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_hT = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* deltat */ + if (address % 64 == 0) { + handle->scratch_deltat = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_deltat = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; + /* di */ + if (address % 64 == 0) { + handle->scratch_di = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_di = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* dc */ + if (address % 64 == 0) { + handle->scratch_dci = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dci = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* df */ + if (address % 64 == 0) { + handle->scratch_df = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_df = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* do */ + if (address % 64 == 0) { + handle->scratch_do = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_do = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* diB */ + if (address % 64 == 0) { + handle->scratch_diB = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_diB = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* dcB */ + if (address % 64 == 0) { + handle->scratch_dciB = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dciB = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* dfB */ + if (address % 64 == 0) { + handle->scratch_dfB = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dfB = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* doB (repurposed for oT) */ + if (address % 64 == 0) { + handle->scratch_dpB = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_dpB = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* t1 */ + if (address % 64 == 0) { + handle->scratch_t1 = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_t1 = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + /* t2 */ + if (address % 64 == 0) { + handle->scratch_t2 = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch_t2 = (void*)(address+offset); + } + address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + switch (handle->desc.cell_type) { + case LIBXSMM_DNN_RNNCELL_RNN_RELU: + case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID: + case LIBXSMM_DNN_RNNCELL_RNN_TANH: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + /* forward only has no scratch need */ + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + handle->scratch_wT = 0; + handle->scratch_rT = 0; + handle->scratch_xT = 0; + handle->scratch_hT = 0; + handle->scratch_deltat = 0; + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_LSTM: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + handle->scratch_w = 0; + handle->scratch_r = 0; + handle->csp_scratch = 0; + handle->cst_scratch = 0; + handle->ht_scratch = 0; + handle->it_scratch = 0; + handle->ft_scratch = 0; + handle->ot_scratch = 0; + handle->cit_scratch = 0; + handle->cot_scratch = 0; + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + handle->scratch_w = 0; + handle->scratch_r = 0; + handle->scratch_wT = 0; + handle->scratch_rT = 0; + handle->scratch_xT = 0; + handle->scratch_hT = 0; + handle->scratch_deltat = 0; + handle->scratch_di = 0; + handle->scratch_df = 0; + handle->scratch_do = 0; + handle->scratch_dci = 0; + handle->scratch_diB = 0; + handle->scratch_dfB = 0; + handle->scratch_dpB = 0; + handle->scratch_dciB = 0; + handle->scratch_t1 = 0; + handle->scratch_t2 = 0; + handle->csp_scratch = 0; + handle->cst_scratch = 0; + handle->ht_scratch = 0; + handle->it_scratch = 0; + handle->ft_scratch = 0; + handle->ot_scratch = 0; + handle->cit_scratch = 0; + handle->cot_scratch = 0; + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_GRU: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + handle->scratch_w = 0; + handle->scratch_r = 0; + handle->ht_scratch = 0; + handle->it_scratch = 0; + handle->cit_scratch = 0; + handle->ft_scratch = 0; + handle->ot_scratch = 0; + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + handle->scratch_w = 0; + handle->scratch_r = 0; + handle->scratch_wT = 0; + handle->scratch_rT = 0; + handle->scratch_xT = 0; + handle->scratch_hT = 0; + handle->scratch_deltat = 0; + handle->scratch_di = 0; + handle->scratch_dci = 0; + handle->scratch_df = 0; + handle->scratch_do = 0; + handle->scratch_diB = 0; + handle->scratch_dciB = 0; + handle->scratch_dfB = 0; + handle->scratch_dpB = 0; + handle->scratch_t1 = 0; + handle->scratch_t2 = 0; + handle->ht_scratch = 0; + handle->it_scratch = 0; + handle->ft_scratch = 0; + handle->ot_scratch = 0; + handle->cit_scratch = 0; + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API size_t libxsmm_dnn_rnncell_get_internalstate_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status) +{ + size_t size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + const size_t sizeof_datatype = sizeof(float); + + switch (handle->desc.cell_type) { + case LIBXSMM_DNN_RNNCELL_RNN_RELU: + case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID: + case LIBXSMM_DNN_RNNCELL_RNN_TANH: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof_datatype * (size_t)handle->desc.max_T + 64; /* zt */ + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof_datatype * (size_t)handle->desc.max_T + 64; /* zt */ + } break; + default: { + *status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_LSTM: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + /* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */ + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + /* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */ + } break; + default: { + *status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_GRU: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + /* with i, f, c, o exposed as i/o, there is currently no need for internal state */ + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + /* with i, f, c, o exposed as i/o, there is currently no need for internal state */ + } break; + default: { + *status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + default: { + *status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE; + } + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return size; +} + + +LIBXSMM_API void* libxsmm_dnn_rnncell_get_internalstate_ptr(const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status) +{ + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + return handle->internal_z; + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return NULL; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* internalstate) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + uintptr_t address = (uintptr_t)internalstate; + size_t offset = 0; + + if (0 != handle) { + switch (handle->desc.cell_type) { + case LIBXSMM_DNN_RNNCELL_RNN_RELU: + case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID: + case LIBXSMM_DNN_RNNCELL_RNN_TANH: { + if (internalstate == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + if (address % 64 == 0) { + handle->internal_z = (void*)address; + } else { + offset = (64 - address % 64); + handle->internal_z = (void*)(address+offset); + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + if (address % 64 == 0) { + handle->internal_z = (void*)address; + } else { + offset = (64 - address % 64); + handle->internal_z = (void*)(address+offset); + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_LSTM: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_GRU: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + switch (handle->desc.cell_type) { + case LIBXSMM_DNN_RNNCELL_RNN_RELU: + case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID: + case LIBXSMM_DNN_RNNCELL_RNN_TANH: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + handle->internal_z = 0; + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + handle->internal_z = 0; + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_LSTM: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + case LIBXSMM_DNN_RNNCELL_GRU: { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: + case LIBXSMM_DNN_COMPUTE_KIND_ALL: { + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_allocate_forget_bias(libxsmm_dnn_rnncell* handle, const float forget_bias) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (handle != 0) { + handle->forget_bias = forget_bias; + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) && + (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) && + (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) && + (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) && + (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) && + (type != LIBXSMM_DNN_RNN_REGULAR_BIAS) && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS) && + (type != LIBXSMM_DNN_RNN_REGULAR_CS) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS) && + (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) && + (type != LIBXSMM_DNN_RNN_INTERNAL_I) && (type != LIBXSMM_DNN_RNN_INTERNAL_F) && + (type != LIBXSMM_DNN_RNN_INTERNAL_O) && (type != LIBXSMM_DNN_RNN_INTERNAL_CI) && + (type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0 && tensor != 0) { + libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_rnncell_create_tensor_datalayout(handle, type, &status); + + if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) { + if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) { + handle->xt = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) { + handle->dxt = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) { + handle->csp = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) { + handle->dcsp = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) { + handle->hp = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) { + handle->dhp = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) { + handle->w = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) { + handle->wt = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) { + handle->dw = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) { + handle->r = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) { + handle->rt = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) { + handle->dr = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) { + handle->b = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) { + handle->db = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) { + handle->cst = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) { + handle->dcs = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) { + handle->ht = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) { + handle->dht = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) { + handle->it = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) { + handle->ft = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) { + handle->ot = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) { + handle->cit = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) { + handle->cot = (libxsmm_dnn_tensor*)tensor; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR; + } + + libxsmm_dnn_destroy_tensor_datalayout( handle_layout ); + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_rnncell_get_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) +{ + libxsmm_dnn_tensor* tensor = 0; + LIBXSMM_UNUSED(status/*TODO*/); + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) && + (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) && + (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) && + (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) && + (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) && + (type != LIBXSMM_DNN_RNN_REGULAR_BIAS) && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS) && + (type != LIBXSMM_DNN_RNN_REGULAR_CS) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS) && + (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) && + (type != LIBXSMM_DNN_RNN_INTERNAL_I) && (type != LIBXSMM_DNN_RNN_INTERNAL_F) && + (type != LIBXSMM_DNN_RNN_INTERNAL_O) && (type != LIBXSMM_DNN_RNN_INTERNAL_CI) && + (type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) { + return tensor; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) { + tensor = handle->xt; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) { + tensor = handle->dxt; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) { + tensor = handle->csp; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) { + tensor = handle->dcsp; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) { + tensor = handle->hp; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) { + tensor = handle->dhp; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) { + tensor = handle->w; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) { + tensor = handle->wt; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) { + tensor = handle->dw; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) { + tensor = handle->r; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) { + tensor = handle->rt; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) { + tensor = handle->dr; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) { + tensor = handle->b; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) { + tensor = handle->db; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) { + tensor = handle->cst; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) { + tensor = handle->dcs; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) { + tensor = handle->ht; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) { + tensor = handle->dht; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) { + tensor = handle->it; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) { + tensor = handle->ft; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) { + tensor = handle->ot; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) { + tensor = handle->cit; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) { + tensor = handle->cot; + } else { + /* cannot happen */ + } + } + + return tensor; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) && + (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) && + (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) && + (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) && + (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) && + (type != LIBXSMM_DNN_RNN_REGULAR_BIAS) && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS) && + (type != LIBXSMM_DNN_RNN_REGULAR_CS) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS) && + (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) && + (type != LIBXSMM_DNN_RNN_INTERNAL_I) && (type != LIBXSMM_DNN_RNN_INTERNAL_F) && + (type != LIBXSMM_DNN_RNN_INTERNAL_O) && (type != LIBXSMM_DNN_RNN_INTERNAL_CI) && + (type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) { + handle->xt = 0; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) { + handle->dxt = 0; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) { + handle->csp = 0; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) { + handle->dcsp = 0; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) { + handle->hp = 0; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) { + handle->dhp = 0; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) { + handle->w = 0; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) { + handle->wt = 0; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) { + handle->dw = 0; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) { + handle->r = 0; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) { + handle->rt = 0; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) { + handle->dr = 0; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) { + handle->b = 0; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) { + handle->db = 0; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) { + handle->cst = 0; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) { + handle->dcs = 0; + } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) { + handle->ht = 0; + } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) { + handle->dht = 0; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) { + handle->it = 0; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) { + handle->ft = 0; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) { + handle->ot = 0; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) { + handle->cit = 0; + } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) { + handle->cot = 0; + } else { + /* cannot happen */ + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_set_sequence_length( libxsmm_dnn_rnncell* handle, const libxsmm_blasint T ) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + if ( handle->desc.max_T < T ) { + status = LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN; + } else { + handle->T = T; + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_blasint libxsmm_dnn_rnncell_get_sequence_length( libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status ) { + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + return handle->T; + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return 0; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_execute_st(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CK) ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_ck( handle, start_thread, tid ); + } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_kcck( handle, start_thread, tid ); + } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) { + status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck( handle, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: + case LIBXSMM_DNN_COMPUTE_KIND_UPD: + case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: { + if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CK) ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck( handle, kind, start_thread, tid ); + } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck( handle, kind, start_thread, tid ); + } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) { + status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck( handle, kind, start_thread, tid ); + } else { + status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.c b/third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.c new file mode 100644 index 0000000000000000000000000000000000000000..54cef8b6028e9658263d08e089e74434cadc72d9 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.c @@ -0,0 +1,1016 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Kunal Banerjee, Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_rnncell_backward_weight_update.h" +#include "libxsmm_dnn_elementwise.h" +#include "libxsmm_main.h" + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +void trans_act(short int *in, short int *out) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + __m512i v0, v1, v2, v3, v4, v5, v6, v7; + const __m512i idx_v = _mm512_set_epi64(13, 12, 7, 6, 9, 8, 3, 2); + const __mmask8 mask0 = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8(204); + const __mmask8 mask1 = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8(51); + const int in_width = 32, out_width = 32; + + r0 = _mm512_loadu_si512(in + 0*in_width); + r1 = _mm512_loadu_si512(in + 1*in_width); + t0 = _mm512_unpacklo_epi16(r0,r1); + t1 = _mm512_unpackhi_epi16(r0,r1); + r2 = _mm512_loadu_si512(in + 2*in_width); + r3 = _mm512_loadu_si512(in + 3*in_width); + t2 = _mm512_unpacklo_epi16(r2,r3); + t3 = _mm512_unpackhi_epi16(r2,r3); + r4 = _mm512_loadu_si512(in + 4*in_width); + r5 = _mm512_loadu_si512(in + 5*in_width); + t4 = _mm512_unpacklo_epi16(r4,r5); + t5 = _mm512_unpackhi_epi16(r4,r5); + r6 = _mm512_loadu_si512(in + 6*in_width); + r7 = _mm512_loadu_si512(in + 7*in_width); + t6 = _mm512_unpacklo_epi16(r6,r7); + t7 = _mm512_unpackhi_epi16(r6,r7); + r8 = _mm512_loadu_si512(in + 8*in_width); + r9 = _mm512_loadu_si512(in + 9*in_width); + t8 = _mm512_unpacklo_epi16(r8,r9); + t9 = _mm512_unpackhi_epi16(r8,r9); + ra = _mm512_loadu_si512(in + 10*in_width); + rb = _mm512_loadu_si512(in + 11*in_width); + ta = _mm512_unpacklo_epi16(ra,rb); + tb = _mm512_unpackhi_epi16(ra,rb); + rc = _mm512_loadu_si512(in + 12*in_width); + rd = _mm512_loadu_si512(in + 13*in_width); + tc = _mm512_unpacklo_epi16(rc,rd); + td = _mm512_unpackhi_epi16(rc,rd); + re = _mm512_loadu_si512(in + 14*in_width); + rf = _mm512_loadu_si512(in + 15*in_width); + te = _mm512_unpacklo_epi16(re,rf); + tf = _mm512_unpackhi_epi16(re,rf); + + r0 = _mm512_unpacklo_epi32(t0,t2); + r1 = _mm512_unpackhi_epi32(t0,t2); + r2 = _mm512_unpacklo_epi32(t1,t3); + r3 = _mm512_unpackhi_epi32(t1,t3); + r4 = _mm512_unpacklo_epi32(t4,t6); + r5 = _mm512_unpackhi_epi32(t4,t6); + r6 = _mm512_unpacklo_epi32(t5,t7); + r7 = _mm512_unpackhi_epi32(t5,t7); + r8 = _mm512_unpacklo_epi32(t8,ta); + r9 = _mm512_unpackhi_epi32(t8,ta); + ra = _mm512_unpacklo_epi32(t9,tb); + rb = _mm512_unpackhi_epi32(t9,tb); + rc = _mm512_unpacklo_epi32(tc,te); + rd = _mm512_unpackhi_epi32(tc,te); + re = _mm512_unpacklo_epi32(td,tf); + rf = _mm512_unpackhi_epi32(td,tf); + + t0 = _mm512_unpacklo_epi64(r0,r4); + t1 = _mm512_unpackhi_epi64(r0,r4); + t2 = _mm512_unpacklo_epi64(r1,r5); + t3 = _mm512_unpackhi_epi64(r1,r5); + t4 = _mm512_unpacklo_epi64(r2,r6); + t5 = _mm512_unpackhi_epi64(r2,r6); + t6 = _mm512_unpacklo_epi64(r3,r7); + t7 = _mm512_unpackhi_epi64(r3,r7); + t8 = _mm512_unpacklo_epi64(r8,rc); + t9 = _mm512_unpackhi_epi64(r8,rc); + ta = _mm512_unpacklo_epi64(r9,rd); + tb = _mm512_unpackhi_epi64(r9,rd); + tc = _mm512_unpacklo_epi64(ra,re); + td = _mm512_unpackhi_epi64(ra,re); + te = _mm512_unpacklo_epi64(rb,rf); + tf = _mm512_unpackhi_epi64(rb,rf); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r2 = _mm512_shuffle_i32x4(t4, t5, 0x88); + r3 = _mm512_shuffle_i32x4(t6, t7, 0x88); + r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd); + r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd); + r8 = _mm512_shuffle_i32x4(t8, t9, 0x88); + r9 = _mm512_shuffle_i32x4(ta, tb, 0x88); + ra = _mm512_shuffle_i32x4(tc, td, 0x88); + rb = _mm512_shuffle_i32x4(te, tf, 0x88); + rc = _mm512_shuffle_i32x4(t8, t9, 0xdd); + rd = _mm512_shuffle_i32x4(ta, tb, 0xdd); + re = _mm512_shuffle_i32x4(tc, td, 0xdd); + rf = _mm512_shuffle_i32x4(te, tf, 0xdd); + + v0 = _mm512_permutex2var_epi64(r0, idx_v, r8); + t0 = _mm512_mask_blend_epi64( mask0, r0, v0); + _mm256_storeu_si256((__m256i*)(out + 0*out_width), _mm512_extracti64x4_epi64(t0, 0)); + _mm256_storeu_si256((__m256i*)(out + 1*out_width), _mm512_extracti64x4_epi64(t0, 1)); + t8 = _mm512_mask_blend_epi64( mask1, r8, v0); + _mm256_storeu_si256((__m256i*)(out + 16*out_width), _mm512_extracti64x4_epi64(t8, 0)); + _mm256_storeu_si256((__m256i*)(out + 17*out_width), _mm512_extracti64x4_epi64(t8, 1)); + v1 = _mm512_permutex2var_epi64(r1, idx_v, r9); + t1 = _mm512_mask_blend_epi64( mask0, r1, v1); + _mm256_storeu_si256((__m256i*)(out + 2*out_width), _mm512_extracti64x4_epi64(t1, 0)); + _mm256_storeu_si256((__m256i*)(out + 3*out_width), _mm512_extracti64x4_epi64(t1, 1)); + t9 = _mm512_mask_blend_epi64( mask1, r9, v1); + _mm256_storeu_si256((__m256i*)(out + 18*out_width), _mm512_extracti64x4_epi64(t9, 0)); + _mm256_storeu_si256((__m256i*)(out + 19*out_width), _mm512_extracti64x4_epi64(t9, 1)); + v2 = _mm512_permutex2var_epi64(r2, idx_v, ra); + t2 = _mm512_mask_blend_epi64( mask0, r2, v2); + _mm256_storeu_si256((__m256i*)(out + 4*out_width), _mm512_extracti64x4_epi64(t2, 0)); + _mm256_storeu_si256((__m256i*)(out + 5*out_width), _mm512_extracti64x4_epi64(t2, 1)); + ta = _mm512_mask_blend_epi64( mask1, ra, v2); + _mm256_storeu_si256((__m256i*)(out + 20*out_width), _mm512_extracti64x4_epi64(ta, 0)); + _mm256_storeu_si256((__m256i*)(out + 21*out_width), _mm512_extracti64x4_epi64(ta, 1)); + v3 = _mm512_permutex2var_epi64(r3, idx_v, rb); + t3 = _mm512_mask_blend_epi64( mask0, r3, v3); + _mm256_storeu_si256((__m256i*)(out + 6*out_width), _mm512_extracti64x4_epi64(t3, 0)); + _mm256_storeu_si256((__m256i*)(out + 7*out_width), _mm512_extracti64x4_epi64(t3, 1)); + tb = _mm512_mask_blend_epi64( mask1, rb, v3); + _mm256_storeu_si256((__m256i*)(out + 22*out_width), _mm512_extracti64x4_epi64(tb, 0)); + _mm256_storeu_si256((__m256i*)(out + 23*out_width), _mm512_extracti64x4_epi64(tb, 1)); + v4 = _mm512_permutex2var_epi64(r4, idx_v, rc); + t4 = _mm512_mask_blend_epi64( mask0, r4, v4); + _mm256_storeu_si256((__m256i*)(out + 8*out_width), _mm512_extracti64x4_epi64(t4, 0)); + _mm256_storeu_si256((__m256i*)(out + 9*out_width), _mm512_extracti64x4_epi64(t4, 1)); + tc = _mm512_mask_blend_epi64( mask1, rc, v4); + _mm256_storeu_si256((__m256i*)(out + 24*out_width), _mm512_extracti64x4_epi64(tc, 0)); + _mm256_storeu_si256((__m256i*)(out + 25*out_width), _mm512_extracti64x4_epi64(tc, 1)); + v5 = _mm512_permutex2var_epi64(r5, idx_v, rd); + t5 = _mm512_mask_blend_epi64( mask0, r5, v5); + _mm256_storeu_si256((__m256i*)(out + 10*out_width), _mm512_extracti64x4_epi64(t5, 0)); + _mm256_storeu_si256((__m256i*)(out + 11*out_width), _mm512_extracti64x4_epi64(t5, 1)); + td = _mm512_mask_blend_epi64( mask1, rd, v5); + _mm256_storeu_si256((__m256i*)(out + 26*out_width), _mm512_extracti64x4_epi64(td, 0)); + _mm256_storeu_si256((__m256i*)(out + 27*out_width), _mm512_extracti64x4_epi64(td, 1)); + v6 = _mm512_permutex2var_epi64(r6, idx_v, re); + t6 = _mm512_mask_blend_epi64( mask0, r6, v6); + _mm256_storeu_si256((__m256i*)(out + 12*out_width), _mm512_extracti64x4_epi64(t6, 0)); + _mm256_storeu_si256((__m256i*)(out + 13*out_width), _mm512_extracti64x4_epi64(t6, 1)); + te = _mm512_mask_blend_epi64( mask1, re, v6); + _mm256_storeu_si256((__m256i*)(out + 28*out_width), _mm512_extracti64x4_epi64(te, 0)); + _mm256_storeu_si256((__m256i*)(out + 29*out_width), _mm512_extracti64x4_epi64(te, 1)); + v7 = _mm512_permutex2var_epi64(r7, idx_v, rf); + t7 = _mm512_mask_blend_epi64( mask0, r7, v7); + _mm256_storeu_si256((__m256i*)(out + 14*out_width), _mm512_extracti64x4_epi64(t7, 0)); + _mm256_storeu_si256((__m256i*)(out + 15*out_width), _mm512_extracti64x4_epi64(t7, 1)); + tf = _mm512_mask_blend_epi64( mask1, rf, v7); + _mm256_storeu_si256((__m256i*)(out + 30*out_width), _mm512_extracti64x4_epi64(tf, 0)); + _mm256_storeu_si256((__m256i*)(out + 31*out_width), _mm512_extracti64x4_epi64(tf, 1)); + + r0 = _mm512_loadu_si512(in + 16*32 + 0*in_width); + r1 = _mm512_loadu_si512(in + 16*32 + 1*in_width); + t0 = _mm512_unpacklo_epi16(r0,r1); + t1 = _mm512_unpackhi_epi16(r0,r1); + r2 = _mm512_loadu_si512(in + 16*32 + 2*in_width); + r3 = _mm512_loadu_si512(in + 16*32 + 3*in_width); + t2 = _mm512_unpacklo_epi16(r2,r3); + t3 = _mm512_unpackhi_epi16(r2,r3); + r4 = _mm512_loadu_si512(in + 16*32 + 4*in_width); + r5 = _mm512_loadu_si512(in + 16*32 + 5*in_width); + t4 = _mm512_unpacklo_epi16(r4,r5); + t5 = _mm512_unpackhi_epi16(r4,r5); + r6 = _mm512_loadu_si512(in + 16*32 + 6*in_width); + r7 = _mm512_loadu_si512(in + 16*32 + 7*in_width); + t6 = _mm512_unpacklo_epi16(r6,r7); + t7 = _mm512_unpackhi_epi16(r6,r7); + r8 = _mm512_loadu_si512(in + 16*32 + 8*in_width); + r9 = _mm512_loadu_si512(in + 16*32 + 9*in_width); + t8 = _mm512_unpacklo_epi16(r8,r9); + t9 = _mm512_unpackhi_epi16(r8,r9); + ra = _mm512_loadu_si512(in + 16*32 + 10*in_width); + rb = _mm512_loadu_si512(in + 16*32 + 11*in_width); + ta = _mm512_unpacklo_epi16(ra,rb); + tb = _mm512_unpackhi_epi16(ra,rb); + rc = _mm512_loadu_si512(in + 16*32 + 12*in_width); + rd = _mm512_loadu_si512(in + 16*32 + 13*in_width); + tc = _mm512_unpacklo_epi16(rc,rd); + td = _mm512_unpackhi_epi16(rc,rd); + re = _mm512_loadu_si512(in + 16*32 + 14*in_width); + rf = _mm512_loadu_si512(in + 16*32 + 15*in_width); + te = _mm512_unpacklo_epi16(re,rf); + tf = _mm512_unpackhi_epi16(re,rf); + + r0 = _mm512_unpacklo_epi32(t0,t2); + r1 = _mm512_unpackhi_epi32(t0,t2); + r2 = _mm512_unpacklo_epi32(t1,t3); + r3 = _mm512_unpackhi_epi32(t1,t3); + r4 = _mm512_unpacklo_epi32(t4,t6); + r5 = _mm512_unpackhi_epi32(t4,t6); + r6 = _mm512_unpacklo_epi32(t5,t7); + r7 = _mm512_unpackhi_epi32(t5,t7); + r8 = _mm512_unpacklo_epi32(t8,ta); + r9 = _mm512_unpackhi_epi32(t8,ta); + ra = _mm512_unpacklo_epi32(t9,tb); + rb = _mm512_unpackhi_epi32(t9,tb); + rc = _mm512_unpacklo_epi32(tc,te); + rd = _mm512_unpackhi_epi32(tc,te); + re = _mm512_unpacklo_epi32(td,tf); + rf = _mm512_unpackhi_epi32(td,tf); + + t0 = _mm512_unpacklo_epi64(r0,r4); + t1 = _mm512_unpackhi_epi64(r0,r4); + t2 = _mm512_unpacklo_epi64(r1,r5); + t3 = _mm512_unpackhi_epi64(r1,r5); + t4 = _mm512_unpacklo_epi64(r2,r6); + t5 = _mm512_unpackhi_epi64(r2,r6); + t6 = _mm512_unpacklo_epi64(r3,r7); + t7 = _mm512_unpackhi_epi64(r3,r7); + t8 = _mm512_unpacklo_epi64(r8,rc); + t9 = _mm512_unpackhi_epi64(r8,rc); + ta = _mm512_unpacklo_epi64(r9,rd); + tb = _mm512_unpackhi_epi64(r9,rd); + tc = _mm512_unpacklo_epi64(ra,re); + td = _mm512_unpackhi_epi64(ra,re); + te = _mm512_unpacklo_epi64(rb,rf); + tf = _mm512_unpackhi_epi64(rb,rf); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r2 = _mm512_shuffle_i32x4(t4, t5, 0x88); + r3 = _mm512_shuffle_i32x4(t6, t7, 0x88); + r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd); + r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd); + r8 = _mm512_shuffle_i32x4(t8, t9, 0x88); + r9 = _mm512_shuffle_i32x4(ta, tb, 0x88); + ra = _mm512_shuffle_i32x4(tc, td, 0x88); + rb = _mm512_shuffle_i32x4(te, tf, 0x88); + rc = _mm512_shuffle_i32x4(t8, t9, 0xdd); + rd = _mm512_shuffle_i32x4(ta, tb, 0xdd); + re = _mm512_shuffle_i32x4(tc, td, 0xdd); + rf = _mm512_shuffle_i32x4(te, tf, 0xdd); + + v0 = _mm512_permutex2var_epi64(r0, idx_v, r8); + t0 = _mm512_mask_blend_epi64( mask0, r0, v0); + _mm256_storeu_si256((__m256i*)(out + 16 + 0*out_width), _mm512_extracti64x4_epi64(t0, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 1*out_width), _mm512_extracti64x4_epi64(t0, 1)); + t8 = _mm512_mask_blend_epi64( mask1, r8, v0); + _mm256_storeu_si256((__m256i*)(out + 16 + 16*out_width), _mm512_extracti64x4_epi64(t8, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 17*out_width), _mm512_extracti64x4_epi64(t8, 1)); + v1 = _mm512_permutex2var_epi64(r1, idx_v, r9); + t1 = _mm512_mask_blend_epi64( mask0, r1, v1); + _mm256_storeu_si256((__m256i*)(out + 16 + 2*out_width), _mm512_extracti64x4_epi64(t1, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 3*out_width), _mm512_extracti64x4_epi64(t1, 1)); + t9 = _mm512_mask_blend_epi64( mask1, r9, v1); + _mm256_storeu_si256((__m256i*)(out + 16 + 18*out_width), _mm512_extracti64x4_epi64(t9, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 19*out_width), _mm512_extracti64x4_epi64(t9, 1)); + v2 = _mm512_permutex2var_epi64(r2, idx_v, ra); + t2 = _mm512_mask_blend_epi64( mask0, r2, v2); + _mm256_storeu_si256((__m256i*)(out + 16 + 4*out_width), _mm512_extracti64x4_epi64(t2, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 5*out_width), _mm512_extracti64x4_epi64(t2, 1)); + ta = _mm512_mask_blend_epi64( mask1, ra, v2); + _mm256_storeu_si256((__m256i*)(out + 16 + 20*out_width), _mm512_extracti64x4_epi64(ta, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 21*out_width), _mm512_extracti64x4_epi64(ta, 1)); + v3 = _mm512_permutex2var_epi64(r3, idx_v, rb); + t3 = _mm512_mask_blend_epi64( mask0, r3, v3); + _mm256_storeu_si256((__m256i*)(out + 16 + 6*out_width), _mm512_extracti64x4_epi64(t3, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 7*out_width), _mm512_extracti64x4_epi64(t3, 1)); + tb = _mm512_mask_blend_epi64( mask1, rb, v3); + _mm256_storeu_si256((__m256i*)(out + 16 + 22*out_width), _mm512_extracti64x4_epi64(tb, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 23*out_width), _mm512_extracti64x4_epi64(tb, 1)); + v4 = _mm512_permutex2var_epi64(r4, idx_v, rc); + t4 = _mm512_mask_blend_epi64( mask0, r4, v4); + _mm256_storeu_si256((__m256i*)(out + 16 + 8*out_width), _mm512_extracti64x4_epi64(t4, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 9*out_width), _mm512_extracti64x4_epi64(t4, 1)); + tc = _mm512_mask_blend_epi64( mask1, rc, v4); + _mm256_storeu_si256((__m256i*)(out + 16 + 24*out_width), _mm512_extracti64x4_epi64(tc, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 25*out_width), _mm512_extracti64x4_epi64(tc, 1)); + v5 = _mm512_permutex2var_epi64(r5, idx_v, rd); + t5 = _mm512_mask_blend_epi64( mask0, r5, v5); + _mm256_storeu_si256((__m256i*)(out + 16 + 10*out_width), _mm512_extracti64x4_epi64(t5, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 11*out_width), _mm512_extracti64x4_epi64(t5, 1)); + td = _mm512_mask_blend_epi64( mask1, rd, v5); + _mm256_storeu_si256((__m256i*)(out + 16 + 26*out_width), _mm512_extracti64x4_epi64(td, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 27*out_width), _mm512_extracti64x4_epi64(td, 1)); + v6 = _mm512_permutex2var_epi64(r6, idx_v, re); + t6 = _mm512_mask_blend_epi64( mask0, r6, v6); + _mm256_storeu_si256((__m256i*)(out + 16 + 12*out_width), _mm512_extracti64x4_epi64(t6, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 13*out_width), _mm512_extracti64x4_epi64(t6, 1)); + te = _mm512_mask_blend_epi64( mask1, re, v6); + _mm256_storeu_si256((__m256i*)(out + 16 + 28*out_width), _mm512_extracti64x4_epi64(te, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 29*out_width), _mm512_extracti64x4_epi64(te, 1)); + v7 = _mm512_permutex2var_epi64(r7, idx_v, rf); + t7 = _mm512_mask_blend_epi64( mask0, r7, v7); + _mm256_storeu_si256((__m256i*)(out + 16 + 14*out_width), _mm512_extracti64x4_epi64(t7, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 15*out_width), _mm512_extracti64x4_epi64(t7, 1)); + tf = _mm512_mask_blend_epi64( mask1, rf, v7); + _mm256_storeu_si256((__m256i*)(out + 16 + 30*out_width), _mm512_extracti64x4_epi64(tf, 0)); + _mm256_storeu_si256((__m256i*)(out + 16 + 31*out_width), _mm512_extracti64x4_epi64(tf, 1)); +#else + LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); +#endif +} + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ +#define LIBXSMM_RNN_CELL_AVX512 + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +# define LIBXSMM_DNN_RNN_RELU_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c" +# undef LIBXSMM_DNN_RNN_RELU_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +# define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c" +# undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +# define LIBXSMM_DNN_RNN_TANH_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c" +# undef LIBXSMM_DNN_RNN_TANH_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { +# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_ck_generic.tpl.c" + } else { + /* should not happen */ + } +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ +#define LIBXSMM_RNN_CELL_AVX512 + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ +#define LIBXSMM_RNN_CELL_AVX512 + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + return libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu(handle, kind, start_thread, tid); +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ +#define LIBXSMM_RNN_CELL_AVX512 +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16_amx.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ +#define LIBXSMM_RNN_CELL_AVX512 + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16_amx.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#endif + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ +#define LIBXSMM_RNN_CELL_AVX512 + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ +#define LIBXSMM_RNN_CELL_AVX512 + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_bf16_amx.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" + +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__ */ +#define LIBXSMM_RNN_CELL_AVX512 + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_bf16_amx.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ +#define LIBXSMM_RNN_CELL_AVX512 + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + return libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu(handle, kind, start_thread, tid); +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ +#define LIBXSMM_RNN_CELL_AVX512 +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16_amx.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16_amx.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#endif + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ +#define LIBXSMM_RNN_CELL_AVX512 + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +# define LIBXSMM_DNN_RNN_RELU_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c" +# undef LIBXSMM_DNN_RNN_RELU_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +# define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c" +# undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +# define LIBXSMM_DNN_RNN_TANH_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c" +# undef LIBXSMM_DNN_RNN_TANH_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { +# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_kcck.tpl.c" + } else { + /* should not happen */ + } +#undef LIBXSMM_RNN_CELL_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; +#if 0 + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_ncnc_kcck_generic.tpl.c" +#endif + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ +#if 0 + if (handle->? == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } +#endif + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_f32_f32( handle, kind, start_thread, tid ); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) { + if ( handle->desc.N % 2 != 0 ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu( handle, kind, start_thread, tid ); + } else if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16( handle, kind, start_thread, tid ); + } else if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx( handle, kind, start_thread, tid ); + } +#else + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu( handle, kind, start_thread, tid ); + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +#define LIBXSMM_DNN_RNN_RELU_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c" +#undef LIBXSMM_DNN_RNN_RELU_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +#define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c" +#undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +#define LIBXSMM_DNN_RNN_TANH_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c" +#undef LIBXSMM_DNN_RNN_TANH_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { +# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_ck_generic.tpl.c" + } else { + /* should not happen */ + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ +#if 0 + if (handle->? == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } +#endif + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_f32_f32( handle, kind, start_thread, tid ); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) { + if ( handle->desc.N % 2 != 0 ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid ); + } else if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16( handle, kind, start_thread, tid ); + } else if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid ); + } +#else + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid ); + } else if (libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid ); + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +#define LIBXSMM_DNN_RNN_RELU_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c" +#undef LIBXSMM_DNN_RNN_RELU_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +#define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c" +#undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +#define LIBXSMM_DNN_RNN_TANH_BWDUPD +# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c" +#undef LIBXSMM_DNN_RNN_TANH_BWDUPD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { +# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_kcck.tpl.c" + } else { + /* should not happen */ + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ +#if 0 + if (handle->? == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } +#endif + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32( handle, kind, start_thread, tid ); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#elif defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32( handle, kind, start_thread, tid ); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.h b/third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.h new file mode 100644 index 0000000000000000000000000000000000000000..47d5398845208f8ce332b8aadac1af58bf674492 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.h @@ -0,0 +1,21 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_RNNCELL_BACKWARD_WEIGHT_UPDATE_H +#define LIBXSMM_DNN_RNNCELL_BACKWARD_WEIGHT_UPDATE_H + +#include +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_RNNCELL_BACKWARD_WEIGHT_UPDATE_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.c b/third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.c new file mode 100644 index 0000000000000000000000000000000000000000..c61e41dff38cb1bdb32beca0a3e84c369ecd95d7 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.c @@ -0,0 +1,740 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_rnncell_forward.h" +#include "libxsmm_dnn_elementwise.h" +#include "libxsmm_main.h" + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +# define LIBXSMM_DNN_RNN_RELU_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c" +# undef LIBXSMM_DNN_RNN_RELU_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +# define LIBXSMM_DNN_RNN_SIGMOID_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c" +# undef LIBXSMM_DNN_RNN_SIGMOID_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +# define LIBXSMM_DNN_RNN_TANH_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c" +# undef LIBXSMM_DNN_RNN_TANH_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { +# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c" + } else { + /* should not happen */ + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__, __AVX512BW__, __AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__, __AVX512BW__, __AVX512DQ__, __AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + return libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(handle, start_thread, tid); +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__, __AVX512BW__, __AVX512DQ__, __AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__, __AVX512BW__, __AVX512DQ__ */ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16_amx.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#endif + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +# define LIBXSMM_DNN_RNN_RELU_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c" +# undef LIBXSMM_DNN_RNN_RELU_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +# define LIBXSMM_DNN_RNN_SIGMOID_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c" +# undef LIBXSMM_DNN_RNN_SIGMOID_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +# define LIBXSMM_DNN_RNN_TANH_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c" +# undef LIBXSMM_DNN_RNN_TANH_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +# define LIBXSMM_DNN_RNN_RELU_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c" +# undef LIBXSMM_DNN_RNN_RELU_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +# define LIBXSMM_DNN_RNN_SIGMOID_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c" +# undef LIBXSMM_DNN_RNN_SIGMOID_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +# define LIBXSMM_DNN_RNN_TANH_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c" +# undef LIBXSMM_DNN_RNN_TANH_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { +# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c" + } else { + /* should not happen */ + } +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + return libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(handle, start_thread, tid); +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_bf16_amx.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_bf16_amx.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#endif + +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + +#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16_amx.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#else +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__ */ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef libxsmm_bfloat16 element_filter_type; + + /* some portable macrros fof BF16 <-> FP32 */ +# include "template/libxsmm_dnn_bf16_macros_define.tpl.c" + + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +#define LIBXSMM_RNN_CELL_AVX512 +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16_amx.tpl.c" +#undef LIBXSMM_RNN_CELL_AVX512 + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + +# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} +#endif + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ +#if 0 + if (handle->? == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } +#endif + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32( handle, start_thread, tid); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx( handle, start_thread, tid); + } +#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx( handle, start_thread, tid); + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +#define LIBXSMM_DNN_RNN_RELU_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c" +#undef LIBXSMM_DNN_RNN_RELU_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +#define LIBXSMM_DNN_RNN_SIGMOID_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c" +#undef LIBXSMM_DNN_RNN_SIGMOID_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +#define LIBXSMM_DNN_RNN_TANH_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c" +#undef LIBXSMM_DNN_RNN_TANH_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { +# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c" + } else { + /* should not happen */ + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ +#if 0 + if (handle->? == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } +#endif + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#elif defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +#define LIBXSMM_DNN_RNN_RELU_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c" +#undef LIBXSMM_DNN_RNN_RELU_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +#define LIBXSMM_DNN_RNN_SIGMOID_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c" +#undef LIBXSMM_DNN_RNN_SIGMOID_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +#define LIBXSMM_DNN_RNN_TANH_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c" +#undef LIBXSMM_DNN_RNN_TANH_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { + status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED; + } else { + /* should not happen */ + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and filter */ +#if 0 + if (handle->? == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } +#endif + + /* check if we are on AVX512 */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) { + if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32( handle, start_thread, tid); + } +#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx( handle, start_thread, tid); + } +#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/ + else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR) { + status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu( handle, start_thread, tid); + } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) { + status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx( handle, start_thread, tid); + } +#endif + else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef float element_filter_type; + if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) { +#define LIBXSMM_DNN_RNN_RELU_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c" +#undef LIBXSMM_DNN_RNN_RELU_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) { +#define LIBXSMM_DNN_RNN_SIGMOID_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c" +#undef LIBXSMM_DNN_RNN_SIGMOID_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) { +#define LIBXSMM_DNN_RNN_TANH_FWD +# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c" +#undef LIBXSMM_DNN_RNN_TANH_FWD + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) { +# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c" + } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) { +# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c" + } else { + /* should not happen */ + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} diff --git a/third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.h b/third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..7cb2efecba3d2dbbce399a7e5cff7fcc2df1ed47 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.h @@ -0,0 +1,21 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_RNNCELL_FORWARD_H +#define LIBXSMM_DNN_RNNCELL_FORWARD_H + +#include +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_RNNCELL_FORWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_softmaxloss.c b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss.c new file mode 100644 index 0000000000000000000000000000000000000000..806f09fbef4d19029f34b77bd351152f4e2f2760 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss.c @@ -0,0 +1,382 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_softmaxloss_backward.h" +#include "libxsmm_dnn_softmaxloss_forward.h" +#include "libxsmm_main.h" + + +LIBXSMM_API libxsmm_dnn_softmaxloss* libxsmm_dnn_create_softmaxloss(libxsmm_dnn_softmaxloss_desc softmaxloss_desc, libxsmm_dnn_err_t* status) { + libxsmm_dnn_softmaxloss* handle = 0; + int lpb; + + /* init libxsmm */ + LIBXSMM_INIT + + if ( (softmaxloss_desc.datatype == LIBXSMM_DNN_DATATYPE_F32) || (softmaxloss_desc.datatype == LIBXSMM_DNN_DATATYPE_BF16) ) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + handle = (libxsmm_dnn_softmaxloss*)calloc(1, sizeof(libxsmm_dnn_softmaxloss)); + + if (0 != handle) { + *status = LIBXSMM_DNN_SUCCESS; + /* let's make the description persistent */ + handle->desc = softmaxloss_desc; + + /* cnn */ + if ( (handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) { + int bk; + /* we need to compute the memory layout given the */ + *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C, + &(handle->bc), &bk, &lpb, + handle->desc.datatype, handle->desc.datatype ); + /* compute the outer blocks */ + handle->Bc = handle->desc.C / handle->bc; + handle->bn = 1; + handle->Bn = handle->desc.N; + } else if ( (handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0 ) { + handle->bc = handle->desc.bc; + handle->bn = handle->desc.bn; + handle->Bc = handle->desc.C / handle->bc; + handle->Bn = handle->desc.N / handle->bn; + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + free( handle ); + handle = 0; + return handle; + } + /* create barrier */ + handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1); + /* calculate scratch size for local softmaxloss copies of one feature map block per thread */ + if ( softmaxloss_desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) { + handle->scratch_size = (sizeof(float)*handle->desc.C*handle->desc.N*2); + } else { + handle->scratch_size = 1; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_HANDLE; + } + } else { + *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + + return handle; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_softmaxloss(const libxsmm_dnn_softmaxloss* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + /* Deallocate barrier */ + if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); } + /* deallocate handle structure */ + free(/*remove constness*/(libxsmm_dnn_softmaxloss*)handle); + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_softmaxloss_create_tensor_datalayout(const libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor_datalayout* layout; + + *status = LIBXSMM_DNN_SUCCESS; + layout = 0; + + if (handle != 0) { + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout)); + + if (layout != 0) { + layout->format = handle->desc.buffer_format; + + if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) || + (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) { + if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) { + layout->datatype = handle->desc.datatype; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 3; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->bc; + layout->dim_size[1] = handle->Bc; + layout->dim_size[2] = handle->desc.N; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) { + layout->datatype = handle->desc.datatype; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 4; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C; + layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->bc; + layout->dim_size[1] = handle->bn; + layout->dim_size[2] = handle->Bc; + layout->dim_size[3] = handle->Bn; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL; + } + } else if ( type == LIBXSMM_DNN_LABEL ) { + layout->datatype = LIBXSMM_DNN_DATATYPE_I32; + layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype)); + layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int)); + + if (0 != layout->dim_type && 0 != layout->dim_size) { + layout->num_dims = 1; + layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_N; + layout->dim_size[0] = handle->desc.N; + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS; + } + } else { + free(layout); + layout = 0; /* make sure a NULL is returned */ + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return layout; +} + + +LIBXSMM_API size_t libxsmm_dnn_softmaxloss_get_scratch_size(const libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_err_t* status) { + size_t l_scratch_size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */ + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return l_scratch_size; +} + + +LIBXSMM_API void* libxsmm_dnn_softmaxloss_get_scratch_ptr(const libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_err_t* status) +{ + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + return handle->scratch; + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return 0; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_bind_scratch(libxsmm_dnn_softmaxloss* handle, const void* scratch) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + uintptr_t address = (uintptr_t)scratch; + size_t offset = 0; + + if (scratch == 0) { + status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED; + return status; + } + + if (0 != handle) { + /* align the internal scratch buffer if needed */ + if (address % 64 == 0) { + handle->scratch = (void*)address; + } else { + offset = (64 - address % 64); + handle->scratch = (void*)(address+offset); + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_release_scratch(libxsmm_dnn_softmaxloss* handle) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + handle->scratch = 0; + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_bind_tensor(libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_LABEL) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0 && tensor != 0) { + libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_softmaxloss_create_tensor_datalayout(handle, type, &status); + + if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = (libxsmm_dnn_tensor*)tensor; + } else if ( type == LIBXSMM_DNN_LABEL ) { + handle->label = (libxsmm_dnn_tensor*)tensor; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR; + } + + libxsmm_dnn_destroy_tensor_datalayout( handle_layout ); + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_softmaxloss_get_tensor(libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor* return_tensor = 0; + + *status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_LABEL) ) { + *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return return_tensor; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + return_tensor = handle->reg_input; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + return_tensor = handle->grad_input; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + return_tensor = handle->reg_output; + } else if ( type == LIBXSMM_DNN_LABEL ) { + return_tensor = handle->label; + } else { + /* cannot happen */ + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return return_tensor; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_release_tensor(libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor_type type) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check for tensor type */ + if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) && + (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_LABEL) ) { + status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE; + return status; + } + + if (handle != 0) { + if ( type == LIBXSMM_DNN_REGULAR_INPUT ) { + handle->reg_input = 0; + } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) { + handle->grad_input = 0; + } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) { + handle->reg_output = 0; + } else if ( type == LIBXSMM_DNN_LABEL ) { + handle->label = 0; + } else { + /* cannot happen */ + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_execute_st(libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_compute_kind kind, + /*unsigned*/int start_thread, /*unsigned*/int tid) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + switch (kind) { + case LIBXSMM_DNN_COMPUTE_KIND_FWD: { + status = libxsmm_dnn_softmaxloss_st_fwd_ncnc( handle, start_thread, tid ); + } break; + case LIBXSMM_DNN_COMPUTE_KIND_BWD: { + status = libxsmm_dnn_softmaxloss_st_bwd_ncnc( handle, start_thread, tid ); + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_KIND; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return status; +} + +LIBXSMM_API float libxsmm_dnn_softmaxloss_get_loss(const libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_err_t* status) { + float l_loss = 0.0f; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != handle) { + l_loss = handle->loss; + } else { + *status = LIBXSMM_DNN_ERR_INVALID_HANDLE; + } + + return l_loss; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.c b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.c new file mode 100644 index 0000000000000000000000000000000000000000..b9dd837cebbcce86090c6e7767bc59856f8acd72 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.c @@ -0,0 +1,103 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_softmaxloss_backward.h" +#include "libxsmm_main.h" + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef int element_label_type; + +# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef int element_label_type; + +# define LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512 +# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c" +# undef LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and mask */ + if ( handle->grad_input == 0 || handle->reg_output == 0 || handle->label == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) { + if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_softmaxloss_st_bwd_ncnc_f32_f32( handle, start_thread, tid); + } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_softmaxloss_st_bwd_ncnc_bf16_bf16( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef int element_label_type; + +# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c" + } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef int element_label_type; + +# define LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16 +# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c" +# undef LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16 + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.h b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..6fbe1b917a0d6d29fafc657828645ae1683ec7c4 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.h @@ -0,0 +1,18 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_SOFTMAXLOSS_BACKWARD_H +#define LIBXSMM_DNN_SOFTMAXLOSS_BACKWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_SOFTMAXLOSS_BACKWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.c b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.c new file mode 100644 index 0000000000000000000000000000000000000000..ee351b2ab07e9505dd7679a15ee26222bccd9aa6 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.c @@ -0,0 +1,103 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_dnn_softmaxloss_forward.h" +#include "libxsmm_main.h" + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid); +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid); + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef float element_input_type; + typedef float element_output_type; + typedef int element_label_type; + +# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c" +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef int element_label_type; + +# define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512 +# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c" +# undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512 +#else /* should not happen */ + LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); + status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH; +#endif + return status; +} + + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* check if we have input, output and mask */ + if ( handle->reg_input == 0 || handle->reg_output == 0 || handle->label == 0 ) { + status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND; + return status; + } + + /* check if we are on an AVX512 platform */ +#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/ + if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) { + if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) { + status = libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32( handle, start_thread, tid); + } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) { + status = libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16( handle, start_thread, tid); + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } else +#endif + { + if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) { + typedef float element_input_type; + typedef float element_output_type; + typedef int element_label_type; + +# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c" + } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) { + typedef libxsmm_bfloat16 element_input_type; + typedef libxsmm_bfloat16 element_output_type; + typedef int element_label_type; + +# define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16 +# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c" +# undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16 + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + return status; + } + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.h b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..e40464b849a80b445e1e361ddad793f2c66efb06 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.h @@ -0,0 +1,18 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_DNN_SOFTMAXLOSS_FORWARD_H +#define LIBXSMM_DNN_SOFTMAXLOSS_FORWARD_H + +#include + +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid); + +#endif /* LIBXSMM_DNN_SOFTMAXLOSS_FORWARD_H */ diff --git a/third_party/libxsmm/src/libxsmm_dnn_tensor.c b/third_party/libxsmm/src/libxsmm_dnn_tensor.c new file mode 100644 index 0000000000000000000000000000000000000000..e95010973930eebffc117dba2a5c9a4d1fa8c622 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_dnn_tensor.c @@ -0,0 +1,642 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include +#include "libxsmm_main.h" +#include "libxsmm_dnn_tensor.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(_OPENMP) +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_tensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, libxsmm_dnn_err_t* status) +{ + return libxsmm_dnn_link_qtensor(layout, data, 0, status); +} + + +LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_qtensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, const unsigned char scf, libxsmm_dnn_err_t* status) +{ + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + libxsmm_dnn_tensor* tensor = (libxsmm_dnn_tensor*)calloc(1, sizeof(libxsmm_dnn_tensor)); + *status = LIBXSMM_DNN_SUCCESS; + + if (layout != 0 && tensor != 0 && data != 0) { + tensor->layout = libxsmm_dnn_duplicate_tensor_datalayout(layout, status); + tensor->data = (void*)data; + tensor->scf = scf; + /* when layout copy failed, free layout */ + if (*status != LIBXSMM_DNN_SUCCESS) { + libxsmm_dnn_destroy_tensor_datalayout(tensor->layout); + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_TENSOR; + } + + if (*status != LIBXSMM_DNN_SUCCESS) { + free((libxsmm_dnn_tensor*)tensor); + tensor = 0; + } + + return tensor; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_duplicate_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor_datalayout* dst_layout; + + *status = LIBXSMM_DNN_SUCCESS; + dst_layout = 0; + + if (layout != 0 && layout->num_dims != 0) { + unsigned int dim = 0; + + /* zero entire content; not only safer but also sets data and code pointers to NULL */ + dst_layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout)); + if (0 != dst_layout) { + dst_layout->dim_type = (libxsmm_dnn_tensor_dimtype*)malloc(layout->num_dims * sizeof(libxsmm_dnn_tensor_dimtype)); + dst_layout->dim_size = (unsigned int*)malloc(layout->num_dims * sizeof(unsigned int)); + dst_layout->num_dims = layout->num_dims; + dst_layout->format = layout->format; + dst_layout->datatype = layout->datatype; + dst_layout->tensor_type = layout->tensor_type; + if (0 != dst_layout->dim_type && 0 != dst_layout->dim_size) { + for (dim = 0; dim < layout->num_dims; ++dim) { + dst_layout->dim_type[dim] = layout->dim_type[dim]; + dst_layout->dim_size[dim] = layout->dim_size[dim]; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } else { + *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT; + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT; + } + + return dst_layout; +} + + +LIBXSMM_API unsigned int libxsmm_dnn_compare_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout_a, const libxsmm_dnn_tensor_datalayout* layout_b, libxsmm_dnn_err_t* status) { + unsigned int result = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (layout_a != 0 && layout_b != 0) { + unsigned int dim = 0; + + if (layout_a->num_dims != layout_b->num_dims) { result = 1; } + if (layout_a->format != layout_b->format) { result = 1; } + if (layout_a->datatype != layout_b->datatype) { result = 1; } + + if (result == 0) { + for ( dim = 0; dim < layout_a->num_dims; ++dim ) { + if ( layout_a->dim_type[dim] != layout_b->dim_type[dim] ) { result = 1; } + if ( layout_a->dim_size[dim] != layout_b->dim_size[dim] ) { result = 1; } + } + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT; + result = 100; + } + + return result; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor_datalayout(libxsmm_dnn_tensor_datalayout* layout) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != layout) { + free(layout->dim_type); + free(layout->dim_size); + free(layout); + } + else { + status = LIBXSMM_DNN_ERR_INVALID_LAYOUT; + } + + return status; +} + + +LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_size(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) { + unsigned int size = 0; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != layout) { + unsigned int dim = 0; + size = (unsigned int)libxsmm_dnn_typesize(layout->datatype); + for (dim = 0; dim < layout->num_dims; ++dim) { + size *= layout->dim_size[dim]; + } + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT; + } + + return size; +} + + +LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_elements(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) { + unsigned int elements = 1; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != layout) { + unsigned int dim = 0; + for ( dim = 0; dim < layout->num_dims; ++dim ) { + elements *= layout->dim_size[dim]; + } + } else { + *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT; + elements = 0; + } + + return elements; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_tensor_data_ptr(libxsmm_dnn_tensor* tensor, const void* data) { + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if ((0 != tensor) && (0 != data)) { + if (0 != tensor->layout) { + if (0 < tensor->layout->num_dims) { + tensor->data = (void*)data; + } else { + status = LIBXSMM_DNN_ERR_INVALID_LAYOUT; + } + } else { + status = LIBXSMM_DNN_ERR_INVALID_LAYOUT; + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + + return status; +} + + +LIBXSMM_API void* libxsmm_dnn_get_tensor_data_ptr(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status) +{ + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != tensor) { + return tensor->data; + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + + return 0; +} + + +LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_get_tensor_datalayout(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status) { + libxsmm_dnn_tensor_datalayout* dst_layout = NULL; + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != tensor) { + dst_layout = libxsmm_dnn_duplicate_tensor_datalayout( tensor->layout, status ); + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + + return dst_layout; +} + + +LIBXSMM_API unsigned char libxsmm_dnn_get_qtensor_scf(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status) +{ + *status = LIBXSMM_DNN_SUCCESS; + + if (0 != tensor) { + return tensor->scf; + } + else { + *status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + + return 0; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_qtensor_scf(libxsmm_dnn_tensor* tensor, const unsigned char scf) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != tensor) { + tensor->scf = scf; + } + else { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor(const libxsmm_dnn_tensor* tensor) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != tensor) { /* it is not an error attempting to destroy a NULL-handle */ + /* free layout information stored in tensor */ + if (0 != tensor->layout) { + libxsmm_dnn_destroy_tensor_datalayout( (libxsmm_dnn_tensor_datalayout*)tensor->layout ); + } + /* deallocate handle structure */ + free(/*remove constness*/(libxsmm_dnn_tensor*)tensor); + } +#if 0 /* releasing a NULL-buffer should be not an error (similar to freeing a NULL pointer) */ + else { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } +#endif + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyin_tensor(const libxsmm_dnn_tensor* tensor, const void* data, const libxsmm_dnn_tensor_format in_format) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* @TODO check for valid combination */ + + if (0 != tensor) { + switch (tensor->layout->tensor_type) { + case LIBXSMM_DNN_REGULAR_INPUT: + case LIBXSMM_DNN_GRADIENT_INPUT: + case LIBXSMM_DNN_REGULAR_OUTPUT: + case LIBXSMM_DNN_GRADIENT_OUTPUT: + case LIBXSMM_DNN_INPUT: + case LIBXSMM_DNN_OUTPUT: + case LIBXSMM_DNN_ACTIVATION: { + switch (in_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: { + if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) { + switch (tensor->layout->datatype) { + case LIBXSMM_DNN_DATATYPE_F32: { + typedef float element_type; +#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_BF16: { + typedef libxsmm_bfloat16 element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + case LIBXSMM_DNN_DATATYPE_I32: { + typedef int element_type; +#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_I16: { + typedef short element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + case LIBXSMM_DNN_DATATYPE_I8: { + typedef unsigned char element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT; + } + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT; + } + } + } break; + case LIBXSMM_DNN_REGULAR_FILTER: + case LIBXSMM_DNN_GRADIENT_FILTER: + case LIBXSMM_DNN_FILTER: { + switch (in_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_KCRS: { + if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) { + switch (tensor->layout->datatype) { + case LIBXSMM_DNN_DATATYPE_F32: { + typedef float element_type; +#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_BF16: { + typedef libxsmm_bfloat16 element_type; +#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_I16: { + typedef short element_type; +#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_I8: { + typedef char element_type; +#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c" + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT; + } + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT; + } + } + } break; + case LIBXSMM_DNN_REGULAR_CHANNEL_BIAS: + case LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS: + case LIBXSMM_DNN_CHANNEL_BIAS: + case LIBXSMM_DNN_REGULAR_CHANNEL_BETA: + case LIBXSMM_DNN_GRADIENT_CHANNEL_BETA: + case LIBXSMM_DNN_CHANNEL_BETA: + case LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA: + case LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA: + case LIBXSMM_DNN_CHANNEL_GAMMA: + case LIBXSMM_DNN_CHANNEL_EXPECTVAL: + case LIBXSMM_DNN_CHANNEL_RCPSTDDEV: + case LIBXSMM_DNN_CHANNEL_VARIANCE: + case LIBXSMM_DNN_CHANNEL_SCALAR: { + switch (in_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: { + if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) { + switch (tensor->layout->datatype) { + case LIBXSMM_DNN_DATATYPE_F32: { + typedef float element_type; +#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_BF16: { + typedef libxsmm_bfloat16 element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + case LIBXSMM_DNN_DATATYPE_I16: { + typedef short element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + case LIBXSMM_DNN_DATATYPE_I8: { + typedef char element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT; + } + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_zero_tensor(const libxsmm_dnn_tensor* tensor) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + if (0 != tensor) { + const size_t size = libxsmm_dnn_get_tensor_elements( tensor->layout, &status ); + size_t i; + /* use for-loops to potentially leverage NUMA in the future */ + switch (tensor->layout->datatype) { + case LIBXSMM_DNN_DATATYPE_F32: { + float* fp32_data = (float*)tensor->data; + for (i = 0; i < size; ++i) fp32_data[i] = 0.0f; + } break; + case LIBXSMM_DNN_DATATYPE_BF16: { + libxsmm_bfloat16* bfp16_data = (libxsmm_bfloat16*)tensor->data; + for (i = 0; i < size; ++i) bfp16_data[i] = 0; + } break; + case LIBXSMM_DNN_DATATYPE_I32: { + int* int32_data = (int*)tensor->data; + for (i = 0; i < size; ++i) int32_data[i] = 0; + } break; + case LIBXSMM_DNN_DATATYPE_I16: { + short* int16_data = (short*)tensor->data; + for (i = 0; i < size; ++i) int16_data[i] = 0; + } break; + case LIBXSMM_DNN_DATATYPE_I8: { + char* int8_data = (char*)tensor->data; + for (i = 0; i < size; ++i) int8_data[i] = 0; + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + + return status; +} + + +LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyout_tensor(const libxsmm_dnn_tensor* tensor, void* data, const libxsmm_dnn_tensor_format out_format) +{ + libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS; + + /* @TODO check for valid combination */ + + if (0 != tensor) { + switch (tensor->layout->tensor_type) { + case LIBXSMM_DNN_REGULAR_INPUT: + case LIBXSMM_DNN_GRADIENT_INPUT: + case LIBXSMM_DNN_REGULAR_OUTPUT: + case LIBXSMM_DNN_GRADIENT_OUTPUT: + case LIBXSMM_DNN_INPUT: + case LIBXSMM_DNN_OUTPUT: + case LIBXSMM_DNN_ACTIVATION: { + switch (out_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: { + if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) { + switch (tensor->layout->datatype) { + case LIBXSMM_DNN_DATATYPE_F32: { + typedef float element_type; +#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_BF16: { + typedef libxsmm_bfloat16 element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + case LIBXSMM_DNN_DATATYPE_I32: { + typedef int element_type; +#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_I16: { + typedef short element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + case LIBXSMM_DNN_DATATYPE_I8: { + typedef unsigned char element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT; + } + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT; + } + } + } break; + case LIBXSMM_DNN_REGULAR_FILTER: + case LIBXSMM_DNN_GRADIENT_FILTER: + case LIBXSMM_DNN_FILTER: { + switch (out_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_KCRS: { + if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) { + switch (tensor->layout->datatype) { + case LIBXSMM_DNN_DATATYPE_F32: { + typedef float element_type; +#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c" + } break; + + case LIBXSMM_DNN_DATATYPE_BF16: { + typedef libxsmm_bfloat16 element_type; +#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_I32: { + typedef int element_type; +#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_I16: { + typedef short element_type; +#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_I8: { + typedef char element_type; +#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c" + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT; + } + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT; + } + } + } break; + case LIBXSMM_DNN_REGULAR_CHANNEL_BIAS: + case LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS: + case LIBXSMM_DNN_CHANNEL_BIAS: + case LIBXSMM_DNN_REGULAR_CHANNEL_BETA: + case LIBXSMM_DNN_GRADIENT_CHANNEL_BETA: + case LIBXSMM_DNN_CHANNEL_BETA: + case LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA: + case LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA: + case LIBXSMM_DNN_CHANNEL_GAMMA: + case LIBXSMM_DNN_CHANNEL_EXPECTVAL: + case LIBXSMM_DNN_CHANNEL_RCPSTDDEV: + case LIBXSMM_DNN_CHANNEL_VARIANCE: + case LIBXSMM_DNN_CHANNEL_SCALAR: { + switch (out_format) { + case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: { + if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) { + switch (tensor->layout->datatype) { + case LIBXSMM_DNN_DATATYPE_F32: { + typedef float element_type; +#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c" + } break; + case LIBXSMM_DNN_DATATYPE_BF16: { + typedef libxsmm_bfloat16 element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + case LIBXSMM_DNN_DATATYPE_I16: { + typedef short element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + case LIBXSMM_DNN_DATATYPE_I8: { + typedef char element_type; +#define LIBXSMM_DNN_COPY_LOW_PRECISION +#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c" +#undef LIBXSMM_DNN_COPY_LOW_PRECISION + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE; + } + } + } else { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT; + } + } break; + default: { + status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT; + } + } + } break; + default: { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + } + } + else { + status = LIBXSMM_DNN_ERR_INVALID_TENSOR; + } + + return status; +} + diff --git a/third_party/libxsmm/src/libxsmm_ext.c b/third_party/libxsmm/src/libxsmm_ext.c new file mode 100644 index 0000000000000000000000000000000000000000..42bc227a34c8ee047351879b47e0738ba7666ce0 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_ext.c @@ -0,0 +1,267 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_ext.h" +#include "libxsmm_gemm.h" +#include + + +#if defined(LIBXSMM_BUILD) +#if defined(LIBXSMM_BUILD_EXT) && !defined(__STATIC) + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK +void LIBXSMM_FSYMBOL(dgemm_batch)(const char transa_array[], const char transb_array[], + const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], + const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], + const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch) +{ + if (LIBXSMM_FSYMBOL(__real_dgemm_batch) != libxsmm_original_dgemm_batch_function) { + LIBXSMM_FSYMBOL(__wrap_dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } + else { + libxsmm_blas_error("dgemm_batch")(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK +void LIBXSMM_FSYMBOL(sgemm_batch)(const char transa_array[], const char transb_array[], + const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], + const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], + const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch) +{ + if (LIBXSMM_FSYMBOL(__real_sgemm_batch) != libxsmm_original_sgemm_batch_function) { + LIBXSMM_FSYMBOL(__wrap_sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } + else { + libxsmm_blas_error("sgemm_batch")(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK +void LIBXSMM_FSYMBOL(dgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const double* alpha, const double* a, const libxsmm_blasint* lda, + const double* b, const libxsmm_blasint* ldb, + const double* beta, double* c, const libxsmm_blasint* ldc) LIBXSMM_BLAS_NOEXCEPT(gemm) +{ + if (LIBXSMM_FSYMBOL(__real_dgemm) != libxsmm_original_dgemm_function) { + LIBXSMM_FSYMBOL(__wrap_dgemm)(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } + else { + libxsmm_blas_error("dgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK +void LIBXSMM_FSYMBOL(sgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const float* a, const libxsmm_blasint* lda, + const float* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) LIBXSMM_BLAS_NOEXCEPT(gemm) +{ + if (LIBXSMM_FSYMBOL(__real_sgemm) != libxsmm_original_sgemm_function) { + LIBXSMM_FSYMBOL(__wrap_sgemm)(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } + else { + libxsmm_blas_error("sgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK +void LIBXSMM_FSYMBOL(dgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n, + const double* alpha, const double* a, const libxsmm_blasint* lda, const double* x, const libxsmm_blasint* incx, + const double* beta, double* y, const libxsmm_blasint* incy) LIBXSMM_BLAS_NOEXCEPT(gemv) +{ + if (LIBXSMM_FSYMBOL(__real_dgemv) != libxsmm_original_dgemv_function) { + LIBXSMM_FSYMBOL(__wrap_dgemv)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } + else { + libxsmm_blas_error("dgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK +void LIBXSMM_FSYMBOL(sgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n, + const float* alpha, const float* a, const libxsmm_blasint* lda, const float* x, const libxsmm_blasint* incx, + const float* beta, float* y, const libxsmm_blasint* incy) LIBXSMM_BLAS_NOEXCEPT(gemv) +{ + if (LIBXSMM_FSYMBOL(__real_sgemv) != libxsmm_original_sgemv_function) { + LIBXSMM_FSYMBOL(__wrap_sgemv)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } + else { + libxsmm_blas_error("sgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK +void dgemm_batch(const char transa_array[], const char transb_array[], + const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], + const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], + const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch) +{ + LIBXSMM_FSYMBOL(dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK +void sgemm_batch(const char transa_array[], const char transb_array[], + const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], + const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], + const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch) +{ + LIBXSMM_FSYMBOL(sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} + +#elif (0 != LIBXSMM_NO_BLAS) /* no-BLAS library */ + +LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_ATTRIBUTE_COMMON unsigned int libxsmm_intrinsics_mm512_rng_state0[16]); +LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_ATTRIBUTE_COMMON unsigned int libxsmm_intrinsics_mm512_rng_state1[16]); +LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_ATTRIBUTE_COMMON unsigned int libxsmm_intrinsics_mm512_rng_state2[16]); +LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_ATTRIBUTE_COMMON unsigned int libxsmm_intrinsics_mm512_rng_state3[16]); + +LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_NO_TRACE void internal_noblas_sink(LIBXSMM_VARIADIC); +LIBXSMM_API_INTERN void internal_noblas_sink(LIBXSMM_VARIADIC) +{ + /* does nothing else but sinking given arguments */ +} + +LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_NO_TRACE libxsmm_sink_function internal_noblas_error(const char* /*symbol*/); +LIBXSMM_API_INTERN libxsmm_sink_function internal_noblas_error(const char* symbol) +{ + static int internal_noblas_nerror = 0; + LIBXSMM_BLAS_ERROR(symbol, &internal_noblas_nerror); + return internal_noblas_sink; +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/ +void LIBXSMM_FSYMBOL(dgemm_batch)(const char transa_array[], const char transb_array[], + const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], + const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], + const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch) +{ + internal_noblas_error("dgemm_batch")(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/ +void LIBXSMM_FSYMBOL(sgemm_batch)(const char transa_array[], const char transb_array[], + const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], + const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], + const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch) +{ + internal_noblas_error("sgemm_batch")(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/ +void LIBXSMM_FSYMBOL(dgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const double* alpha, const double* a, const libxsmm_blasint* lda, + const double* b, const libxsmm_blasint* ldb, + const double* beta, double* c, const libxsmm_blasint* ldc) LIBXSMM_BLAS_NOEXCEPT(gemm) +{ + internal_noblas_error("dgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/ +void LIBXSMM_FSYMBOL(sgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const float* a, const libxsmm_blasint* lda, + const float* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) LIBXSMM_BLAS_NOEXCEPT(gemm) +{ + internal_noblas_error("sgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/ +void LIBXSMM_FSYMBOL(dgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n, + const double* alpha, const double* a, const libxsmm_blasint* lda, const double* x, const libxsmm_blasint* incx, + const double* beta, double* y, const libxsmm_blasint* incy) LIBXSMM_BLAS_NOEXCEPT(gemv) +{ + internal_noblas_error("dgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/ +void LIBXSMM_FSYMBOL(sgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n, + const float* alpha, const float* a, const libxsmm_blasint* lda, const float* x, const libxsmm_blasint* incx, + const float* beta, float* y, const libxsmm_blasint* incy) LIBXSMM_BLAS_NOEXCEPT(gemv) +{ + internal_noblas_error("sgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE +void dgemm_batch(const char transa_array[], const char transb_array[], + const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], + const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], + const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch) +{ + LIBXSMM_FSYMBOL(dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} + + +LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE +void sgemm_batch(const char transa_array[], const char transb_array[], + const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], + const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], + const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch) +{ + LIBXSMM_FSYMBOL(sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} + +#endif +#endif /*defined(LIBXSMM_BUILD)*/ + diff --git a/third_party/libxsmm/src/libxsmm_ext.h b/third_party/libxsmm/src/libxsmm_ext.h new file mode 100644 index 0000000000000000000000000000000000000000..1f6828891f927f430dd36e582d303f8db339b06d --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_ext.h @@ -0,0 +1,46 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_EXT_H +#define LIBXSMM_EXT_H + +#include "libxsmm_main.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#if defined(_OPENMP) +# if !defined(__INTEL_COMPILER) +# if defined(__clang__) +# pragma clang diagnostic push +# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) +# pragma GCC diagnostic push +# endif +# if defined(__clang__) +# pragma clang diagnostic ignored "-Wpedantic" +# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) +# pragma GCC diagnostic ignored "-Wpedantic" +# endif +# endif +# include +# if defined(LIBXSMM_TRACE_CALLERID_GCCBUILTIN) && !defined(__INTEL_COMPILER) +# if defined(__clang__) +# pragma clang diagnostic pop +# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) +# pragma GCC diagnostic pop +# endif +# endif +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#endif /*LIBXSMM_EXT_H*/ + diff --git a/third_party/libxsmm/src/libxsmm_ext_gemm.c b/third_party/libxsmm/src/libxsmm_ext_gemm.c new file mode 100644 index 0000000000000000000000000000000000000000..9a17e35cec4aef370f38a4530d3eebfdd3f6f66c --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_ext_gemm.c @@ -0,0 +1,1268 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include +#include "libxsmm_gemm.h" +#include "libxsmm_ext.h" + +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) +# include "libxsmm_trace.h" +#endif + +#if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) && 0 +# define LIBXSMM_EXT_GEMM_PARGROUPS_INFO +#endif + +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) +# if !defined(LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH) +# define LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO) +# endif +# if !defined(LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH) +# define LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH 8/*POT*/ +# endif +LIBXSMM_APIVAR_DEFINE(libxsmm_gemm_descriptor internal_ext_gemm_batchdesc[LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH]); +LIBXSMM_APIVAR_DEFINE(unsigned int internal_ext_gemm_batchdepth); +LIBXSMM_APIVAR_DEFINE(unsigned int internal_ext_gemm_batchsize); +#endif + + +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) +LIBXSMM_API_INLINE int internal_mmbatch_sortrev(const void* stat_a, const void* stat_b) +{ + const libxsmm_mmbatch_item *const a = (const libxsmm_mmbatch_item*)stat_a; + const libxsmm_mmbatch_item *const b = (const libxsmm_mmbatch_item*)stat_b; + LIBXSMM_ASSERT(NULL != stat_a && NULL != stat_b); + return a->stat.count < b->stat.count ? 1 : (b->stat.count < a->stat.count ? -1 : 0); +} +#endif /*defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)*/ + + +LIBXSMM_API_INLINE int internal_mmbatch_flush(const libxsmm_gemm_descriptor* batchdesc, + libxsmm_blasint batchsize, libxsmm_mmbatch_item* batcharray) +{ + int result = EXIT_SUCCESS; +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) + if (0 != batchsize) { /* recorded/lazy multiplications */ + const libxsmm_blasint itemsize = sizeof(libxsmm_mmbatch_item); + LIBXSMM_ASSERT(NULL != batchdesc && 0 < batchsize); + if (0 == (LIBXSMM_MMBATCH_FLAG_STATISTIC & batchdesc->flags)) { /* process batch */ + const libxsmm_xmmfunction kernel = libxsmm_xmmdispatch(batchdesc); + if (NULL != kernel.xmm) { + const unsigned char itypesize = libxsmm_typesize((libxsmm_datatype)LIBXSMM_GETENUM_INP(batchdesc->datatype)); + const unsigned char otypesize = libxsmm_typesize((libxsmm_datatype)LIBXSMM_GETENUM_OUT(batchdesc->datatype)); +#if defined(_OPENMP) + if (0 == (LIBXSMM_MMBATCH_FLAG_SEQUENTIAL & batchdesc->flags)) { /* parallelized */ + const int nchunks = (int)LIBXSMM_UPDIV(batchsize, libxsmm_gemm_taskgrain); +# if defined(LIBXSMM_EXT_TASKS) + if (0 == omp_get_active_level()) { + const int max_nthreads = omp_get_max_threads(); + const int nthreads = LIBXSMM_MIN(max_nthreads, nchunks); + if (0 == libxsmm_gemm_tasks) +# else + if (0 == omp_in_parallel()) { + const int max_nthreads = omp_get_max_threads(); + const int nthreads = LIBXSMM_MIN(max_nthreads, nchunks); +# endif + { /* classic internal parallelization */ +# pragma omp parallel num_threads(nthreads) + /*check*/libxsmm_mmbatch_kernel( + kernel, 0/*index_base*/, 0/*index_stride*/, &itemsize, &itemsize, &itemsize, + &batcharray->value.a, &batcharray->value.b, &batcharray->value.c, + 0 == (LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED & batchdesc->flags) ? batchsize : -batchsize, + omp_get_thread_num(), nthreads, itypesize, otypesize, batchdesc->flags); + } +# if defined(LIBXSMM_EXT_TASKS) + else { /* internal parallelization with tasks */ +# pragma omp parallel num_threads(nthreads) + { /* first thread discovering work will launch all tasks */ +# pragma omp single nowait /* anyone is good */ + { int tid; for (tid = 0; tid < nchunks/*ntasks*/; ++tid) { +# pragma omp task untied + /*check*/libxsmm_mmbatch_kernel( + kernel, 0/*index_base*/, 0/*index_stride*/, &itemsize, &itemsize, &itemsize, + &batcharray->value.a, &batcharray->value.b, &batcharray->value.c, + 0 == (LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED & batchdesc->flags) ? batchsize : -batchsize, + tid, nchunks/*ntasks*/, itypesize, otypesize, batchdesc->flags); + } + } + } /* implicit synchronization (barrier) */ + } +# endif + } + else { /* assume external parallelization */ + int tid; for (tid = 0; tid < nchunks/*ntasks*/; ++tid) { +# if defined(LIBXSMM_EXT_TASKS) +# pragma omp task untied +#endif + /*check*/libxsmm_mmbatch_kernel( + kernel, 0/*index_base*/, 0/*index_stride*/, &itemsize, &itemsize, &itemsize, + &batcharray->value.a, &batcharray->value.b, &batcharray->value.c, + 0 == (LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED & batchdesc->flags) ? batchsize : -batchsize, + tid, nchunks/*ntasks*/, itypesize, otypesize, batchdesc->flags); + } +# if defined(LIBXSMM_EXT_TASKS) + if (0 == libxsmm_nosync) { /* allow to omit synchronization */ +# pragma omp taskwait + } +# endif + } + } + else +#endif + { /* sequential */ + result = libxsmm_mmbatch_kernel( + kernel, 0/*index_base*/, 0/*index_stride*/, &itemsize, &itemsize, &itemsize, + &batcharray->value.a, &batcharray->value.b, &batcharray->value.c, batchsize, + 0/*tid*/, 1/*nthreads*/, itypesize, otypesize, batchdesc->flags); + } + } + else { /* no fallback */ + /* several reasons to arrive here: try-lock, unsuitable SMM, etc. */ + result = EXIT_FAILURE; + } + memset(batcharray, 0, (size_t)batchsize * (size_t)itemsize); /* clear */ + } + else { /* print statistic */ + const libxsmm_blasint limit = (LIBXSMM_GEMM_MMBATCH_VERBOSITY < libxsmm_verbosity ? batchsize/*unlimited*/ : 7/*limited*/); + unsigned int threshold, batchcount; + libxsmm_blasint count = 0, i; + LIBXSMM_ASSERT(NULL != batcharray); + qsort(batcharray, (size_t)batchsize, (size_t)itemsize, internal_mmbatch_sortrev); + batchcount = batcharray[0].stat.count; + threshold = ((LIBXSMM_GEMM_MMBATCH_VERBOSITY < libxsmm_verbosity || 3 >= batchsize) ? 0 : (batchcount / 2)); + for (i = 1; i < batchsize; ++i) batchcount += batcharray[i].stat.count; + LIBXSMM_STDIO_ACQUIRE(); + for (i = 0; i < batchsize; ++i) { + const libxsmm_gemm_descriptor descriptor = batcharray[i].stat.desc; + const libxsmm_blasint lda = descriptor.lda, ldb = descriptor.ldb, ldc = descriptor.ldc; + const libxsmm_blasint m = descriptor.m, n = descriptor.n, k = descriptor.k; + const char *const symbol = batcharray[i].stat.symbol; + const unsigned int ci = batcharray[i].stat.count; + LIBXSMM_MEMZERO127(batcharray + i); /* clear */ + if (threshold < ci && count < limit /* limit printed statistic */ + && 0 < m && 0 < n && 0 < k) + { + const unsigned int ciperc = (unsigned int)(100.0 * ci / batchcount + 0.5); + if (0 != ciperc) { + LIBXSMM_ASSERT(0 != ci); + if (0 == count) { + fprintf(stderr, "\nLIBXSMM STATISTIC: %u multiplication%c\n", batchcount, 1 < batchcount ? 's' : ' '); + } + LIBXSMM_GEMM_PRINT2(stderr, + LIBXSMM_GETENUM_INP(descriptor.datatype), LIBXSMM_GETENUM_OUT(descriptor.datatype), descriptor.flags, m, n, k, + /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & descriptor.flags) ? 0 : */1, NULL/*a*/, lda, NULL/*b*/, ldb, + 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & descriptor.flags) ? 0 : 1, NULL/*c*/, ldc); + if (NULL != symbol && 0 != *symbol) { + fprintf(stderr, ": %u%% [%s]\n", ciperc, symbol); + } + else { + fprintf(stderr, ": %u%%\n", ciperc); + } + ++count; + } + else break; + } + } + LIBXSMM_STDIO_RELEASE(); + } + } +#else + LIBXSMM_UNUSED(batchdesc); LIBXSMM_UNUSED(batchsize); LIBXSMM_UNUSED(batcharray); +#endif + return result; +} + + +#if defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) + +#if defined(LIBXSMM_BLAS_WRAP_DYNAMIC) +LIBXSMM_API libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch(void) +{ +# if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, double, gemm_batch, libxsmm_original_dgemm_batch_function, libxsmm_original_dgemm_batch/*self*/); + /*LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_batch_function);*/ +# else + LIBXSMM_BLAS_WRAPPER(0, double, gemm_batch, libxsmm_original_dgemm_batch_function, libxsmm_original_dgemm_batch/*self*/); +# endif + return libxsmm_original_dgemm_batch_function; +} + +LIBXSMM_API libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch(void) +{ +# if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, float, gemm_batch, libxsmm_original_sgemm_batch_function, libxsmm_original_sgemm_batch/*self*/); + /*LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_batch_function);*/ +# else + LIBXSMM_BLAS_WRAPPER(0, float, gemm_batch, libxsmm_original_sgemm_batch_function, libxsmm_original_sgemm_batch/*self*/); +# endif + return libxsmm_original_sgemm_batch_function; +} + +LIBXSMM_API libxsmm_dgemm_function libxsmm_original_dgemm(void) +{ +# if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, double, gemm, libxsmm_original_dgemm_function, libxsmm_original_dgemm/*self*/); + LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_function); +# else + LIBXSMM_BLAS_WRAPPER(0, double, gemm, libxsmm_original_dgemm_function, libxsmm_original_dgemm/*self*/); +# endif + return libxsmm_original_dgemm_function; +} + +LIBXSMM_API libxsmm_sgemm_function libxsmm_original_sgemm(void) +{ +# if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, float, gemm, libxsmm_original_sgemm_function, libxsmm_original_sgemm/*self*/); + LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_function); +# else + LIBXSMM_BLAS_WRAPPER(0, float, gemm, libxsmm_original_sgemm_function, libxsmm_original_sgemm/*self*/); +# endif + return libxsmm_original_sgemm_function; +} + +LIBXSMM_API libxsmm_dgemv_function libxsmm_original_dgemv(void) +{ +# if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, double, gemv, libxsmm_original_dgemv_function, libxsmm_original_dgemv/*self*/); + LIBXSMM_ASSERT(NULL != libxsmm_original_dgemv_function); +# else + LIBXSMM_BLAS_WRAPPER(0, double, gemv, libxsmm_original_dgemv_function, libxsmm_original_dgemv/*self*/); +# endif + return libxsmm_original_dgemv_function; +} + +LIBXSMM_API libxsmm_sgemv_function libxsmm_original_sgemv(void) +{ +# if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, float, gemv, libxsmm_original_sgemv_function, libxsmm_original_sgemv/*self*/); + LIBXSMM_ASSERT(NULL != libxsmm_original_sgemv_function); +# else + LIBXSMM_BLAS_WRAPPER(0, float, gemv, libxsmm_original_sgemv_function, libxsmm_original_sgemv/*self*/); +# endif + return libxsmm_original_sgemv_function; +} +#endif /*defined(LIBXSMM_BLAS_WRAP_DYNAMIC)*/ + + +LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_dgemm_batch)( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + LIBXSMM_ASSERT(NULL != lda_array && NULL != ldb_array && NULL != ldc_array && NULL != m_array && NULL != n_array && NULL != k_array); + LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != alpha_array && NULL != beta_array); + LIBXSMM_ASSERT(NULL != group_count && NULL != group_size); + LIBXSMM_INIT + if (0 != libxsmm_gemm_wrap) { + if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */ + libxsmm_dgemm_batch(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } + else { /* parallelized */ + libxsmm_dgemm_batch_omp(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } + } + else { + LIBXSMM_GEMM_BATCH_SYMBOL(double)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } +} + + +LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_sgemm_batch)( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + LIBXSMM_ASSERT(NULL != lda_array && NULL != ldb_array && NULL != ldc_array && NULL != m_array && NULL != n_array && NULL != k_array); + LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != alpha_array && NULL != beta_array); + LIBXSMM_ASSERT(NULL != group_count && NULL != group_size); + LIBXSMM_INIT + if (0 != libxsmm_gemm_wrap) { + if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */ + libxsmm_sgemm_batch(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } + else { /* parallelized */ + libxsmm_sgemm_batch_omp(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } + } + else { + LIBXSMM_GEMM_BATCH_SYMBOL(float)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } +} + + +LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_dgemm)( + const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const double* alpha, const double* a, const libxsmm_blasint* lda, + const double* b, const libxsmm_blasint* ldb, + const double* beta, double* c, const libxsmm_blasint* ldc) +{ + LIBXSMM_ASSERT(NULL != lda && NULL != ldb && NULL != ldc && NULL != m && NULL != n && NULL != k); + LIBXSMM_ASSERT(NULL != transa && NULL != transb && NULL != alpha && NULL != beta); + { +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) + unsigned int i = 0; /* no flush */ + int flags = -1; +# if !defined(NDEBUG) + static int error_once = 0; + int result = EXIT_SUCCESS; +# endif + LIBXSMM_INIT + if (0 != libxsmm_gemm_wrap && (NULL == libxsmm_mmbatch_array + || LIBXSMM_GEMM_PRECISION_F64 != libxsmm_mmbatch_desc.datatype + || ((unsigned int)*lda) != libxsmm_mmbatch_desc.lda + || ((unsigned int)*ldb) != libxsmm_mmbatch_desc.ldb + || ((unsigned int)*ldc) != libxsmm_mmbatch_desc.ldc + || ((unsigned int)*m) != libxsmm_mmbatch_desc.m + || ((unsigned int)*n) != libxsmm_mmbatch_desc.n + || ((unsigned int)*k) != libxsmm_mmbatch_desc.k + || (flags = LIBXSMM_GEMM_FLAGS(*transa, *transb)) != (int)(LIBXSMM_GEMM_FLAG_TRANS_AB & libxsmm_mmbatch_desc.flags) + || LIBXSMM_NEQ(/*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & libxsmm_mmbatch_desc.flags) ? 0 : */1, *alpha) + || LIBXSMM_NEQ(0 != (LIBXSMM_GEMM_FLAG_BETA_0 & libxsmm_mmbatch_desc.flags) ? 0 : 1, *beta))) +#endif + { +#if defined(_DEBUG) + const char *const env_check = getenv("LIBXSMM_GEMM_CHECK"); + const double check = LIBXSMM_ABS(NULL == env_check ? 0 : atof(env_check)); + void* d = NULL; + if (LIBXSMM_NEQ(0, check)) { + const size_t size = (size_t)(*ldc) * (size_t)(*n) * sizeof(double); + d = libxsmm_scratch_malloc(size, 0/*auto*/, LIBXSMM_MALLOC_INTERNAL_CALLER); + if (NULL != d && LIBXSMM_NEQ(0, *beta)) memcpy(d, c, size); /* copy destination */ + } +#endif + if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */ + libxsmm_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } + else { /* parallelized */ + libxsmm_dgemm_omp(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +#if defined(_DEBUG) + if (NULL != d) { + libxsmm_matdiff_info diff; + libxsmm_blas_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, d, ldc); + if (EXIT_SUCCESS == libxsmm_matdiff(&diff, LIBXSMM_DATATYPE_F64, *m, *n, d, c, ldc, ldc) + && check < 100.0 * diff.normf_rel) + { + LIBXSMM_STDIO_ACQUIRE(); + fprintf(stderr, "LIBXSMM: "); + libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION_F64, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + fprintf(stderr, " => %f%% ERROR\n", 100.0 * diff.normf_rel); + LIBXSMM_STDIO_RELEASE(); + } + libxsmm_free(d); + } +#endif +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) + if (0 != (LIBXSMM_MMBATCH_FLAG_STATISTIC & libxsmm_mmbatch_desc.flags)) { + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const descriptor = libxsmm_dgemm_descriptor_init(&blob, + *m, *n, *k, *lda, *ldb, *ldc, *alpha, *beta, LIBXSMM_GEMM_FLAGS(*transa, *transb), + LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH); + + LIBXSMM_ASSERT(0 != libxsmm_mmbatch_size); + if (NULL != descriptor) { + const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size); + const unsigned int batchsize = LIBXSMM_ATOMIC_LOAD(&internal_ext_gemm_batchsize, LIBXSMM_ATOMIC_RELAXED); + const unsigned int max_size = (0 != batchsize ? (((batchsize - 1) % max_batchsize) + 1) : 0); + libxsmm_mmbatch_item *const batcharray = (libxsmm_mmbatch_item*)libxsmm_mmbatch_array; + libxsmm_mmbatch_item* batcharray_cur = batcharray; + unsigned int size = max_size; + if (libxsmm_mmbatch_size < max_size) { + size = max_size - libxsmm_mmbatch_size; + batcharray_cur += libxsmm_mmbatch_size; + } + i = libxsmm_diff_n(descriptor, batcharray_cur, sizeof(libxsmm_gemm_descriptor), + sizeof(libxsmm_mmbatch_item)/*stride*/, 0/*hint*/, size); + + if (i < size) { /* update existing entry */ + LIBXSMM_ATOMIC_ADD_FETCH(&batcharray_cur[i].stat.count, 1, LIBXSMM_ATOMIC_RELAXED); + } + else { /* new entry needed */ + const int all = -1, shift = 0; + void* extra = 0; + i = ((LIBXSMM_ATOMIC_ADD_FETCH(&internal_ext_gemm_batchsize, 1, LIBXSMM_ATOMIC_RELAXED) - 1) % max_batchsize) + 1; + batcharray[i-1].stat.desc = *descriptor; + batcharray[i-1].stat.count = 1; + batcharray[i-1].stat.symbol = libxsmm_trace_info(NULL/*depth*/, NULL/*tid*/, &all, LIBXSMM_FUNCNAME, &shift, &all); + if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(libxsmm_mmbatch_array, NULL/*size*/, NULL/*flags*/, &extra)) { + *(libxsmm_mmbatch_flush_function*)extra = libxsmm_mmbatch_end; + } +# if !defined(NDEBUG) + else { + result = EXIT_FAILURE; + } +# endif + } + } + } +#endif + } +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) + else { + libxsmm_mmbatch_item *const batcharray = (libxsmm_mmbatch_item*)libxsmm_mmbatch_array; + const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size); + i = ((LIBXSMM_ATOMIC_ADD_FETCH(&internal_ext_gemm_batchsize, 1, LIBXSMM_ATOMIC_RELAXED) - 1) % max_batchsize) + 1; + batcharray[i-1].value.a = a; + batcharray[i-1].value.b = b; + batcharray[i-1].value.c = c; + LIBXSMM_ASSERT(0 <= flags); + } + if (libxsmm_mmbatch_size == (i - 1)) { /* condition ensure to flush once (first discovery) */ +# if !defined(NDEBUG) + result = +# endif + internal_mmbatch_flush(&libxsmm_mmbatch_desc, libxsmm_mmbatch_size, (libxsmm_mmbatch_item*)libxsmm_mmbatch_array); + } +# if !defined(NDEBUG) /* library code is expected to be mute */ + if (EXIT_SUCCESS != result && 0 != libxsmm_verbosity && + 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: DGEMM batch recording failed!\n"); + } +# endif +#endif + } +} + + +LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_sgemm)( + const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const float* a, const libxsmm_blasint* lda, + const float* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) +{ + LIBXSMM_ASSERT(NULL != lda && NULL != ldb && NULL != ldc && NULL != m && NULL != n && NULL != k); + LIBXSMM_ASSERT(NULL != transa && NULL != transb && NULL != alpha && NULL != beta); + { +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) + unsigned int i = 0; /* no flush */ + int flags = -1; +# if !defined(NDEBUG) + static int error_once = 0; + int result = EXIT_SUCCESS; +# endif + LIBXSMM_INIT + if (0 != libxsmm_gemm_wrap && (NULL == libxsmm_mmbatch_array + || LIBXSMM_GEMM_PRECISION_F32 != libxsmm_mmbatch_desc.datatype + || ((unsigned int)*lda) != libxsmm_mmbatch_desc.lda + || ((unsigned int)*ldb) != libxsmm_mmbatch_desc.ldb + || ((unsigned int)*ldc) != libxsmm_mmbatch_desc.ldc + || ((unsigned int)*m) != libxsmm_mmbatch_desc.m + || ((unsigned int)*n) != libxsmm_mmbatch_desc.n + || ((unsigned int)*k) != libxsmm_mmbatch_desc.k + || (flags = LIBXSMM_GEMM_FLAGS(*transa, *transb)) != (int)(LIBXSMM_GEMM_FLAG_TRANS_AB & libxsmm_mmbatch_desc.flags) + || LIBXSMM_NEQ(/*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & libxsmm_mmbatch_desc.flags) ? 0 : */1, *alpha) + || LIBXSMM_NEQ(0 != (LIBXSMM_GEMM_FLAG_BETA_0 & libxsmm_mmbatch_desc.flags) ? 0 : 1, *beta))) +#endif + { +#if defined(_DEBUG) + const char *const env_check = getenv("LIBXSMM_GEMM_CHECK"); + const double check = LIBXSMM_ABS(NULL == env_check ? 0 : atof(env_check)); + void* d = NULL; + if (LIBXSMM_NEQ(0, check)) { + const size_t size = (size_t)(*ldc) * (size_t)(*n) * sizeof(float); + d = libxsmm_scratch_malloc(size, 0/*auto*/, LIBXSMM_MALLOC_INTERNAL_CALLER); + if (NULL != d && LIBXSMM_NEQ(0, *beta)) memcpy(d, c, size); /* copy destination */ + } +#endif + if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */ + libxsmm_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } + else { /* parallelized */ + libxsmm_sgemm_omp(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +#if defined(_DEBUG) + if (NULL != d) { + libxsmm_matdiff_info diff; + libxsmm_blas_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, d, ldc); + if (EXIT_SUCCESS == libxsmm_matdiff(&diff, LIBXSMM_DATATYPE_F32, *m, *n, d, c, ldc, ldc) + && check < 100.0 * diff.normf_rel) + { + LIBXSMM_STDIO_ACQUIRE(); + fprintf(stderr, "LIBXSMM: "); + libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION_F32, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + fprintf(stderr, " => %f%% ERROR\n", 100.0 * diff.normf_rel); + LIBXSMM_STDIO_RELEASE(); + } + libxsmm_free(d); + } +#endif +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) + if (0 != (LIBXSMM_MMBATCH_FLAG_STATISTIC & libxsmm_mmbatch_desc.flags)) { + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const descriptor = libxsmm_sgemm_descriptor_init(&blob, + *m, *n, *k, *lda, *ldb, *ldc, *alpha, *beta, LIBXSMM_GEMM_FLAGS(*transa, *transb), + LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH); + + LIBXSMM_ASSERT(0 != libxsmm_mmbatch_size); + if (NULL != descriptor) { + const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size); + const unsigned int batchsize = LIBXSMM_ATOMIC_LOAD(&internal_ext_gemm_batchsize, LIBXSMM_ATOMIC_RELAXED); + const unsigned int max_size = (0 != batchsize ? (((batchsize - 1) % max_batchsize) + 1) : 0); + libxsmm_mmbatch_item *const batcharray = (libxsmm_mmbatch_item*)libxsmm_mmbatch_array; + libxsmm_mmbatch_item* batcharray_cur = batcharray; + unsigned int size = max_size; + if (libxsmm_mmbatch_size < max_size) { + size = max_size - libxsmm_mmbatch_size; + batcharray_cur += libxsmm_mmbatch_size; + } + i = libxsmm_diff_n(descriptor, batcharray_cur, sizeof(libxsmm_gemm_descriptor), + sizeof(libxsmm_mmbatch_item)/*stride*/, 0/*hint*/, size); + + if (i < size) { /* update existing entry */ + LIBXSMM_ATOMIC_ADD_FETCH(&batcharray_cur[i].stat.count, 1, LIBXSMM_ATOMIC_RELAXED); + } + else { /* new entry needed */ + const int all = -1, shift = 0; + void* extra = 0; + i = ((LIBXSMM_ATOMIC_ADD_FETCH(&internal_ext_gemm_batchsize, 1, LIBXSMM_ATOMIC_RELAXED) - 1) % max_batchsize) + 1; + batcharray[i-1].stat.desc = *descriptor; + batcharray[i-1].stat.count = 1; + batcharray[i-1].stat.symbol = libxsmm_trace_info(NULL/*depth*/, NULL/*tid*/, &all, LIBXSMM_FUNCNAME, &shift, &all); + if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(libxsmm_mmbatch_array, NULL/*size*/, NULL/*flags*/, &extra)) { + *(libxsmm_mmbatch_flush_function*)extra = libxsmm_mmbatch_end; + } +# if !defined(NDEBUG) + else { + result = EXIT_FAILURE; + } +# endif + } + } + } +#endif + } +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) + else { + libxsmm_mmbatch_item *const batcharray = (libxsmm_mmbatch_item*)libxsmm_mmbatch_array; + const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size); + i = ((LIBXSMM_ATOMIC_ADD_FETCH(&internal_ext_gemm_batchsize, 1, LIBXSMM_ATOMIC_RELAXED) - 1) % max_batchsize) + 1; + batcharray[i-1].value.a = a; + batcharray[i-1].value.b = b; + batcharray[i-1].value.c = c; + LIBXSMM_ASSERT(0 <= flags); + } + if (libxsmm_mmbatch_size == (i - 1)) { /* condition ensure to flush once (first discovery) */ +# if !defined(NDEBUG) + result = +# endif + internal_mmbatch_flush(&libxsmm_mmbatch_desc, libxsmm_mmbatch_size, (libxsmm_mmbatch_item*)libxsmm_mmbatch_array); + } +# if !defined(NDEBUG) /* library code is expected to be mute */ + if (EXIT_SUCCESS != result && 0 != libxsmm_verbosity && + 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: SGEMM batch recording failed!\n"); + } +# endif +#endif + } +} + + +LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_dgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n, + const double* alpha, const double* a, const libxsmm_blasint* lda, const double* x, const libxsmm_blasint* incx, + const double* beta, double* y, const libxsmm_blasint* incy) +{ + LIBXSMM_ASSERT(NULL != trans && NULL != m && NULL != n && NULL != lda && NULL != incx && NULL != incy && NULL != alpha && NULL != beta); + LIBXSMM_INIT + if ((2 < libxsmm_gemm_wrap || 2 > libxsmm_gemm_wrap) && 1 == *incx && 1 == *incy && LIBXSMM_SMM(*m, 1, *n, 2/*RFO*/, sizeof(double))) { + if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */ + const int flags = LIBXSMM_GEMM_FLAGS(*trans, 'N'); + const libxsmm_dmmfunction xgemv = libxsmm_dmmdispatch(*m, 1, *n, lda, n/*ldb*/, m/*ldc*/, alpha, beta, &flags, NULL); + if (NULL != xgemv) { + LIBXSMM_MMCALL_LDX(xgemv, a, x, y, *m, 1, *n, *lda, *n/*ldb*/, *m/*ldc*/); + } + else { + LIBXSMM_GEMV_SYMBOL(double)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } + } + else { /* TODO: parallelized */ + LIBXSMM_GEMV_SYMBOL(double)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } + } + else { + LIBXSMM_GEMV_SYMBOL(double)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } +} + + +LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_sgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n, + const float* alpha, const float* a, const libxsmm_blasint* lda, const float* x, const libxsmm_blasint* incx, + const float* beta, float* y, const libxsmm_blasint* incy) +{ + LIBXSMM_ASSERT(NULL != trans && NULL != m && NULL != n && NULL != lda && NULL != incx && NULL != incy && NULL != alpha && NULL != beta); + LIBXSMM_INIT + if ((2 < libxsmm_gemm_wrap || 2 > libxsmm_gemm_wrap) && 1 == *incx && 1 == *incy && LIBXSMM_SMM(*m, 1, *n, 2/*RFO*/, sizeof(float))) { + if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */ + const int flags = LIBXSMM_GEMM_FLAGS(*trans, 'N'); + const libxsmm_smmfunction xgemv = libxsmm_smmdispatch(*m, 1, *n, lda, n/*ldb*/, m/*ldc*/, alpha, beta, &flags, NULL); + if (NULL != xgemv) { + LIBXSMM_MMCALL_LDX(xgemv, a, x, y, *m, 1, *n, *lda, *n/*ldb*/, *m/*ldc*/); + } + else { + LIBXSMM_GEMV_SYMBOL(float)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } + } + else { /* TODO: parallelized */ + LIBXSMM_GEMV_SYMBOL(float)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } + } + else { + LIBXSMM_GEMV_SYMBOL(float)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + } +} + + +LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void __wrap_dgemm_batch( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + LIBXSMM_FSYMBOL(__wrap_dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} + + +LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void __wrap_sgemm_batch( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + LIBXSMM_FSYMBOL(__wrap_sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} + +#endif /*defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT)*/ + + +LIBXSMM_APIEXT void libxsmm_xgemm_omp(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc) +{ + libxsmm_gemm_blob blob; +#if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */ + const int outerpar = omp_get_active_level(), nthreads = (0 == outerpar ? omp_get_max_threads() : omp_get_num_threads()); +#elif defined(_OPENMP) + const int outerpar = omp_in_parallel(), nthreads = (0 == outerpar ? omp_get_max_threads() : 1); +#else + const int nthreads = 1; +#endif + const libxsmm_gemm_handle *const handle = libxsmm_gemm_handle_init(&blob, iprec, oprec, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta, LIBXSMM_GEMM_HANDLE_FLAG_AUTO, nthreads); + const size_t scratch_size = libxsmm_gemm_handle_get_scratch_size(handle); + void* scratch = NULL; + if (NULL != handle && (0 == scratch_size || + NULL != (scratch = libxsmm_scratch_malloc(scratch_size, LIBXSMM_CACHELINE, LIBXSMM_MALLOC_INTERNAL_CALLER)))) + { +#if defined(_OPENMP) + if (0 == outerpar) { /* enable internal parallelization */ +# if defined(LIBXSMM_EXT_TASKS) + if (0 == libxsmm_gemm_tasks) +# endif + { +# pragma omp parallel num_threads(nthreads) + libxsmm_gemm_task(handle, scratch, a, b, c, omp_get_thread_num(), nthreads); + } +# if defined(LIBXSMM_EXT_TASKS) + else { /* tasks requested */ + const int ntasks = nthreads; /* TODO: apply grain-size */ +# pragma omp parallel num_threads(nthreads) + { /* first thread discovering work will launch all tasks */ +# pragma omp single nowait /* anyone is good */ + { int tid; for (tid = 0; tid < ntasks; ++tid) { +# pragma omp task untied + libxsmm_gemm_task(handle, scratch, a, b, c, tid, ntasks); + } + } + } /* implicit synchronization (barrier) */ + } +# endif + } + else { /* assume external parallelization */ +# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */ + const int ntasks = nthreads; /* TODO: apply grain-size */ + int tid; for (tid = 0; tid < ntasks; ++tid) { +# pragma omp task untied + libxsmm_gemm_task(handle, scratch, a, b, c, tid, ntasks); + } + if (0 == libxsmm_nosync) { /* allow to omit synchronization */ +# pragma omp taskwait + } +# else + libxsmm_gemm_task(handle, scratch, a, b, c, 0/*tid*/, 1/*nthreads*/); +# endif + } + if (LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) { /* library code is expected to be mute */ + const unsigned int ntasks = handle->mt * handle->nt * handle->kt; + const double imbalance = 100.0 * LIBXSMM_DELTA((unsigned int)nthreads, ntasks) / nthreads; + static double max_imbalance = 50.0; + if (max_imbalance < imbalance) { + fprintf(stderr, "LIBXSMM WARNING: XGEMM %.0f%% imbalance (%u of %i workers utilized)!\n", + imbalance, ntasks, nthreads); + max_imbalance = imbalance; + } + } +#else + libxsmm_gemm_task(handle, scratch, a, b, c, 0/*tid*/, 1/*nthreads*/); +#endif /*defined(_OPENMP)*/ + libxsmm_free(scratch); + } + else { /* fallback or error */ + static int error_once = 0; + if (NULL == handle) { /* fallback */ + if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: XGEMM fallback code path triggered!\n"); + } + } + else if (0 != libxsmm_verbosity && /* library code is expected to be mute */ + 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: failed to allocate GEMM-scratch memory!\n"); + } + libxsmm_blas_xgemm(iprec, oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +} + + +LIBXSMM_API_INLINE void internal_gemm_batch_omp(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char transa[], const char transb[], const libxsmm_blasint m[], const libxsmm_blasint n[], const libxsmm_blasint k[], + const void* alpha, const void* a[], const libxsmm_blasint lda[], const void* b[], const libxsmm_blasint ldb[], + const void* beta, void* c[], const libxsmm_blasint ldc[], libxsmm_blasint index_base, libxsmm_blasint index_stride, + const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + const libxsmm_blasint batchsize[], libxsmm_blasint group_count) +{ + static int error_once = 0; + LIBXSMM_INIT + if ( /* check for sensible arguments */ +#if defined(LIBXSMM_BATCH_CHECK) + NULL != a && NULL != b && NULL != c && (1 == group_count || -1 == group_count || + (0 == index_stride && (NULL == stride_a || 0 != *stride_a) && (NULL == stride_b || 0 != *stride_b) && (NULL == stride_c || 0 != *stride_c))) && +#endif + 0 != group_count) + { + int result = EXIT_SUCCESS; + const int max_npargroups = (int)(0 < libxsmm_gemm_npargroups + ? LIBXSMM_MIN(libxsmm_gemm_npargroups, LIBXSMM_GEMM_NPARGROUPS) : LIBXSMM_GEMM_NPARGROUPS); + const libxsmm_gemm_prefetch_type prefetch = libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO); + const size_t sa = (NULL != stride_a ? (size_t)(*stride_a) : sizeof(void*)); + const size_t sb = (NULL != stride_b ? (size_t)(*stride_b) : sizeof(void*)); + const size_t sc = (NULL != stride_c ? (size_t)(*stride_c) : sizeof(void*)); + const unsigned char otypesize = libxsmm_typesize((libxsmm_datatype)oprec); + const int ngroups = (int)LIBXSMM_ABS(group_count); + int group = 0, group_next = LIBXSMM_GEMM_NPARGROUPS; + libxsmm_code_pointer kernel[LIBXSMM_GEMM_NPARGROUPS]; + libxsmm_blasint base[LIBXSMM_GEMM_NPARGROUPS], i; +#if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + int kflags[LIBXSMM_GEMM_NPARGROUPS]; +#endif + int max_nthreads = 1; +#if defined(_OPENMP) +# if defined(LIBXSMM_EXT_TASKS) + const int outerpar = omp_get_active_level(); +# else + const int outerpar = omp_in_parallel(); +# endif + if (0 == outerpar) max_nthreads = omp_get_max_threads(); +#endif + for (i = 0; i < max_npargroups; ++i) { +#if !defined(NDEBUG) + kernel[i].ptr = NULL; +# if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + kflags[i] = 0; +# endif +#endif + base[i] = 0; + } + for (group = 0; group < ngroups; group = group_next, group_next += max_npargroups) { + const int npargroups = LIBXSMM_MIN(group_next, ngroups); + libxsmm_blasint size = 0; + int suitable = 0; + if (0 < group) { /* base is maintained even if par-group is not suitable */ + for (i = 0; i < npargroups; ++i) { + const libxsmm_blasint isize = batchsize[group+i-1], asize = LIBXSMM_ABS(isize); + base[i] += asize; + } + } + for (i = 0; i < npargroups; ++i) { + const libxsmm_blasint g = group + i, im = m[g], in = n[g], ik = k[g]; + suitable = LIBXSMM_SMM_AI(im, in, ik, 2/*RFO*/, otypesize); + if (0 != suitable) { + const libxsmm_blasint isize = batchsize[g], asize = LIBXSMM_ABS(isize); + const char *const ta = (NULL != transa ? (transa + g) : NULL); + const char *const tb = (NULL != transb ? (transb + g) : NULL); + const int flags = LIBXSMM_GEMM_PFLAGS(ta, tb, LIBXSMM_FLAGS); + const void **const galpha = &alpha, **const gbeta = β + libxsmm_descriptor_blob blob; + /* coverity[ptr_arith] */ + libxsmm_gemm_descriptor *const desc = libxsmm_gemm_descriptor_init2(&blob, iprec, oprec, im, in, ik, + NULL != lda ? lda[g] : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & flags) ? im : ik), + NULL != ldb ? ldb[g] : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & flags) ? ik : in), + NULL != ldc ? ldc[g] : im, NULL != alpha ? galpha[g] : NULL, NULL != beta ? gbeta[g] : NULL, + flags, prefetch); + if (NULL != desc) { + libxsmm_gemm_internal_set_batchflag(desc, c, index_stride, 0 < group_count ? isize : -asize, 1 != max_nthreads); + kernel[i].xgemm = libxsmm_xmmdispatch(desc); + } + else kernel[i].ptr = NULL; + if (NULL != kernel[i].ptr_const) { + if (size < asize) size = asize; +#if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + LIBXSMM_ASSERT(NULL != desc); /* coverity[var_deref_op] */ + kflags[i] = desc->flags; +#endif + } + else { + suitable = 0; + break; + } + } + else break; + } + if (0 != suitable) { /* check if an SMM is suitable */ + const unsigned char itypesize = libxsmm_typesize((libxsmm_datatype)iprec); +#if defined(_OPENMP) + const int nchunks = (int)LIBXSMM_UPDIV(size, libxsmm_gemm_taskgrain); + const int ntasks = nchunks * npargroups, nthreads = LIBXSMM_MIN(max_nthreads, ntasks); + if (1 < nthreads) { + if (0 == outerpar) { /* enable internal parallelization */ +# if defined(LIBXSMM_EXT_TASKS) + if (0 == libxsmm_gemm_tasks) +# endif + { +# pragma omp parallel for num_threads(nthreads) private(i) + for (i = 0; i < ntasks; ++i) { + const libxsmm_blasint j = i * libxsmm_gemm_taskgrain, u = j / size, v = j - u * size, g = group + u; + const libxsmm_blasint isize = batchsize[g], asize = LIBXSMM_ABS(isize); + if (v < asize) { +#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + libxsmm_mmkernel_info kernel_info; +#endif + /*check*/libxsmm_mmbatch_kernel(kernel[g].xgemm, index_base, index_stride, stride_a, stride_b, stride_c, + (const char*)a + sa * base[u], (const char*)b + sb * base[u], (char*)c + sc * base[u], + 0 < group_count ? isize : -asize, (int)i, nchunks, itypesize, otypesize, +#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + EXIT_SUCCESS == libxsmm_get_mmkernel_info(kernel[g].xgemm, &kernel_info) ? kernel_info.flags : 0); +#else + kflags[g]); +#endif + } + } + } +# if defined(LIBXSMM_EXT_TASKS) + else { /* tasks requested */ +# pragma omp parallel num_threads(nthreads) private(i) + { /* first thread discovering work will launch all tasks */ +# pragma omp single nowait /* anyone is good */ + for (i = 0; i < ntasks; ++i) { + const libxsmm_blasint j = i * libxsmm_gemm_taskgrain, u = j / size, v = j - u * size, g = group + u; + const libxsmm_blasint isize = batchsize[g], asize = LIBXSMM_ABS(isize); + if (v < asize) { +# pragma omp task + { +#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + libxsmm_mmkernel_info kernel_info; +#endif + /*check*/libxsmm_mmbatch_kernel(kernel[g].xgemm, index_base, index_stride, stride_a, stride_b, stride_c, + (const char*)a + sa * base[u], (const char*)b + sb * base[u], (char*)c + sc * base[u], + 0 < group_count ? isize : -asize, (int)i, nchunks, itypesize, otypesize, +#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + EXIT_SUCCESS == libxsmm_get_mmkernel_info(kernel[g].xgemm, &kernel_info) ? kernel_info.flags : 0); +#else + kflags[g]); +#endif + } + } + } + } /* implicit synchronization (barrier) */ + } +# endif + } + else { /* assume external parallelization */ + for (i = 0; i < (libxsmm_blasint)ntasks; ++i) { + const libxsmm_blasint j = i * libxsmm_gemm_taskgrain, u = j / size, v = j - u * size, g = group + u; + const libxsmm_blasint isize = batchsize[g], asize = LIBXSMM_ABS(isize); + if (v < asize) { +# if defined(LIBXSMM_EXT_TASKS) /* OpenMP-tasks */ +# pragma omp task +#endif + { +#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + libxsmm_mmkernel_info kernel_info; +#endif + /*check*/libxsmm_mmbatch_kernel(kernel[g].xgemm, index_base, index_stride, stride_a, stride_b, stride_c, + (const char*)a + sa * base[u], (const char*)b + sb * base[u], (char*)c + sc * base[u], + 0 < group_count ? isize : -asize, (int)i, nchunks, itypesize, otypesize, +#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + EXIT_SUCCESS == libxsmm_get_mmkernel_info(kernel[g].xgemm, &kernel_info) ? kernel_info.flags : 0); +#else + kflags[g]); +#endif + } + } + } +# if defined(LIBXSMM_EXT_TASKS) /* OpenMP-tasks */ + if (0 == libxsmm_nosync) { /* allow to omit synchronization */ +# pragma omp taskwait + } +# endif + } + } + else +#endif /*defined(_OPENMP)*/ + { /* sequential */ + for (i = 0; i < npargroups; ++i) { + const libxsmm_blasint g = group + i; +#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + libxsmm_mmkernel_info kernel_info; +#endif + libxsmm_mmbatch_kernel(kernel[i].xgemm, index_base, index_stride, stride_a, stride_b, stride_c, + (const char*)a + sa * base[i], (const char*)b + sb * base[i], (char*)c + sc * base[i], batchsize[g], + 0/*tid*/, 1/*nthreads*/, itypesize, otypesize, +#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) + EXIT_SUCCESS == libxsmm_get_mmkernel_info(kernel[i].xgemm, &kernel_info) ? kernel_info.flags : 0); +#else + kflags[i]); +#endif + } + } + } + else { /* trigger fallback */ + result = EXIT_FAILURE; + } + if (EXIT_SUCCESS != result) { + for (i = 0; i < npargroups; ++i) { + const libxsmm_blasint g = group + i; + const char *const ta = (NULL != transa ? (transa + g) : NULL); + const char *const tb = (NULL != transb ? (transb + g) : NULL); + const int flags = LIBXSMM_GEMM_PFLAGS(ta, tb, LIBXSMM_FLAGS); + const libxsmm_blasint im = m[g], in = n[g], ik = k[g]; + const libxsmm_blasint ilda = (NULL != lda ? lda[g] : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & flags) ? im : ik)); + const libxsmm_blasint ildb = (NULL != ldb ? ldb[g] : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & flags) ? ik : in)); + const libxsmm_blasint ildc = (NULL != ldc ? ldc[g] : im); + const void **const galpha = &alpha, **const gbeta = β + /* coverity[overrun-local] */ + const void *const ialpha = (NULL != alpha ? galpha[g] : NULL); + /* coverity[overrun-local] */ + const void *const ibeta = (NULL != beta ? gbeta[g] : NULL); + if (EXIT_SUCCESS == libxsmm_mmbatch_blas(iprec, oprec, ta, tb, im, in, ik, ialpha, + (const char*)a + sa * base[i], &ilda, (const char*)b + sb * base[i], &ildb, ibeta, (char*)c + sc * base[i], &ildc, + index_base, index_stride, stride_a, stride_b, stride_c, batchsize[g])) + { + if (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) { + const size_t threshold = LIBXSMM_MNK_SIZE(im, in, im); + static size_t threshold_max = 0; + if (threshold_max < threshold) { + LIBXSMM_STDIO_ACQUIRE(); + fprintf(stderr, "LIBXSMM WARNING: "); + libxsmm_gemm_print2(stderr, iprec, oprec, ta, tb, &im, &in, &ik, + ialpha, NULL/*a*/, &ilda, NULL/*b*/, &ildb, ibeta, NULL/*c*/, &ildc); + fprintf(stderr, " => batched GEMM/omp was falling back to BLAS!\n"); + LIBXSMM_STDIO_RELEASE(); + threshold_max = threshold; + } + } + } + else { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: libxsmm_gemm_batch_omp failed!\n"); + } + return; /* exit routine */ + } + } + } + } + } +#if defined(LIBXSMM_BATCH_CHECK) + else if (0 != group_count && 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: incorrect arguments (libxsmm_gemm_batch_omp)!\n"); + } +#endif +} + + +LIBXSMM_APIEXT void libxsmm_gemm_batch_omp(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc, libxsmm_blasint index_base, libxsmm_blasint index_stride, + const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize) +{ + internal_gemm_batch_omp(iprec, oprec, transa, transb, &m, &n, &k, + alpha, (const void**)a, lda, (const void**)b, ldb, beta, (void**)c, ldc, index_base, index_stride, + stride_a, stride_b, stride_c, &batchsize, 1); +} + + +LIBXSMM_APIEXT void libxsmm_dgemm_batch_omp( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + if (NULL != group_count) { + const libxsmm_blasint ptrsize = sizeof(void*); + internal_gemm_batch_omp(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, transa_array, transb_array, m_array, n_array, k_array, + alpha_array, (const void**)a_array, lda_array, (const void**)b_array, ldb_array, beta_array, (void**)c_array, ldc_array, + 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, group_size, *group_count); + } +} + + +LIBXSMM_APIEXT void libxsmm_sgemm_batch_omp( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + if (NULL != group_count) { + const libxsmm_blasint ptrsize = sizeof(void*); + internal_gemm_batch_omp(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, transa_array, transb_array, m_array, n_array, k_array, + alpha_array, (const void**)a_array, lda_array, (const void**)b_array, ldb_array, beta_array, (void**)c_array, ldc_array, + 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, group_size, *group_count); + } +} + + +LIBXSMM_APIEXT void libxsmm_mmbatch_begin(libxsmm_gemm_precision precision, + const int* flags, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const void* alpha, const void* beta) +{ +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) +# if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 26115) /* try-lock is treated incorrectly by static analysis */ +# endif + LIBXSMM_INIT + if (NULL != libxsmm_mmbatch_array /* batch-recording available, but not yet running */ + /* currently, batch recording is only enabled if all values are present (no complex filtering) */ + && NULL != flags && NULL != alpha && NULL != beta + && NULL != lda && NULL != ldb && NULL != ldc + && NULL != m && NULL != n && NULL != k + && LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_DEFAULT) == LIBXSMM_LOCK_TRYLOCK(LIBXSMM_LOCK_DEFAULT, &libxsmm_mmbatch_lock)) + { + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const descriptor = libxsmm_gemm_descriptor_init(&blob, precision, + *m, *n, *k, *lda, *ldb, *ldc, alpha, beta, *flags, libxsmm_get_gemm_prefetch(LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH)); + static int error_once = 0; + int result = EXIT_SUCCESS; + + if (NULL != descriptor) { + const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size); + unsigned int i; +#if !defined(NDEBUG) + const unsigned int mmbatch_maxdepth = LIBXSMM_UP2POT(LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH); + LIBXSMM_ASSERT((LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH) == mmbatch_maxdepth/*is pot*/); +#endif + /* eventually overwrite the oldest entry */ + i = LIBXSMM_MOD2(internal_ext_gemm_batchdepth, LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH); + internal_ext_gemm_batchdesc[i] = libxsmm_mmbatch_desc; /* backup */ + ++internal_ext_gemm_batchdepth; + + /* ensure descriptor does not match any GEMM such that... */ + LIBXSMM_MEMZERO127(&libxsmm_mmbatch_desc); + /* ...the batch stops and completely flushes */ + if (0 != internal_ext_gemm_batchsize) { + result = internal_mmbatch_flush(internal_ext_gemm_batchdesc + i, + (((libxsmm_blasint)internal_ext_gemm_batchsize - 1) % max_batchsize) + 1, + (libxsmm_mmbatch_item*)libxsmm_mmbatch_array); + } + + if (EXIT_SUCCESS == result) { /* enable descriptor */ + internal_ext_gemm_batchsize = 0; /* reset */ + if (0 == (LIBXSMM_MMBATCH_FLAG_STATISTIC & *flags)) { + libxsmm_mmbatch_desc = *descriptor; + } + else { + libxsmm_mmbatch_desc.flags = LIBXSMM_MMBATCH_FLAG_STATISTIC; + } + } + } + else { + result = EXIT_FAILURE; + } + if (EXIT_SUCCESS != result && 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: GEMM batch enabling failed!\n"); + } + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_DEFAULT, &libxsmm_mmbatch_lock); + } +# if defined(_MSC_VER) +# pragma warning(pop) +# endif +#else + LIBXSMM_UNUSED(precision); LIBXSMM_UNUSED(flags); + LIBXSMM_UNUSED(m); LIBXSMM_UNUSED(n); LIBXSMM_UNUSED(k); + LIBXSMM_UNUSED(lda); LIBXSMM_UNUSED(ldb); LIBXSMM_UNUSED(ldc); + LIBXSMM_UNUSED(alpha); LIBXSMM_UNUSED(beta); +#endif +} + + +LIBXSMM_APIEXT void libxsmm_mmbatch_end(void) +{ +#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT) +# if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 26115) /* try-lock is treated incorrectly by static analysis */ +# endif + /*const*/ int trystate = LIBXSMM_LOCK_TRYLOCK(LIBXSMM_LOCK_DEFAULT, &libxsmm_mmbatch_lock); + if (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_DEFAULT) == trystate) { + const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size); + const libxsmm_gemm_descriptor flushdesc = libxsmm_mmbatch_desc; + static int error_once = 0; +#if !defined(NDEBUG) + const unsigned int mmbatch_maxdepth = LIBXSMM_UP2POT(LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH); +#endif + /* ensure descriptor does not match any GEMM such that... */ + LIBXSMM_MEMZERO127(&libxsmm_mmbatch_desc); + /* ...the batch stops and completely flushes */ + if (EXIT_SUCCESS == internal_mmbatch_flush(&flushdesc, + 0 != internal_ext_gemm_batchsize ? (((internal_ext_gemm_batchsize - 1) % max_batchsize) + 1) : 0, + (libxsmm_mmbatch_item*)libxsmm_mmbatch_array)) + { + internal_ext_gemm_batchsize = 0; /* reset */ + --internal_ext_gemm_batchdepth; /* restore the previous descriptor */ + assert((LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH) == mmbatch_maxdepth/*is pot*/); /* no LIBXSMM_ASSERT! */ + libxsmm_mmbatch_desc = internal_ext_gemm_batchdesc[LIBXSMM_MOD2(internal_ext_gemm_batchdepth, LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH)]; + } + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: GEMM batch processing failed!\n"); + } + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_DEFAULT, &libxsmm_mmbatch_lock); + } +# if defined(_MSC_VER) +# pragma warning(pop) +# endif +#endif +} + + +#if defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__)) + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_xgemm_omp)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*, + const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const double*, const double*, const libxsmm_blasint*, const double*, const libxsmm_blasint*, + const double*, double*, const libxsmm_blasint*); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_xgemm_omp)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec, + const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const double* alpha, const double* a, const libxsmm_blasint* lda, const double* b, const libxsmm_blasint* ldb, + const double* beta, double* c, const libxsmm_blasint* ldc) +{ + LIBXSMM_ASSERT(NULL != iprec && NULL != oprec); + libxsmm_xgemm_omp(*iprec, *oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_dgemm_omp)(const char*, const char*, + const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const double*, const double*, const libxsmm_blasint*, + const double*, const libxsmm_blasint*, + const double*, double*, const libxsmm_blasint*); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_dgemm_omp)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const double* alpha, const double* a, const libxsmm_blasint* lda, + const double* b, const libxsmm_blasint* ldb, + const double* beta, double* c, const libxsmm_blasint* ldc) +{ + libxsmm_dgemm_omp(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_sgemm_omp)(const char*, const char*, + const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const float*, const float*, const libxsmm_blasint*, + const float*, const libxsmm_blasint*, + const float*, float*, const libxsmm_blasint*); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_sgemm_omp)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const float* a, const libxsmm_blasint* lda, + const float* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) +{ + libxsmm_sgemm_omp(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_gemm_batch_omp)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*, + const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const void*, const void*, const libxsmm_blasint*, const void*, const libxsmm_blasint*, + const void*, void*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const libxsmm_blasint[], const libxsmm_blasint[], const libxsmm_blasint[], + const libxsmm_blasint*); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_gemm_batch_omp)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec, + const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc, const libxsmm_blasint* index_base, const libxsmm_blasint* index_stride, + const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + const libxsmm_blasint* batchsize) +{ + LIBXSMM_ASSERT(NULL != iprec && NULL != oprec && NULL != m && NULL != n && NULL != k); + LIBXSMM_ASSERT(NULL != index_base && NULL != index_stride && NULL != batchsize); + libxsmm_gemm_batch_omp(*iprec, *oprec, transa, transb, *m, *n, *k, alpha, a, lda, b, ldb, beta, c, ldc, + *index_base, *index_stride, stride_a, stride_b, stride_c, *batchsize); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_mmbatch_begin)(const libxsmm_gemm_precision*, + const int*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const void*, const void*); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_mmbatch_begin)(const libxsmm_gemm_precision* precision, + const int* flags, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const void* alpha, const void* beta) +{ + LIBXSMM_ASSERT(NULL != precision); + libxsmm_mmbatch_begin(*precision, flags, m, n, k, lda, ldb, ldc, alpha, beta); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_mmbatch_end)(void); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_mmbatch_end)(void) +{ + libxsmm_mmbatch_end(); +} + +#endif /*defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/ + diff --git a/third_party/libxsmm/src/libxsmm_ext_xcopy.c b/third_party/libxsmm/src/libxsmm_ext_xcopy.c new file mode 100644 index 0000000000000000000000000000000000000000..b6f2c35a1f918086f9f1c59e4c34eab94ae7eb69 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_ext_xcopy.c @@ -0,0 +1,472 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_xcopy.h" +#include "libxsmm_ext.h" + +#define LIBXSMM_MCOPY_MT(MT, NT, M, N) ((MT) <= (M) && (NT) <= (N) && (64U * 64U) <= (((unsigned int)(M)) * (N))) + + +LIBXSMM_APIEXT void libxsmm_matcopy_omp(void* out, const void* in, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo) +{ + LIBXSMM_INIT + if (0 < typesize && 256 > typesize && m <= ldi && m <= ldo && out != in && + ((NULL != out && 0 < m && 0 < n) || (0 == m && 0 == n))) + { + if (0 < m && 0 < n) { +#if defined(_OPENMP) + unsigned int tm, tn, ts; + if (NULL != in) { /* mcopy */ + tm = LIBXSMM_UPDIV(libxsmm_mcopy_mbytes, typesize); + tn = (unsigned int)(libxsmm_mcopy_nscale * tm); + ts = libxsmm_mcopy_mbytes; + } + else { /* mzero */ + tm = LIBXSMM_UPDIV(libxsmm_mzero_mbytes, typesize); + tn = (unsigned int)(libxsmm_mzero_nscale * tm); + ts = libxsmm_mzero_mbytes; + } + if (0 == tm) tm = m; + if (0 == tn) tn = LIBXSMM_MIN(LIBXSMM_XCOPY_TILE_MIN, n); + if (0 != ts && ts < (tm * tn * typesize)) { + tm = LIBXSMM_MAX(ts / (tn * typesize), LIBXSMM_XCOPY_TILE_MIN); + } + if (LIBXSMM_MCOPY_MT(tm, tn, (unsigned int)m, (unsigned int)n)) { /* consider problem-size */ + libxsmm_xcopykernel kernel; + kernel.ptr = NULL; +# if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 2)) + if (0 != (2 & libxsmm_xcopy_jit)) { /* JIT'ted matrix-copy permitted? */ + switch (typesize) { + case 8: kernel.function = libxsmm_dispatch_meltw_unary(tm, tn, &ldi, &ldo, + LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_MELTW_FLAG_UNARY_NONE, + NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/); + break; + case 4: kernel.function = libxsmm_dispatch_meltw_unary(tm, tn, &ldi, &ldo, + LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_MELTW_FLAG_UNARY_NONE, + NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/); + break; + case 2: kernel.function = libxsmm_dispatch_meltw_unary(tm, tn, &ldi, &ldo, + LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_MELTW_FLAG_UNARY_NONE, + NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/); + break; + case 1: kernel.function = libxsmm_dispatch_meltw_unary(tm, tn, &ldi, &ldo, + LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_MELTW_FLAG_UNARY_NONE, + NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/); + break; + } + } +# endif +# if defined(LIBXSMM_EXT_TASKS) && 0/* implies _OPENMP */ + if (0 == omp_get_active_level()) +# else + if (0 == omp_in_parallel()) +# endif + { /* enable internal parallelization */ + const int nthreads = omp_get_max_threads(); +# if defined(LIBXSMM_EXT_TASKS) + if (0 >= libxsmm_xcopy_taskscale) +# endif + { +# pragma omp parallel num_threads(nthreads) + libxsmm_matcopy_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, omp_get_thread_num(), nthreads); + } +# if defined(LIBXSMM_EXT_TASKS) + else { /* tasks requested */ + const int ntasks = nthreads * libxsmm_xcopy_taskscale; +# pragma omp parallel num_threads(nthreads) + { /* first thread discovering work will launch all tasks */ +# pragma omp single nowait /* anyone is good */ + { int tid; + for (tid = 0; tid < ntasks; ++tid) { +# pragma omp task untied + libxsmm_matcopy_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, tid, ntasks); + } + } + } + } +# endif + } + else { /* assume external parallelization */ +# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */ + const int nthreads = omp_get_num_threads(); + const int ntasks = (0 == libxsmm_xcopy_taskscale + ? (LIBXSMM_XCOPY_TASKSCALE) + : libxsmm_xcopy_taskscale) * nthreads; + int tid; + for (tid = 0; tid < ntasks; ++tid) { +# pragma omp task untied + libxsmm_matcopy_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, tid, ntasks); + } + if (0 == libxsmm_nosync) { /* allow to omit synchronization */ +# pragma omp taskwait + } +# else + libxsmm_matcopy_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, 0/*tid*/, 1/*nthreads*/); +# endif + } + } + else +#endif /*defined(_OPENMP)*/ + if (NULL != in) { /* no MT, or small problem-size */ + LIBXSMM_XCOPY_NONJIT(LIBXSMM_MCOPY_KERNEL, + typesize, out, in, ldi, ldo, 0, m, 0, n); + } + else { /* no MT, or small problem-size */ + /* coverity[ptr_arith] */ + LIBXSMM_XCOPY_NONJIT(LIBXSMM_MZERO_KERNEL, + typesize, out, in, ldi, ldo, 0, m, 0, n); + } + } + } + else { + static int error_once = 0; + if ( 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + if (NULL == out) { + fprintf(stderr, "LIBXSMM ERROR: the matrix-copy input and/or output is NULL!\n"); + } + else if (out == in) { + fprintf(stderr, "LIBXSMM ERROR: output and input of the matrix-copy must be different!\n"); + } + else if (0 == typesize || 256 <= typesize) { + fprintf(stderr, "LIBXSMM ERROR: invalid type-size for matrix-copy specified!\n"); + } + else if (ldi < m || ldo < m) { + fprintf(stderr, "LIBXSMM ERROR: the leading dimension(s) of the matrix-copy is/are too small!\n"); + } + else if (0 > m || 0 > n) { + fprintf(stderr, "LIBXSMM ERROR: the matrix extent(s) of the matrix-copy is/are negative!\n"); + } + } + } +} + + +LIBXSMM_APIEXT void libxsmm_otrans_omp(void* out, const void* in, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo) +{ + static int error_once = 0; + LIBXSMM_INIT + if (0 < typesize && 256 > typesize && m <= ldi && n <= ldo && + ((NULL != out && NULL != in && 0 < m && 0 < n) || (0 == m && 0 == n))) + { + if (0 < m && 0 < n) { + if (out != in) { +#if defined(_OPENMP) + unsigned int tm = LIBXSMM_UPDIV(libxsmm_tcopy_mbytes, typesize); + unsigned int tn = (unsigned int)(libxsmm_tcopy_nscale * tm); + if (0 == tm) tm = m; + if (0 == tn) tn = LIBXSMM_MIN(LIBXSMM_XCOPY_TILE_MIN, n); + if (0 != libxsmm_tcopy_mbytes && libxsmm_tcopy_mbytes < (tm * tn * typesize)) { + tm = LIBXSMM_MAX(libxsmm_tcopy_mbytes / (tn * typesize), LIBXSMM_XCOPY_TILE_MIN); + } + if (tm <= (unsigned int)m && tn <= (unsigned int)n) { /* consider problem-size */ + libxsmm_xcopykernel kernel; + kernel.ptr = NULL; +# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */ + if (0 == omp_get_active_level()) +# else + if (0 == omp_in_parallel()) +# endif + { /* enable internal parallelization */ + const int nthreads = omp_get_max_threads(); +# if defined(LIBXSMM_EXT_TASKS) + if (0 >= libxsmm_xcopy_taskscale) +# endif + { +# pragma omp parallel num_threads(nthreads) + { /* coverity[divide_by_zero] */ + libxsmm_otrans_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, omp_get_thread_num(), nthreads); + } + } +# if defined(LIBXSMM_EXT_TASKS) + else { /* tasks requested */ + const int ntasks = nthreads * libxsmm_xcopy_taskscale; +# pragma omp parallel num_threads(nthreads) + { /* first thread discovering work will launch all tasks */ +# pragma omp single nowait /* anyone is good */ + { int tid; + for (tid = 0; tid < ntasks; ++tid) { +# pragma omp task untied + libxsmm_otrans_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, tid, ntasks); + } + } + } + } +# endif + } + else { /* assume external parallelization */ +# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */ + const int nthreads = omp_get_num_threads(); + const int ntasks = (0 == libxsmm_xcopy_taskscale + ? (LIBXSMM_XCOPY_TASKSCALE) + : libxsmm_xcopy_taskscale) * nthreads; + int tid; + for (tid = 0; tid < ntasks; ++tid) { +# pragma omp task untied + libxsmm_otrans_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, tid, ntasks); + } + if (0 == libxsmm_nosync) { /* allow to omit synchronization */ +# pragma omp taskwait + } +# else /* coverity[divide_by_zero] */ + libxsmm_otrans_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, 0/*tid*/, 1/*nthreads*/); +# endif + } + } + else +#endif /*defined(_OPENMP)*/ + { /* no MT, or small problem-size */ +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1)) + libxsmm_xcopykernel kernel; + kernel.ptr = NULL; + if (0 != (1 & libxsmm_xcopy_jit)) { /* JIT'ted transpose permitted? */ + switch (typesize) { + case 8: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 4: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 2: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 1: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + } + if (NULL != kernel.ptr) { /* JIT-kernel available */ + LIBXSMM_TCOPY_CALL(kernel, typesize, in, ldi, out, ldo); + } + } + else +#endif + { + LIBXSMM_XCOPY_NONJIT(LIBXSMM_TCOPY_KERNEL, + typesize, out, in, ldi, ldo, 0, m, 0, n); + } + } + } + else if (ldi == ldo) { + libxsmm_itrans/*TODO: omp*/(out, typesize, m, n, ldi, ldo); + } + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: output and input of the transpose must be different!\n"); + } + } + } + else { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + if (NULL == out || NULL == in) { + fprintf(stderr, "LIBXSMM ERROR: the transpose input and/or output is NULL!\n"); + } + else if (out == in) { + fprintf(stderr, "LIBXSMM ERROR: output and input of the transpose must be different!\n"); + } + else if (0 == typesize || 256 <= typesize) { + fprintf(stderr, "LIBXSMM ERROR: invalid type-size for matrix-transpose specified!\n"); + } + else if (ldi < m || ldo < n) { + fprintf(stderr, "LIBXSMM ERROR: the leading dimension(s) of the transpose is/are too small!\n"); + } + else if (0 > m || 0 > n) { + fprintf(stderr, "LIBXSMM ERROR: the matrix extent(s) of the transpose is/are negative!\n"); + } + } + } +} + + +LIBXSMM_APIEXT void libxsmm_itrans_batch_omp(void* inout, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo, + libxsmm_blasint index_base, libxsmm_blasint index_stride, + const libxsmm_blasint stride[], libxsmm_blasint batchsize) +{ +#if defined(_OPENMP) + if (1 < batchsize) { /* consider problem-size */ + const libxsmm_blasint scratchsize = m * n * typesize; + const libxsmm_blasint size = LIBXSMM_ABS(batchsize); + char buffer[LIBXSMM_ITRANS_BUFFER_MAXSIZE]; + char *const mat0 = (char*)inout; + void* scratch = NULL; + libxsmm_xcopykernel kernel = { NULL }; + if (m != n || ldi != ldo || 127 < typesize) { + if (scratchsize <= LIBXSMM_ITRANS_BUFFER_MAXSIZE) { + scratch = buffer; + } + else { + static int error_once = 0; + LIBXSMM_INIT + if (EXIT_SUCCESS != libxsmm_xmalloc(&scratch, scratchsize, 0/*auto-align*/, + LIBXSMM_MALLOC_FLAG_SCRATCH | LIBXSMM_MALLOC_FLAG_PRIVATE, + 0/*extra*/, 0/*extra_size*/) + && 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: failed to allocate buffer for in-place transpose!\n"); + } + } +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1)) + if (0 != (1 & libxsmm_xcopy_jit) /* JIT'ted transpose permitted? */ + /* avoid outgrown transpose kernel upfront */ + && (m <= LIBXSMM_CONFIG_MAX_DIM || n <= LIBXSMM_CONFIG_MAX_DIM)) + { + switch (typesize) { + case 8: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 4: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 2: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 1: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + } + } +#endif + } +# if defined(LIBXSMM_EXT_TASKS) && 0/* implies _OPENMP */ + if (0 == omp_get_active_level()) +# else + if (0 == omp_in_parallel()) +# endif + { /* enable internal parallelization */ + const int nthreads = omp_get_max_threads(); +# if defined(LIBXSMM_EXT_TASKS) + if (0 >= libxsmm_xcopy_taskscale) +# endif + { + const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, nthreads); +# pragma omp parallel num_threads(nthreads) + { + const libxsmm_blasint begin = omp_get_thread_num() * tasksize; + const libxsmm_blasint span = begin + tasksize; + libxsmm_itrans_internal(mat0, scratch, typesize, m, n, ldi, ldo, index_base, + index_stride, stride, kernel, begin, LIBXSMM_MIN(span, size)); + } + } +# if defined(LIBXSMM_EXT_TASKS) + else { /* tasks requested */ + const int ntasks = nthreads * libxsmm_xcopy_taskscale; + const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, ntasks); +# pragma omp parallel num_threads(nthreads) + { /* first thread discovering work will launch all tasks */ +# pragma omp single nowait /* anyone is good */ + { int tid; + for (tid = 0; tid < ntasks; ++tid) { + const libxsmm_blasint begin = tid * tasksize; + const libxsmm_blasint span = begin + tasksize; +# pragma omp task untied + libxsmm_itrans_internal(mat0, scratch, typesize, m, n, ldi, ldo, index_base, + index_stride, stride, kernel, begin, LIBXSMM_MIN(span, size)); + } + } + } + } +# endif + } + else { /* assume external parallelization */ +# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */ + const int nthreads = omp_get_num_threads(); + const int ntasks = (0 == libxsmm_xcopy_taskscale + ? (LIBXSMM_XCOPY_TASKSCALE) + : libxsmm_xcopy_taskscale) * nthreads; + const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, ntasks); + int tid; + for (tid = 0; tid < ntasks; ++tid) { + const libxsmm_blasint begin = tid * tasksize; + const libxsmm_blasint span = begin + tasksize; +# pragma omp task untied + libxsmm_itrans_internal(mat0, scratch, typesize, m, n, ldi, ldo, index_base, + index_stride, stride, kernel, begin, LIBXSMM_MIN(span, size)); + } + if (0 == libxsmm_nosync) { /* allow to omit synchronization */ +# pragma omp taskwait + } +# else + libxsmm_itrans_internal(mat0, scratch, typesize, m, n, ldi, ldo, index_base, + index_stride, stride, kernel, 0, batchsize); +# endif + } + if (NULL != scratch && LIBXSMM_ITRANS_BUFFER_MAXSIZE < scratchsize) { + libxsmm_xfree(scratch, 0/*no check*/); + } + } + else +#endif /*defined(_OPENMP)*/ + libxsmm_itrans_batch(inout, typesize, m, n, ldi, ldo, + index_base, index_stride, stride, batchsize, + 0/*tid*/, 1/*ntasks*/); +} + + +#if defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__)) + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_matcopy_omp)(void* /*out*/, const void* /*in*/, const int* /*typesize*/, + const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const libxsmm_blasint* /*ldi*/, const libxsmm_blasint* /*ldo*/); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_matcopy_omp)(void* out, const void* in, const int* typesize, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo) +{ + libxsmm_blasint ldx; + LIBXSMM_ASSERT(NULL != typesize && 0 < *typesize && NULL != m); + ldx = *(NULL != ldi ? ldi : m); + libxsmm_matcopy_omp(out, in, (unsigned int)*typesize, *m, *(NULL != n ? n : m), ldx, NULL != ldo ? *ldo : ldx); +} + + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_otrans_omp)(void* /*out*/, const void* /*in*/, const int* /*typesize*/, + const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const libxsmm_blasint* /*ldi*/, const libxsmm_blasint* /*ldo*/); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_otrans_omp)(void* out, const void* in, const int* typesize, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo) +{ + libxsmm_blasint ldx; + LIBXSMM_ASSERT(NULL != typesize && 0 < *typesize && NULL != m); + ldx = *(NULL != ldi ? ldi : m); + libxsmm_otrans_omp(out, in, (unsigned int)*typesize, *m, *(NULL != n ? n : m), ldx, NULL != ldo ? *ldo : ldx); +} + +#endif /*defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/ + diff --git a/third_party/libxsmm/src/libxsmm_fsspmdm.c b/third_party/libxsmm/src/libxsmm_fsspmdm.c new file mode 100644 index 0000000000000000000000000000000000000000..5bcc447c8d655a2240d22b6c8b01c65201f67762 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_fsspmdm.c @@ -0,0 +1,602 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "generator_spgemm_csr_asparse_reg.h" +#include +#include "libxsmm_main.h" + + +/* Double precision AVX-512 lane broadcasts */ +LIBXSMM_APIVAR_DEFINE(const double* internal_fsspmdm_dperm); +/* Single precision AVX-512 lane broadcasts */ +LIBXSMM_APIVAR_DEFINE(const float* internal_fsspmdm_sperm); + + +LIBXSMM_API_INTERN void internal_dfsspmdm_init(void); +LIBXSMM_API_INTERN void internal_dfsspmdm_init(void) +{ + LIBXSMM_ALIGNED(static const unsigned int dperm[], LIBXSMM_ALIGNMENT) = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, + 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, + 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15 + }; + LIBXSMM_ASSERT(NULL == internal_fsspmdm_dperm); + internal_fsspmdm_dperm = (const double*)((const void*)dperm); +} + + +LIBXSMM_API_INTERN void internal_sfsspmdm_init(void); +LIBXSMM_API_INTERN void internal_sfsspmdm_init(void) +{ + LIBXSMM_ALIGNED(static const unsigned int sperm[], LIBXSMM_ALIGNMENT) = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15 + }; + LIBXSMM_ASSERT(NULL == internal_fsspmdm_sperm); + internal_fsspmdm_sperm = (const float*)((const void*)sperm); +} + + +LIBXSMM_API libxsmm_dfsspmdm* libxsmm_dfsspmdm_create( + libxsmm_blasint M, libxsmm_blasint N, libxsmm_blasint K, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + const double alpha, const double beta, libxsmm_blasint c_is_nt, + const double* a_dense) +{ + double one = 1.0; + double* a_csr_values = NULL; + unsigned int* a_csr_rowptr = NULL; + unsigned int* a_csr_colidx = NULL; + double* aa_dense = NULL; + int flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + const libxsmm_gemm_prefetch_type prefetch = LIBXSMM_GEMM_PREFETCH_NONE; + const libxsmm_gemm_descriptor* xgemm_desc; + libxsmm_descriptor_blob xgemm_blob; + libxsmm_dfsspmdm* new_handle = NULL; + libxsmm_dmmfunction k_sparse1 = NULL; + libxsmm_dmmfunction k_sparse2 = NULL; + libxsmm_dmmfunction k_dense = NULL; + int i, j, n, a_nnz, N_sparse1, N_sparse2, N_dense, nkerns; + + /* internal lazy initialization */ + if (NULL == internal_fsspmdm_dperm) internal_dfsspmdm_init(); + + /* some checks... */ + assert(N % 8 == 0); + assert(N >= 8); + assert(LIBXSMM_FEQ(beta, 1.0) || LIBXSMM_FEQ(beta, 0.0)); + assert(K <= lda); + assert(N <= ldc); + assert(N <= ldb); + + /* Get the number of non-zeros */ + a_nnz = 0; + for (i = 0; i < M; i++) { + for (j = 0; j < K; j++) { + if (LIBXSMM_NEQ(a_dense[(i*lda) + j], 0.0)) { + a_nnz++; + } + } + } + + /* Null matrix */ + if ( 0 == a_nnz ) return NULL; + + /* Allocate handle */ + new_handle = (libxsmm_dfsspmdm*)malloc(sizeof(libxsmm_dfsspmdm)); + if ( NULL == new_handle ) return NULL; + + /* Initialize the handle */ + LIBXSMM_MEMZERO127(new_handle); + /* TODO: in case of ILP64, check value ranges */ + new_handle->N = (int)N; + new_handle->M = (int)M; + new_handle->K = (int)K; + new_handle->ldb = (int)ldb; + new_handle->ldc = (int)ldc; + + /* update flags */ + if ( beta == 0.0 && c_is_nt != 0 ) { + flags |= LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT; + } + + /* Allocate CSR structure */ + a_csr_values = (double*)malloc((size_t)a_nnz * sizeof(double)); + a_csr_rowptr = (unsigned int*)malloc(((size_t)M + 1) * sizeof(unsigned int)); + a_csr_colidx = (unsigned int*)malloc((size_t)a_nnz * sizeof(unsigned int)); + + /* Allocate dense storage */ + aa_dense = (double*)libxsmm_aligned_malloc((size_t)M * (size_t)K * sizeof(double), 64); + + if ( NULL == a_csr_values || NULL == a_csr_rowptr || NULL == a_csr_colidx || NULL == aa_dense ) { + free( a_csr_values ); free( a_csr_rowptr ); free( a_csr_colidx ); + free( new_handle ); + libxsmm_free( aa_dense ); + return NULL; + } + + /* Populate CSR structure */ + for (i = 0, n = 0; i < M; i++) { + a_csr_rowptr[i] = n; + for (j = 0; j < K; j++) { + if (LIBXSMM_NEQ(a_dense[(i*lda) + j], 0.0)) { + a_csr_values[n] = alpha*a_dense[(i*lda) + j]; + a_csr_colidx[n] = j; + n++; + } + } + } + a_csr_rowptr[M] = a_nnz; + + /* Attempt to JIT a sparse kernel */ + N_sparse1 = libxsmm_cpuid_vlen32(libxsmm_cpuid()) / 2; + xgemm_desc = libxsmm_dgemm_descriptor_init(&xgemm_blob, M, N_sparse1, K, + 0, ldb, ldc, one, beta, flags, prefetch); + if ( NULL != xgemm_desc ) { + k_sparse1 = libxsmm_create_dcsr_reg(xgemm_desc, a_csr_rowptr, a_csr_colidx, a_csr_values); + } + + /* If that worked try to JIT a second (wider) sparse kernel */ + N_sparse2 = N_sparse1*2; + if ( NULL != k_sparse1 && N_sparse2 <= N ) { + xgemm_desc = libxsmm_dgemm_descriptor_init(&xgemm_blob, M, N_sparse2, K, + 0, ldb, ldc, one, beta, flags, prefetch); + + if ( NULL != xgemm_desc ) { + k_sparse2 = libxsmm_create_dcsr_reg(xgemm_desc, a_csr_rowptr, a_csr_colidx, a_csr_values); + } + } + + /* Free CSR */ + free( a_csr_values ); + free( a_csr_rowptr ); + free( a_csr_colidx ); + + /* Also generate a dense kernel */ + N_dense = 8; + k_dense = libxsmm_dmmdispatch(N_dense, M, K, &ldb, &K, &ldc, &one, &beta, &flags, (const int*)LIBXSMM_GEMM_PREFETCH_NONE); + + if ( NULL != k_dense ) { + /* copy A over */ + for ( i = 0; i < M; ++i ) { + for ( j = 0; j < K; ++j ) { + aa_dense[i*K + j] = alpha*a_dense[i*lda + j]; + } + } + } + + /* Tally up how many kernels we got */ + nkerns = !!k_dense + !!k_sparse1 + !!k_sparse2; + + /* We have at least one kernel */ + if ( nkerns ) { + libxsmm_timer_tickint t; + double *B = NULL, *C = NULL; + double dt_dense = ( NULL != k_dense ) ? 1e5 : 1e6; + double dt_sparse1 = ( NULL != k_sparse1 ) ? 1e5 : 1e6; + double dt_sparse2 = ( NULL != k_sparse2 ) ? 1e5 : 1e6; + void* fp; + + /* If we have two or more kernels then try to benchmark them */ + if ( nkerns >= 2 ) { + B = (double*)libxsmm_aligned_malloc((size_t)K * (size_t)ldb * sizeof(double), 64); + C = (double*)libxsmm_aligned_malloc((size_t)M * (size_t)ldc * sizeof(double), 64); + + if ( NULL != B && NULL != C ) { + for ( i = 0; i < K; i++ ) { + for ( j = 0; j < N; j++ ) { + B[i*ldb + j] = 1; + } + } + for ( i = 0; i < M; i++ ) { + for ( j = 0; j < N; j++ ) { + C[i*ldc + j] = 1; + } + } + } + } + + /* Benchmark dense */ + if ( NULL != k_dense && NULL != B && NULL != C ) { + t = libxsmm_timer_tick(); + for ( i = 0; i < 250; i++ ) { + for ( j = 0; j < N; j += N_dense ) { + k_dense( B + j, aa_dense, C + j ); + } + } + dt_dense = libxsmm_timer_duration( t, libxsmm_timer_tick() ); + } + + /* Benchmark sparse (regular) */ + if ( NULL != k_sparse1 && NULL != B && NULL != C ) { + t = libxsmm_timer_tick(); + for ( i = 0; i < 250; i++ ) { + for ( j = 0; j < N; j += N_sparse1 ) { + k_sparse1( internal_fsspmdm_dperm, B + j, C + j ); + } + } + dt_sparse1 = libxsmm_timer_duration( t, libxsmm_timer_tick() ); + } + + /* Benchmark sparse (wide) */ + if ( NULL != k_sparse2 && NULL != B && NULL != C ) { + t = libxsmm_timer_tick(); + for ( i = 0; i < 250; i++ ) { + for ( j = 0; j < N; j += N_sparse2 ) { + k_sparse2( internal_fsspmdm_dperm, B + j, C + j ); + } + } + dt_sparse2 = libxsmm_timer_duration( t, libxsmm_timer_tick() ); + } + + /* Dense fastest */ + if ( dt_dense <= dt_sparse1 && dt_dense <= dt_sparse2 ) { + new_handle->N_chunksize = N_dense; + new_handle->kernel = k_dense; + new_handle->a_dense = aa_dense; + } else { + libxsmm_free( aa_dense ); + } + + /* Sparse (regular) fastest */ + if ( dt_sparse1 < dt_dense && dt_sparse1 <= dt_sparse2 ) { + new_handle->N_chunksize = N_sparse1; + new_handle->kernel = k_sparse1; + } else if ( NULL != k_sparse1 ) { + LIBXSMM_ASSIGN127( &fp, &k_sparse1 ); + libxsmm_free( fp ); + } + + /* Sparse (wide) fastest */ + if ( dt_sparse2 < dt_dense && dt_sparse2 < dt_sparse1 ) { + new_handle->N_chunksize = N_sparse2; + new_handle->kernel = k_sparse2; + } else if ( NULL != k_sparse2 ) { + LIBXSMM_ASSIGN127( &fp, &k_sparse2 ); + libxsmm_free( fp ); + } + + libxsmm_free( B ); + libxsmm_free( C ); + } + else { + libxsmm_free( aa_dense ); + free( new_handle ); + new_handle = NULL; + } + + return new_handle; +} + + +LIBXSMM_API libxsmm_sfsspmdm* libxsmm_sfsspmdm_create( + libxsmm_blasint M, libxsmm_blasint N, libxsmm_blasint K, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + const float alpha, const float beta, libxsmm_blasint c_is_nt, + const float* a_dense) +{ + float one = 1.0f; + float* a_csr_values = NULL; + unsigned int* a_csr_rowptr = NULL; + unsigned int* a_csr_colidx = NULL; + float* aa_dense = NULL; + int flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + const libxsmm_gemm_prefetch_type prefetch = LIBXSMM_GEMM_PREFETCH_NONE; + const libxsmm_gemm_descriptor* xgemm_desc; + libxsmm_descriptor_blob xgemm_blob; + libxsmm_sfsspmdm* new_handle = NULL; + libxsmm_smmfunction k_sparse1 = NULL; + libxsmm_smmfunction k_sparse2 = NULL; + libxsmm_smmfunction k_dense = NULL; + int i, j, n, a_nnz, N_sparse1, N_sparse2, N_dense, nkerns; + + /* internal lazy initialization */ + if (NULL == internal_fsspmdm_sperm) internal_sfsspmdm_init(); + + /* some checks... */ + assert(N % 16 == 0); + assert(N >= 16); + assert(LIBXSMM_FEQ(beta, 1.0f) || LIBXSMM_FEQ(beta, 0.0f)); + assert(K <= lda); + assert(N <= ldc); + assert(N <= ldb); + + /* Get the number of non-zeros */ + a_nnz = 0; + for (i = 0; i < M; i++) { + for (j = 0; j < K; j++) { + if (LIBXSMM_NEQ(a_dense[(i*lda) + j], 0.0)) { + a_nnz++; + } + } + } + + /* Null matrix */ + if ( 0 == a_nnz ) return 0; + + /* Allocate handle */ + new_handle = (libxsmm_sfsspmdm*)malloc(sizeof(libxsmm_sfsspmdm)); + if ( NULL == new_handle ) return NULL; + + /* Initialize the handle */ + LIBXSMM_MEMZERO127(new_handle); + /* TODO: in case of ILP64, check value ranges */ + new_handle->N = (int)N; + new_handle->M = (int)M; + new_handle->K = (int)K; + new_handle->ldb = (int)ldb; + new_handle->ldc = (int)ldc; + + /* update flags */ + if ( beta == 0.0 && c_is_nt != 0 ) { + flags |= LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT; + } + + /* Allocate CSR structure */ + a_csr_values = (float*)malloc((size_t)a_nnz * sizeof(float)); + a_csr_rowptr = (unsigned int*)malloc(((size_t)M + 1) * sizeof(unsigned int)); + a_csr_colidx = (unsigned int*)malloc((size_t)a_nnz * sizeof(unsigned int)); + + /* Allocate dense storage */ + aa_dense = (float*)libxsmm_aligned_malloc((size_t)M * (size_t)K * sizeof(float), 64); + + if ( NULL == a_csr_values || NULL == a_csr_rowptr || NULL == a_csr_colidx || NULL == aa_dense ) { + free( a_csr_values ); free( a_csr_rowptr ); free( a_csr_colidx ); + free( new_handle ); + libxsmm_free( aa_dense ); + return NULL; + } + + /* Populate CSR structure */ + for (i = 0, n = 0; i < M; i++) { + a_csr_rowptr[i] = n; + for (j = 0; j < K; j++) { + if (LIBXSMM_NEQ(a_dense[(i*lda) + j], 0.0f)) { + a_csr_values[n] = alpha*a_dense[(i*lda) + j]; + a_csr_colidx[n] = j; + n++; + } + } + } + a_csr_rowptr[M] = a_nnz; + + /* Attempt to JIT a sparse kernel */ + N_sparse1 = libxsmm_cpuid_vlen32(libxsmm_cpuid()); + xgemm_desc = libxsmm_sgemm_descriptor_init(&xgemm_blob, M, N_sparse1, K, + 0, ldb, ldc, one, beta, flags, prefetch); + if ( NULL != xgemm_desc ) { + k_sparse1 = libxsmm_create_scsr_reg(xgemm_desc, a_csr_rowptr, a_csr_colidx, a_csr_values); + } + + /* If that worked try to JIT a second (wider) sparse kernel */ + N_sparse2 = N_sparse1*2; + if ( NULL != k_sparse1 && N_sparse2 <= N ) { + xgemm_desc = libxsmm_sgemm_descriptor_init(&xgemm_blob, M, N_sparse2, K, + 0, ldb, ldc, one, beta, flags, prefetch); + + if ( NULL != xgemm_desc ) { + k_sparse2 = libxsmm_create_scsr_reg(xgemm_desc, a_csr_rowptr, a_csr_colidx, a_csr_values); + } + } + + /* Free CSR */ + free( a_csr_values ); + free( a_csr_rowptr ); + free( a_csr_colidx ); + + /* Also generate a dense kernel */ + N_dense = 16; + k_dense = libxsmm_smmdispatch(N_dense, M, K, &ldb, &K, &ldc, &one, &beta, &flags, (const int*)LIBXSMM_GEMM_PREFETCH_NONE); + + if ( NULL != k_dense ) { + /* copy A over */ + for ( i = 0; i < M; ++i ) { + for ( j = 0; j < K; ++j ) { + aa_dense[i*K + j] = alpha*a_dense[i*lda + j]; + } + } + } + + /* Tally up how many kernels we got */ + nkerns = !!k_dense + !!k_sparse1 + !!k_sparse2; + + /* We have at least one kernel */ + if ( nkerns ) { + libxsmm_timer_tickint t; + float *B = NULL, *C = NULL; + double dt_dense = ( NULL != k_dense ) ? 1e5 : 1e6; + double dt_sparse1 = ( NULL != k_sparse1 ) ? 1e5 : 1e6; + double dt_sparse2 = ( NULL != k_sparse2 ) ? 1e5 : 1e6; + void* fp; + + /* If we have two or more kernels then try to benchmark them */ + if ( nkerns >= 2 ) { + B = (float*)libxsmm_aligned_malloc((size_t)K * (size_t)ldb * sizeof(float), 64); + C = (float*)libxsmm_aligned_malloc((size_t)M * (size_t)ldc * sizeof(float), 64); + + if ( NULL != B && NULL != C ) { + for ( i = 0; i < K; i++ ) { + for ( j = 0; j < N; j++ ) { + B[i*ldb + j] = 1; + } + } + for ( i = 0; i < M; i++ ) { + for ( j = 0; j < N; j++ ) { + C[i*ldc + j] = 1; + } + } + } + } + + /* Benchmark dense */ + if ( NULL != k_dense && NULL != B && NULL != C ) { + t = libxsmm_timer_tick(); + for ( i = 0; i < 250; i++ ) { + for ( j = 0; j < N; j += N_dense ) { + k_dense( B + j, aa_dense, C + j ); + } + } + dt_dense = libxsmm_timer_duration( t, libxsmm_timer_tick() ); + } + + /* Benchmark sparse (regular) */ + if ( NULL != k_sparse1 && NULL != B && NULL != C ) { + t = libxsmm_timer_tick(); + for ( i = 0; i < 250; i++ ) { + for ( j = 0; j < N; j += N_sparse1 ) { + k_sparse1( internal_fsspmdm_sperm, B + j, C + j ); + } + } + dt_sparse1 = libxsmm_timer_duration( t, libxsmm_timer_tick() ); + } + + /* Benchmark sparse (wide) */ + if ( NULL != k_sparse2 && NULL != B && NULL != C ) { + t = libxsmm_timer_tick(); + for ( i = 0; i < 250; i++ ) { + for ( j = 0; j < N; j += N_sparse2 ) { + k_sparse2( internal_fsspmdm_sperm, B + j, C + j ); + } + } + dt_sparse2 = libxsmm_timer_duration( t, libxsmm_timer_tick() ); + } + + /* Dense fastest */ + if ( dt_dense <= dt_sparse1 && dt_dense <= dt_sparse2 ) { + new_handle->N_chunksize = N_dense; + new_handle->kernel = k_dense; + new_handle->a_dense = aa_dense; + } else { + libxsmm_free( aa_dense ); + } + + /* Sparse (regular) fastest */ + if ( dt_sparse1 < dt_dense && dt_sparse1 <= dt_sparse2 ) { + new_handle->N_chunksize = N_sparse1; + new_handle->kernel = k_sparse1; + } else if ( NULL != k_sparse1 ) { + LIBXSMM_ASSIGN127( &fp, &k_sparse1 ); + libxsmm_free( fp ); + } + + /* Sparse (wide) fastest */ + if ( dt_sparse2 < dt_dense && dt_sparse2 < dt_sparse1 ) { + new_handle->N_chunksize = N_sparse2; + new_handle->kernel = k_sparse2; + } else if ( NULL != k_sparse2 ) { + LIBXSMM_ASSIGN127( &fp, &k_sparse2 ); + libxsmm_free( fp ); + } + + libxsmm_free( B ); + libxsmm_free( C ); + } + else { + libxsmm_free( aa_dense ); + free( new_handle ); + new_handle = NULL; + } + + return new_handle; +} + + +LIBXSMM_API void libxsmm_dfsspmdm_execute( const libxsmm_dfsspmdm* handle, const double* B, double* C ) +{ + int i; + assert( handle != NULL ); + + if ( handle->a_dense == NULL ) { + for ( i = 0; i < handle->N; i+=handle->N_chunksize ) { + handle->kernel( internal_fsspmdm_dperm, B+i, C+i ); + } + } else { + for ( i = 0; i < handle->N; i+=handle->N_chunksize ) { + handle->kernel( B+i, handle->a_dense, C+i ); + } + } +} + + +LIBXSMM_API void libxsmm_sfsspmdm_execute( const libxsmm_sfsspmdm* handle, const float* B, float* C ) +{ + int i; + assert( handle != NULL ); + + if ( handle->a_dense == NULL ) { + for ( i = 0; i < handle->N; i+=handle->N_chunksize ) { + handle->kernel( internal_fsspmdm_sperm, B+i, C+i ); + } + } else { + for ( i = 0; i < handle->N; i+=handle->N_chunksize ) { + handle->kernel( B+i, handle->a_dense, C+i ); + } + } +} + + +LIBXSMM_API void libxsmm_dfsspmdm_destroy( libxsmm_dfsspmdm* handle ) +{ + assert( handle != NULL ); + + if ( handle->a_dense != NULL ) { + libxsmm_free(handle->a_dense); + } + else { + /* deallocate code known to be not registered; no index attached + do not use libxsmm_release_kernel here! We also need to work + around pointer-to-function to pointer-to-object conversion */ + void* fp; + LIBXSMM_ASSIGN127(&fp, &handle->kernel); + libxsmm_free(fp); + } + + free(handle); +} + + +LIBXSMM_API void libxsmm_sfsspmdm_destroy( libxsmm_sfsspmdm* handle ) +{ + assert( handle != NULL ); + + if ( handle->a_dense != NULL ) { + libxsmm_free(handle->a_dense); + } + else { + /* deallocate code known to be not registered; no index attached + do not use libxsmm_release_kernel here! We also need to work + around pointer-to-function to pointer-to-object conversion */ + void* fp; + LIBXSMM_ASSIGN127(&fp, &handle->kernel); + libxsmm_free(fp); + } + + free(handle); +} + diff --git a/third_party/libxsmm/src/libxsmm_gemm.c b/third_party/libxsmm/src/libxsmm_gemm.c new file mode 100644 index 0000000000000000000000000000000000000000..1c9972346ea2c84d0ff958560db663f638382b77 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_gemm.c @@ -0,0 +1,2156 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_gemm.h" +#include "libxsmm_xcopy.h" +#include "libxsmm_hash.h" +#include + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#if !defined(LIBXSMM_NO_LIBM) +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if !defined(LIBXSMM_GEMM_XCOPY_JIT) && defined(LIBXSMM_XCOPY_JIT) && (0 != LIBXSMM_XCOPY_JIT) +# define LIBXSMM_GEMM_XCOPY_JIT +#endif +#if !defined(LIBXSMM_GEMM_KPARALLEL) && 0 +# define LIBXSMM_GEMM_KPARALLEL +#endif +#if !defined(LIBXSMM_GEMM_BATCHSIZE) +# define LIBXSMM_GEMM_BATCHSIZE 1024 +#endif +#if !defined(LIBXSMM_GEMM_TASKGRAIN) +# define LIBXSMM_GEMM_TASKGRAIN 128 +#endif +#if !defined(LIBXSMM_GEMM_BATCHREDUCE) && !defined(_WIN32) && !defined(__CYGWIN__) /* not supported */ +# define LIBXSMM_GEMM_BATCHREDUCE +#endif +#if !defined(LIBXSMM_GEMM_BATCHSCALE) && (defined(LIBXSMM_GEMM_BATCHREDUCE) || defined(LIBXSMM_WRAP)) +#define LIBXSMM_GEMM_BATCHSCALE ((unsigned int)LIBXSMM_ROUND(sizeof(libxsmm_mmbatch_item) * (LIBXSMM_GEMM_MMBATCH_SCALE))) +#endif +#if defined(LIBXSMM_BUILD) +# define LIBXSMM_GEMM_WEAK LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK +#else +# define LIBXSMM_GEMM_WEAK LIBXSMM_API +#endif + +#if (0 != LIBXSMM_SYNC) /** Locks for the batch interface (duplicated C indexes). */ +# define LIBXSMM_GEMM_LOCKIDX(IDX, NPOT) LIBXSMM_MOD2(LIBXSMM_CRC32U(LIBXSMM_BLASINT_NBITS)(2507/*seed*/, &(IDX)), NPOT) +# define LIBXSMM_GEMM_LOCKPTR(PTR, NPOT) LIBXSMM_MOD2(LIBXSMM_CRC32U(LIBXSMM_BITS)(1975/*seed*/, &(PTR)), NPOT) +# if !defined(LIBXSMM_GEMM_MAXNLOCKS) +# define LIBXSMM_GEMM_MAXNLOCKS 1024 +# endif +# if !defined(LIBXSMM_GEMM_LOCKFWD) +# define LIBXSMM_GEMM_LOCKFWD +# endif +# if LIBXSMM_LOCK_TYPE_ISPOD(LIBXSMM_GEMM_LOCK) +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_gemm_locktype { + char pad[LIBXSMM_CACHELINE]; + LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) state; +} internal_gemm_locktype; +# else +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_gemm_locktype { + LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) state; +} internal_gemm_locktype; +# endif +LIBXSMM_APIVAR_DEFINE(internal_gemm_locktype internal_gemm_lock[LIBXSMM_GEMM_MAXNLOCKS]); +LIBXSMM_APIVAR_DEFINE(unsigned int internal_gemm_nlocks); /* populated number of locks */ +#endif + +/* definition of corresponding variables */ +LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch_function); +LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch_function); +LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_dgemm_function libxsmm_original_dgemm_function); +LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_sgemm_function libxsmm_original_sgemm_function); +LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_dgemv_function libxsmm_original_dgemv_function); +LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_sgemv_function libxsmm_original_sgemv_function); +/* definition of corresponding variables */ +LIBXSMM_APIVAR_PUBLIC_DEF(libxsmm_gemm_descriptor libxsmm_mmbatch_desc); +LIBXSMM_APIVAR_PUBLIC_DEF(void* libxsmm_mmbatch_array); +LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) libxsmm_mmbatch_lock); +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_mmbatch_size); +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_gemm_npargroups); +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_gemm_taskgrain); +LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_gemm_tasks); +LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_gemm_wrap); + +LIBXSMM_APIVAR_PRIVATE_DEF(libxsmm_gemm_prefetch_type libxsmm_gemm_auto_prefetch_default); +/** Determines the prefetch strategy, which is used in case of LIBXSMM_PREFETCH_AUTO. */ +LIBXSMM_APIVAR_PRIVATE_DEF(libxsmm_gemm_prefetch_type libxsmm_gemm_auto_prefetch); + +/** Prefetch strategy for tiled GEMM. */ +LIBXSMM_APIVAR_DEFINE(libxsmm_gemm_prefetch_type internal_gemm_tiled_prefetch); +/** Vector width used for GEMM. */ +LIBXSMM_APIVAR_DEFINE(unsigned int internal_gemm_vwidth); +/** Limit the M-extent of the tile. */ +LIBXSMM_APIVAR_DEFINE(unsigned int internal_gemm_mlimit); +/** Table of M-extents per type-size (tile shape). */ +LIBXSMM_APIVAR_DEFINE(float internal_gemm_nstretch); +/** Table of M-extents per type-size (tile shape). */ +LIBXSMM_APIVAR_DEFINE(float internal_gemm_kstretch); +/** Determines if batch-reduce is enabled */ +LIBXSMM_APIVAR_DEFINE(int internal_gemm_batchreduce); + + +#if defined(LIBXSMM_BUILD) +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_dgemm_batch)( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ +#if (0 != LIBXSMM_BLAS) +# if defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP) + if (0 > libxsmm_gemm_wrap) { + LIBXSMM_FSYMBOL(dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } + else +# endif + { + const libxsmm_blasint ptrsize = sizeof(void*); + libxsmm_blasint i, j = 0; + LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != group_count && NULL != group_size); + LIBXSMM_ASSERT(NULL != m_array && NULL != n_array && NULL != k_array && NULL != lda_array && NULL != ldb_array && NULL != ldc_array); + LIBXSMM_ASSERT(NULL != a_array && NULL != b_array && NULL != c_array && NULL != alpha_array && NULL != beta_array); + for (i = 0; i < *group_count; ++i) { + const libxsmm_blasint size = group_size[i]; + libxsmm_dmmbatch_blas(transa_array + i, transb_array + i, m_array[i], n_array[i], k_array[i], alpha_array + i, + a_array + j, lda_array + i, b_array + j, ldb_array + i, beta_array + i, + c_array + j, ldc_array + i, 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size); + j += size; + } + } +#else + libxsmm_blas_error("dgemm_batch")(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +#endif +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_sgemm_batch)( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ +#if (0 != LIBXSMM_BLAS) +# if defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP) + if (0 > libxsmm_gemm_wrap) { + LIBXSMM_FSYMBOL(sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); + } + else +# endif + { + const libxsmm_blasint ptrsize = sizeof(void*); + libxsmm_blasint i, j = 0; + LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != group_count && NULL != group_size); + LIBXSMM_ASSERT(NULL != m_array && NULL != n_array && NULL != k_array && NULL != lda_array && NULL != ldb_array && NULL != ldc_array); + LIBXSMM_ASSERT(NULL != a_array && NULL != b_array && NULL != c_array && NULL != alpha_array && NULL != beta_array); + for (i = 0; i < *group_count; ++i) { + const libxsmm_blasint size = group_size[i]; + libxsmm_smmbatch_blas(transa_array + i, transb_array + i, m_array[i], n_array[i], k_array[i], alpha_array + i, + a_array + i, lda_array + i, b_array + i, ldb_array + i, beta_array + i, + c_array + i, ldc_array + i, 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size); + j += size; + } + } +#else + libxsmm_blas_error("sgemm_batch")(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +#endif +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_dgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const double* alpha, const double* a, const libxsmm_blasint* lda, + const double* b, const libxsmm_blasint* ldb, + const double* beta, double* c, const libxsmm_blasint* ldc) +{ +#if (0 != LIBXSMM_BLAS) + LIBXSMM_FSYMBOL(dgemm)((LIBXSMM_BLAS_CONST char*)transa, (LIBXSMM_BLAS_CONST char*)transb, + (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n, (LIBXSMM_BLAS_CONST libxsmm_blasint*)k, + (LIBXSMM_BLAS_CONST double*)alpha, (LIBXSMM_BLAS_CONST double*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda, + (LIBXSMM_BLAS_CONST double*)b, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldb, + (LIBXSMM_BLAS_CONST double*) beta, c, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldc); +#else + libxsmm_blas_error("dgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +#endif +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_sgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const float* a, const libxsmm_blasint* lda, + const float* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) +{ +#if (0 != LIBXSMM_BLAS) + LIBXSMM_FSYMBOL(sgemm)((LIBXSMM_BLAS_CONST char*)transa, (LIBXSMM_BLAS_CONST char*)transb, + (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n, (LIBXSMM_BLAS_CONST libxsmm_blasint*)k, + (LIBXSMM_BLAS_CONST float*)alpha, (LIBXSMM_BLAS_CONST float*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda, + (LIBXSMM_BLAS_CONST float*)b, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldb, + (LIBXSMM_BLAS_CONST float*) beta, c, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldc); +#else + libxsmm_blas_error("sgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +#endif +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_dgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n, + const double* alpha, const double* a, const libxsmm_blasint* lda, const double* x, const libxsmm_blasint* incx, + const double* beta, double* y, const libxsmm_blasint* incy) +{ +#if (0 != LIBXSMM_BLAS) + LIBXSMM_FSYMBOL(dgemv)((LIBXSMM_BLAS_CONST char*)trans, (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n, + (LIBXSMM_BLAS_CONST double*)alpha, (LIBXSMM_BLAS_CONST double*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda, + (LIBXSMM_BLAS_CONST double*)x, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incx, + (LIBXSMM_BLAS_CONST double*) beta, y, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incy); +#else + libxsmm_blas_error("dgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +#endif +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_sgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n, + const float* alpha, const float* a, const libxsmm_blasint* lda, const float* x, const libxsmm_blasint* incx, + const float* beta, float* y, const libxsmm_blasint* incy) +{ +#if (0 != LIBXSMM_BLAS) + LIBXSMM_FSYMBOL(sgemv)((LIBXSMM_BLAS_CONST char*)trans, (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n, + (LIBXSMM_BLAS_CONST float*)alpha, (LIBXSMM_BLAS_CONST float*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda, + (LIBXSMM_BLAS_CONST float*)x, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incx, + (LIBXSMM_BLAS_CONST float*) beta, y, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incy); +#else + libxsmm_blas_error("sgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +#endif +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void __real_dgemm_batch( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + LIBXSMM_FSYMBOL(__real_dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void __real_sgemm_batch( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + LIBXSMM_FSYMBOL(__real_sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array, + alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, + group_count, group_size); +} +#endif /*defined(LIBXSMM_BUILD)*/ + + +LIBXSMM_GEMM_WEAK libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch(void) +{ +#if (0 != LIBXSMM_BLAS) && defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP) + LIBXSMM_BLAS_WRAPPER(1, double, gemm_batch, libxsmm_original_dgemm_batch_function, NULL/*unknown*/); + /*LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_batch_function);*/ +#else + LIBXSMM_BLAS_WRAPPER(0, double, gemm_batch, libxsmm_original_dgemm_batch_function, NULL/*unknown*/); +#endif + return libxsmm_original_dgemm_batch_function; +} + + +LIBXSMM_GEMM_WEAK libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch(void) +{ +#if (0 != LIBXSMM_BLAS) && defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP) + LIBXSMM_BLAS_WRAPPER(1, float, gemm_batch, libxsmm_original_sgemm_batch_function, NULL/*unknown*/); + /*LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_batch_function);*/ +#else + LIBXSMM_BLAS_WRAPPER(0, float, gemm_batch, libxsmm_original_sgemm_batch_function, NULL/*unknown*/); +#endif + return libxsmm_original_sgemm_batch_function; +} + + +LIBXSMM_GEMM_WEAK libxsmm_dgemm_function libxsmm_original_dgemm(void) +{ +#if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, double, gemm, libxsmm_original_dgemm_function, NULL/*unknown*/); + LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_function); +#else + LIBXSMM_BLAS_WRAPPER(0, double, gemm, libxsmm_original_dgemm_function, NULL/*unknown*/); +#endif + return libxsmm_original_dgemm_function; +} + + +LIBXSMM_GEMM_WEAK libxsmm_sgemm_function libxsmm_original_sgemm(void) +{ +#if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, float, gemm, libxsmm_original_sgemm_function, NULL/*unknown*/); + LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_function); +#else + LIBXSMM_BLAS_WRAPPER(0, float, gemm, libxsmm_original_sgemm_function, NULL/*unknown*/); +#endif + return libxsmm_original_sgemm_function; +} + + +LIBXSMM_GEMM_WEAK libxsmm_dgemv_function libxsmm_original_dgemv(void) +{ +#if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, double, gemv, libxsmm_original_dgemv_function, NULL/*unknown*/); + LIBXSMM_ASSERT(NULL != libxsmm_original_dgemv_function); +#else + LIBXSMM_BLAS_WRAPPER(0, double, gemv, libxsmm_original_dgemv_function, NULL/*unknown*/); +#endif + return libxsmm_original_dgemv_function; +} + + +LIBXSMM_GEMM_WEAK libxsmm_sgemv_function libxsmm_original_sgemv(void) +{ +#if (0 != LIBXSMM_BLAS) + LIBXSMM_BLAS_WRAPPER(1, float, gemv, libxsmm_original_sgemv_function, NULL/*unknown*/); + LIBXSMM_ASSERT(NULL != libxsmm_original_sgemv_function); +#else + LIBXSMM_BLAS_WRAPPER(0, float, gemv, libxsmm_original_sgemv_function, NULL/*unknown*/); +#endif + return libxsmm_original_sgemv_function; +} + + +LIBXSMM_API libxsmm_sink_function libxsmm_blas_error(const char* symbol) +{ + static int error_once = 0; + LIBXSMM_BLAS_ERROR(symbol, &error_once); + return libxsmm_sink; +} + + +LIBXSMM_API_INTERN void libxsmm_gemm_init(int archid) +{ + const char* env_w = getenv("LIBXSMM_GEMM_WRAP"); + LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_GEMM_LOCK) attr; + LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_GEMM_LOCK, &attr); +#if defined(LIBXSMM_WRAP) /* determines if wrap is considered */ + { /* intercepted GEMMs (1: sequential and non-tiled, 2: parallelized and tiled) */ +# if defined(__STATIC) /* with static library the user controls interceptor already */ + libxsmm_gemm_wrap = ((NULL == env_w || 0 == *env_w) /* LIBXSMM_WRAP=0: no promotion */ + ? (0 < (LIBXSMM_WRAP) ? (LIBXSMM_WRAP + 2) : (LIBXSMM_WRAP - 2)) : atoi(env_w)); +# else + libxsmm_gemm_wrap = ((NULL == env_w || 0 == *env_w) ? (LIBXSMM_WRAP) : atoi(env_w)); +# endif + } +#endif + { /* setup prefetch strategy for tiled GEMMs */ + const char *const env_p = getenv("LIBXSMM_TGEMM_PREFETCH"); + const libxsmm_gemm_prefetch_type tiled_prefetch_default = LIBXSMM_GEMM_PREFETCH_AL2_AHEAD; + const int uid = ((NULL == env_p || 0 == *env_p) ? LIBXSMM_PREFETCH_AUTO/*default*/ : atoi(env_p)); + internal_gemm_tiled_prefetch = (0 <= uid ? libxsmm_gemm_uid2prefetch(uid) : tiled_prefetch_default); + } +#if (0 != LIBXSMM_SYNC) + { /* initialize locks for the batch interface */ + const char *const env_locks = getenv("LIBXSMM_GEMM_NLOCKS"); + const int nlocks = ((NULL == env_locks || 0 == *env_locks) ? -1/*default*/ : atoi(env_locks)); + unsigned int i; + internal_gemm_nlocks = LIBXSMM_UP2POT(0 > nlocks ? (LIBXSMM_GEMM_MAXNLOCKS) : LIBXSMM_MIN(nlocks, LIBXSMM_GEMM_MAXNLOCKS)); + for (i = 0; i < internal_gemm_nlocks; ++i) LIBXSMM_LOCK_INIT(LIBXSMM_GEMM_LOCK, &internal_gemm_lock[i].state, &attr); + } +#endif +#if defined(LIBXSMM_GEMM_BATCHREDUCE) || defined(LIBXSMM_WRAP) + { /* determines if batch-reduce kernel or batch-wrap is considered */ + const char *const env_r = getenv("LIBXSMM_GEMM_BATCHREDUCE"); + internal_gemm_batchreduce = (NULL == env_r || 0 == *env_r) ? 0 : atoi(env_r); + if ((NULL == env_w || 0 == *env_w) && ((LIBXSMM_GEMM_MMBATCH_VERBOSITY <= libxsmm_verbosity && INT_MAX != libxsmm_verbosity) || 0 > libxsmm_verbosity)) { + libxsmm_mmbatch_desc.flags = LIBXSMM_MMBATCH_FLAG_STATISTIC; /* enable auto-batch statistic */ + internal_gemm_batchreduce = 0; + } + if (0 != internal_gemm_batchreduce || 0 != libxsmm_gemm_wrap) { + const char *const env_b = getenv("LIBXSMM_GEMM_BATCHSIZE"); + const int env_bi = (NULL == env_b || 0 == *env_b) ? -1/*auto*/ : atoi(env_b); + const unsigned int env_bu = (unsigned int)(0 >= env_bi ? (LIBXSMM_GEMM_BATCHSIZE) : env_bi); + const unsigned int batchscale = LIBXSMM_ABS(internal_gemm_batchreduce) * 2048/*arbitrary*/ * 2/*A and B-matrices*/ * sizeof(void*); + const unsigned int minsize = LIBXSMM_UPDIV(batchscale * env_bu, LIBXSMM_GEMM_BATCHSCALE); + const unsigned int batchsize = LIBXSMM_MAX(env_bu, minsize); + const void *const extra = NULL; + LIBXSMM_ASSERT(1 < (LIBXSMM_GEMM_MMBATCH_SCALE) && NULL == libxsmm_mmbatch_array); + if (EXIT_SUCCESS == libxsmm_xmalloc(&libxsmm_mmbatch_array, (size_t)batchsize * (LIBXSMM_GEMM_BATCHSCALE), 0/*auto-alignment*/, + LIBXSMM_MALLOC_FLAG_PRIVATE /*| LIBXSMM_MALLOC_FLAG_SCRATCH*/, &extra, sizeof(extra))) + { + LIBXSMM_LOCK_INIT(LIBXSMM_GEMM_LOCK, &libxsmm_mmbatch_lock, &attr); + LIBXSMM_ASSERT(NULL != libxsmm_mmbatch_array); + libxsmm_mmbatch_size = batchsize; + } + } + } +#else + LIBXSMM_UNUSED(env_w); +#endif + { /* determines grain-size of tasks (when available) */ + const char *const env_s = getenv("LIBXSMM_GEMM_NPARGROUPS"); + libxsmm_gemm_npargroups = ((NULL == env_s || 0 == *env_s || 0 >= atoi(env_s)) + ? (LIBXSMM_GEMM_NPARGROUPS) : atoi(env_s)); + } + if (LIBXSMM_X86_AVX512_CORE <= archid) { + internal_gemm_vwidth = 64; + internal_gemm_mlimit = 48; + internal_gemm_nstretch = 3.0f; + internal_gemm_kstretch = 2.0f; + } + else if (LIBXSMM_X86_AVX512_MIC <= archid) { + internal_gemm_vwidth = 64; + internal_gemm_mlimit = 64; + internal_gemm_nstretch = 1.0f; + internal_gemm_kstretch = 1.0f; + } + else if (LIBXSMM_X86_AVX2 <= archid) { + internal_gemm_vwidth = 32; + internal_gemm_mlimit = 48; + internal_gemm_nstretch = 3.0f; + internal_gemm_kstretch = 2.0f; + } + else if (LIBXSMM_X86_AVX <= archid) { + internal_gemm_vwidth = 32; + internal_gemm_mlimit = 48; + internal_gemm_nstretch = 5.0f; + internal_gemm_kstretch = 1.0f; + } + else { + internal_gemm_vwidth = 16; + internal_gemm_mlimit = 48; + internal_gemm_nstretch = 7.0f; + internal_gemm_kstretch = 5.0f; + } + { /* setup tile sizes according to environment (LIBXSMM_TGEMM_M, LIBXSMM_TGEMM_N, LIBXSMM_TGEMM_K) */ + const char *const env_m = getenv("LIBXSMM_TGEMM_M"), *const env_n = getenv("LIBXSMM_TGEMM_N"), *const env_k = getenv("LIBXSMM_TGEMM_K"); + const int m = ((NULL == env_m || 0 == *env_m) ? 0 : atoi(env_m)); + const int n = ((NULL == env_n || 0 == *env_n) ? 0 : atoi(env_n)); + const int k = ((NULL == env_k || 0 == *env_k) ? 0 : atoi(env_k)); + if (0 < m) { + if (0 < n) internal_gemm_nstretch = ((float)n) / m; + if (0 < k) internal_gemm_kstretch = ((float)k) / m; + } + } + { /* setup tile sizes according to environment (LIBXSMM_TGEMM_NS, LIBXSMM_TGEMM_KS) */ + const char *const env_ns = getenv("LIBXSMM_TGEMM_NS"), *const env_ks = getenv("LIBXSMM_TGEMM_KS"); + const double ns = ((NULL == env_ns || 0 == *env_ns) ? 0 : atof(env_ns)); + const double ks = ((NULL == env_ks || 0 == *env_ks) ? 0 : atof(env_ks)); + if (0 < ns) internal_gemm_nstretch = (float)LIBXSMM_MIN(24, ns); + if (0 < ks) internal_gemm_kstretch = (float)LIBXSMM_MIN(24, ks); + } + { /* determines if OpenMP tasks are used (when available) */ + const char *const env_t = getenv("LIBXSMM_GEMM_TASKS"); + const int gemm_tasks = ((NULL == env_t || 0 == *env_t) ? 0/*disabled*/ : atoi(env_t)); + libxsmm_gemm_tasks = (0 <= gemm_tasks ? LIBXSMM_ABS(gemm_tasks) : 1/*enabled*/); + } + { /* determines grain-size of tasks (when available) */ + const char *const env_g = getenv("LIBXSMM_GEMM_TASKGRAIN"); + const int gemm_taskgrain = ((NULL == env_g || 0 == *env_g || 0 >= atoi(env_g)) + ? (LIBXSMM_GEMM_TASKGRAIN) : atoi(env_g)); + /* adjust grain-size or scale beyond the number of threads */ + libxsmm_gemm_taskgrain = LIBXSMM_MAX(0 < libxsmm_gemm_tasks ? (gemm_taskgrain / libxsmm_gemm_tasks) : gemm_taskgrain, 1); + } + LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_GEMM_LOCK, &attr); + /* determine BLAS function-pointers */ + libxsmm_original_dgemm_batch(); + libxsmm_original_sgemm_batch(); + libxsmm_original_dgemm(); + libxsmm_original_sgemm(); + libxsmm_original_dgemv(); + libxsmm_original_sgemv(); +} + + +LIBXSMM_API_INTERN void libxsmm_gemm_finalize(void) +{ +#if (0 != LIBXSMM_SYNC) + unsigned int i; for (i = 0; i < internal_gemm_nlocks; ++i) LIBXSMM_LOCK_DESTROY(LIBXSMM_GEMM_LOCK, &internal_gemm_lock[i].state); +#endif +#if defined(LIBXSMM_GEMM_BATCHREDUCE) || defined(LIBXSMM_WRAP) + if (NULL != libxsmm_mmbatch_array) { + void *extra = NULL, *const mmbatch_array = libxsmm_mmbatch_array; + if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(mmbatch_array, NULL/*size*/, NULL/*flags*/, &extra) && NULL != extra) { + const libxsmm_mmbatch_flush_function flush = *(libxsmm_mmbatch_flush_function*)extra; + if (NULL != flush) flush(); + } +#if !defined(NDEBUG) + libxsmm_mmbatch_array = NULL; +#endif + libxsmm_xfree(mmbatch_array, 0/*no check*/); + LIBXSMM_LOCK_DESTROY(LIBXSMM_GEMM_LOCK, &libxsmm_mmbatch_lock); + } +#endif +} + + +LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_xprefetch(const int* prefetch) +{ + LIBXSMM_INIT /* load configuration */ + return libxsmm_get_gemm_prefetch(NULL == prefetch ? ((int)libxsmm_gemm_auto_prefetch) : *prefetch); +} + + +LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_prefetch(int prefetch) +{ + libxsmm_gemm_prefetch_type result; +#if !defined(_WIN32) && !defined(__CYGWIN__) && !defined(__MINGW32__) + if (0 > prefetch) { + LIBXSMM_INIT /* load configuration */ + result = libxsmm_gemm_auto_prefetch_default; + } + else { + result = (libxsmm_gemm_prefetch_type)prefetch; + } +#else /* TODO: full support for Windows calling convention */ + result = LIBXSMM_GEMM_PREFETCH_NONE; + LIBXSMM_UNUSED(prefetch); +#endif + return result; +} + + +LIBXSMM_API_INTERN int libxsmm_gemm_prefetch2uid(libxsmm_gemm_prefetch_type prefetch) +{ + switch (prefetch) { + case LIBXSMM_GEMM_PREFETCH_SIGONLY: return 2; + case LIBXSMM_GEMM_PREFETCH_BL2_VIA_C: return 3; + case LIBXSMM_GEMM_PREFETCH_AL2_AHEAD: return 4; + case LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD: return 5; + case LIBXSMM_GEMM_PREFETCH_AL2: return 6; + case LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C: return 7; + case LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB: return 8; + default: { + LIBXSMM_ASSERT(LIBXSMM_GEMM_PREFETCH_NONE == prefetch); + return 0; + } + } +} + + +LIBXSMM_API_INTERN libxsmm_gemm_prefetch_type libxsmm_gemm_uid2prefetch(int uid) +{ + switch (uid) { + case 1: return LIBXSMM_GEMM_PREFETCH_NONE; /* nopf */ + case 2: return LIBXSMM_GEMM_PREFETCH_SIGONLY; /* pfsigonly */ + case 3: return LIBXSMM_GEMM_PREFETCH_BL2_VIA_C; /* BL2viaC */ + case 4: return LIBXSMM_GEMM_PREFETCH_AL2_AHEAD; /* curAL2 */ + case 5: return LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD; /* curAL2_BL2viaC */ + case 6: return LIBXSMM_GEMM_PREFETCH_AL2; /* AL2 */ + case 7: return LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C; /* AL2_BL2viaC */ + case 8: return LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB; + default: { + if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + static int error_once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { + fprintf(stderr, "LIBXSMM WARNING: invalid prefetch strategy requested!\n"); + } + } + return LIBXSMM_GEMM_PREFETCH_NONE; + } + } +} + + +LIBXSMM_API void libxsmm_gemm_print(void* ostream, + libxsmm_gemm_precision precision, const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, + const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc) +{ + libxsmm_gemm_print2(ostream, precision, precision, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +LIBXSMM_API void libxsmm_gemm_print2(void* ostream, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, + const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc) +{ + const libxsmm_blasint nn = *(n ? n : m), kk = *(k ? k : m); + const char ctransa = (char)(NULL != transa ? (*transa) : (0 == (LIBXSMM_FLAGS & LIBXSMM_GEMM_FLAG_TRANS_A) ? 'n' : 't')); + const char ctransb = (char)(NULL != transb ? (*transb) : (0 == (LIBXSMM_FLAGS & LIBXSMM_GEMM_FLAG_TRANS_B) ? 'n' : 't')); + const libxsmm_blasint ilda = (NULL != lda ? *lda : (('n' == ctransa || 'N' == ctransa) ? *m : kk)); + const libxsmm_blasint ildb = (NULL != ldb ? *ldb : (('n' == ctransb || 'N' == ctransb) ? kk : nn)); + const libxsmm_blasint ildc = *(NULL != ldc ? ldc : m); + libxsmm_mhd_elemtype mhd_elemtype = LIBXSMM_MHD_ELEMTYPE_UNKNOWN; + char string_a[128], string_b[128], typeprefix = 0; + + switch (iprec | oprec) { + case LIBXSMM_GEMM_PRECISION_F64: { + LIBXSMM_ASSERT(iprec == oprec); + LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "%g", NULL != alpha ? *((const double*)alpha) : LIBXSMM_ALPHA); + LIBXSMM_SNPRINTF(string_b, sizeof(string_b), "%g", NULL != beta ? *((const double*)beta) : LIBXSMM_BETA); + mhd_elemtype = LIBXSMM_MHD_ELEMTYPE_F64; + typeprefix = 'd'; + } break; + case LIBXSMM_GEMM_PRECISION_F32: { + LIBXSMM_ASSERT(iprec == oprec); + LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "%g", NULL != alpha ? *((const float*)alpha) : LIBXSMM_ALPHA); + LIBXSMM_SNPRINTF(string_b, sizeof(string_b), "%g", NULL != beta ? *((const float*)beta) : LIBXSMM_BETA); + mhd_elemtype = LIBXSMM_MHD_ELEMTYPE_F32; + typeprefix = 's'; + } break; + default: if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + static int error_once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { /* TODO: support I16, etc. */ + fprintf(stderr, "LIBXSMM ERROR: unsupported data-type requested!\n"); + } + } + } + + if (0 != typeprefix) { + if (NULL != ostream) { /* print information about GEMM call */ + if (NULL != a && NULL != b && NULL != c) { + fprintf((FILE*)ostream, "%cgemm('%c', '%c', %" PRIuPTR "/*m*/, %" PRIuPTR "/*n*/, %" PRIuPTR "/*k*/,\n" + " %s/*alpha*/, %p/*a*/, %" PRIuPTR "/*lda*/,\n" + " %p/*b*/, %" PRIuPTR "/*ldb*/,\n" + " %s/*beta*/, %p/*c*/, %" PRIuPTR "/*ldc*/)", + typeprefix, ctransa, ctransb, (uintptr_t)*m, (uintptr_t)nn, (uintptr_t)kk, + string_a, a, (uintptr_t)ilda, b, (uintptr_t)ildb, string_b, c, (uintptr_t)ildc); + } + else { + fprintf((FILE*)ostream, "%cgemm(trans=%c%c mnk=%" PRIuPTR ",%" PRIuPTR ",%" PRIuPTR + " ldx=%" PRIuPTR ",%" PRIuPTR ",%" PRIuPTR " a,b=%s,%s)", + typeprefix, ctransa, ctransb, (uintptr_t)*m, (uintptr_t)nn, (uintptr_t)kk, + (uintptr_t)ilda, (uintptr_t)ildb, (uintptr_t)ildc, string_a, string_b); + } + } + else { /* dump A, B, and C matrices into MHD files */ + char extension_header[256]; + size_t data_size[2], size[2]; + + if (NULL != a) { + LIBXSMM_SNPRINTF(extension_header, sizeof(extension_header), "TRANS = %c\nALPHA = %s", ctransa, string_a); + LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "libxsmm_a_%p.mhd", a); + data_size[0] = (size_t)ilda; data_size[1] = (size_t)kk; size[0] = (size_t)(*m); size[1] = (size_t)kk; + libxsmm_mhd_write(string_a, NULL/*offset*/, size, data_size, 2/*ndims*/, 1/*ncomponents*/, mhd_elemtype, + NULL/*conversion*/, a, NULL/*header_size*/, extension_header, NULL/*extension*/, 0/*extension_size*/); + } + if (NULL != b) { + LIBXSMM_SNPRINTF(extension_header, sizeof(extension_header), "\nTRANS = %c", ctransb); + LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "libxsmm_b_%p.mhd", b); + data_size[0] = (size_t)ildb; data_size[1] = (size_t)nn; size[0] = (size_t)kk; size[1] = (size_t)nn; + libxsmm_mhd_write(string_a, NULL/*offset*/, size, data_size, 2/*ndims*/, 1/*ncomponents*/, mhd_elemtype, + NULL/*conversion*/, b, NULL/*header_size*/, extension_header, NULL/*extension*/, 0/*extension_size*/); + } + if (NULL != c) { + LIBXSMM_SNPRINTF(extension_header, sizeof(extension_header), "BETA = %s", string_b); + LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "libxsmm_c_%p.mhd", c); + data_size[0] = (size_t)ildc; data_size[1] = (size_t)nn; size[0] = (size_t)(*m); size[1] = (size_t)nn; + libxsmm_mhd_write(string_a, NULL/*offset*/, size, data_size, 2/*ndims*/, 1/*ncomponents*/, mhd_elemtype, + NULL/*conversion*/, c, NULL/*header_size*/, extension_header, NULL/*extension*/, 0/*extension_size*/); + } + } + } +} + + +LIBXSMM_API void libxsmm_gemm_dprint( + void* ostream, libxsmm_gemm_precision precision, char transa, char transb, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, double dalpha, const void* a, libxsmm_blasint lda, + const void* b, libxsmm_blasint ldb, double dbeta, void* c, libxsmm_blasint ldc) +{ + libxsmm_gemm_dprint2(ostream, precision, precision, transa, transb, m, n, k, dalpha, a, lda, b, ldb, dbeta, c, ldc); +} + + +LIBXSMM_API void libxsmm_gemm_dprint2( + void* ostream, libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, char transa, char transb, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, double dalpha, const void* a, libxsmm_blasint lda, + const void* b, libxsmm_blasint ldb, double dbeta, void* c, libxsmm_blasint ldc) +{ + switch (iprec) { + case LIBXSMM_GEMM_PRECISION_F64: { + libxsmm_gemm_print2(ostream, LIBXSMM_GEMM_PRECISION_F64, oprec, &transa, &transb, + &m, &n, &k, &dalpha, a, &lda, b, &ldb, &dbeta, c, &ldc); + } break; + case LIBXSMM_GEMM_PRECISION_F32: { + const float alpha = (float)dalpha, beta = (float)dbeta; + libxsmm_gemm_print2(ostream, LIBXSMM_GEMM_PRECISION_F32, oprec, &transa, &transb, + &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } break; + default: { + libxsmm_gemm_print2(ostream, iprec, oprec, &transa, &transb, + &m, &n, &k, &dalpha, a, &lda, b, &ldb, &dbeta, c, &ldc); + } + } +} + + +LIBXSMM_API void libxsmm_gemm_xprint(void* ostream, + libxsmm_xmmfunction kernel, const void* a, const void* b, void* c) +{ + const libxsmm_descriptor* desc; + libxsmm_code_pointer code; + size_t code_size; + code.xgemm = kernel; + if (NULL != libxsmm_get_kernel_xinfo(code, &desc, &code_size) && + NULL != desc && LIBXSMM_KERNEL_KIND_MATMUL == LIBXSMM_DESCRIPTOR_KIND(desc->kind)) + { + libxsmm_gemm_dprint2(ostream, + (libxsmm_gemm_precision)LIBXSMM_GETENUM_INP(desc->gemm.desc.datatype), + (libxsmm_gemm_precision)LIBXSMM_GETENUM_OUT(desc->gemm.desc.datatype), + (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & desc->gemm.desc.flags) ? 'N' : 'T'), + (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & desc->gemm.desc.flags) ? 'N' : 'T'), + (libxsmm_blasint)desc->gemm.desc.m, (libxsmm_blasint)desc->gemm.desc.n, (libxsmm_blasint)desc->gemm.desc.k, + /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & libxsmm_mmbatch_desc.flags) ? 0 : */1, a, + (libxsmm_blasint)desc->gemm.desc.lda, b, (libxsmm_blasint)desc->gemm.desc.ldb, + 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & libxsmm_mmbatch_desc.flags) ? 0 : 1, c, (libxsmm_blasint)desc->gemm.desc.ldc); + fprintf((FILE*)ostream, " = %p+%u", code.ptr_const, (unsigned int)code_size); + } +} + + +LIBXSMM_API void libxsmm_blas_xgemm(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc) +{ + LIBXSMM_INIT + switch (iprec) { + case LIBXSMM_GEMM_PRECISION_F64: { + LIBXSMM_ASSERT(iprec == oprec); + LIBXSMM_BLAS_XGEMM(double, double, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } break; + case LIBXSMM_GEMM_PRECISION_F32: { + LIBXSMM_ASSERT(iprec == oprec); + LIBXSMM_BLAS_XGEMM(float, float, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } break; + default: if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + static int error_once = 0; + LIBXSMM_UNUSED(oprec); + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { /* TODO: support I16, etc. */ + fprintf(stderr, "LIBXSMM ERROR: unsupported data-type requested!\n"); + } + } + } +} + + +LIBXSMM_API_INLINE int libxsmm_gemm_plan_internal(unsigned int ntasks, + unsigned int m, unsigned int n, unsigned int k, /* whole problem size */ + unsigned int tm, unsigned int tn, unsigned int tk, /* tile size (kernel) */ + unsigned int* nmt, unsigned int* nnt, unsigned int* nkt, /* number of tiles */ + unsigned int* mt, unsigned int* nt, unsigned int* kt) /* number of tasks */ +{ + unsigned int result = EXIT_SUCCESS, replan = 0; + LIBXSMM_ASSERT(NULL != nmt && NULL != nnt && NULL != nkt); + LIBXSMM_ASSERT(NULL != mt && NULL != nt && NULL != kt); + LIBXSMM_ASSERT(0 < ntasks); + *nmt = (m + tm - 1) / LIBXSMM_MAX(tm, 1); + *nnt = (n + tn - 1) / LIBXSMM_MAX(tn, 1); + *nkt = (k + tk - 1) / LIBXSMM_MAX(tk, 1); +#if !defined(NDEBUG) + *mt = *nt = *kt = 0; +#endif + do { + if (1 >= replan) *mt = libxsmm_product_limit(*nmt, ntasks, 0); + if (1 == replan || ntasks <= *mt) { /* M-parallelism */ + *nt = 1; + *kt = 1; + replan = 0; + } + else { + const unsigned int mntasks = libxsmm_product_limit((*nmt) * (*nnt), ntasks, 0); + if (0 == replan && *mt >= mntasks) replan = 1; + if (2 == replan || (0 == replan && ntasks <= mntasks)) { /* MN-parallelism */ + *nt = mntasks / *mt; + *kt = 1; + replan = 0; + } + else { /* MNK-parallelism */ + const unsigned int mnktasks = libxsmm_product_limit((*nmt) * (*nnt) * (*nkt), ntasks, 0); + if (mntasks < mnktasks) { +#if defined(LIBXSMM_GEMM_KPARALLEL) + *nt = mntasks / *mt; + *kt = mnktasks / mntasks; + replan = 0; +#else + static int error_once = 0; + if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: XGEMM K-parallelism triggered!\n"); + } +#endif + } +#if defined(LIBXSMM_GEMM_KPARALLEL) + else +#endif + if (0 == replan) replan = 2; + } + } + } while (0 != replan); + if (0 == *mt || 0 == *nt || 0 == *kt) { + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API libxsmm_gemm_handle* libxsmm_gemm_handle_init(libxsmm_gemm_blob* blob, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const void* alpha, const void* beta, int flags, /*unsigned*/int ntasks) +{ + unsigned int ulda, uldb, um, un, uk, tm = 0, tn = 0, tk = 0, max_ntasks = 0; + libxsmm_descriptor_blob desc_blob; + union { + libxsmm_gemm_handle* ptr; + libxsmm_gemm_blob* blob; + } result; + LIBXSMM_ASSERT(sizeof(libxsmm_gemm_handle) <= sizeof(libxsmm_gemm_blob)); + if (NULL != blob && NULL != m && 0 < ntasks) { + unsigned int ntm = 0, ntn = 0, ntk = 0, mt = 1, nt = 1, kt = 1; + const char *const env_tm = getenv("LIBXSMM_TGEMM_M"); + libxsmm_blasint klda, kldb, kldc, km, kn; + libxsmm_gemm_descriptor* desc; + double dbeta; + LIBXSMM_INIT + result.blob = blob; +#if defined(NDEBUG) + result.ptr->copy_a.ptr = result.ptr->copy_b.ptr = result.ptr->copy_i.ptr = result.ptr->copy_o.ptr = NULL; +#else + memset(blob, 0, sizeof(libxsmm_gemm_blob)); +#endif + if (EXIT_SUCCESS != libxsmm_dvalue((libxsmm_datatype)oprec, beta, &dbeta)) dbeta = LIBXSMM_BETA; /* fuse beta into flags */ + result.ptr->gemm_flags = LIBXSMM_GEMM_PFLAGS(transa, transb, LIBXSMM_FLAGS) | (LIBXSMM_NEQ(0, dbeta) ? 0 : LIBXSMM_GEMM_FLAG_BETA_0); + /* TODO: check that arguments fit into handle (unsigned int vs. libxsmm_blasint) */ + um = (unsigned int)(*m); uk = (NULL != k ? ((unsigned int)(*k)) : um); un = (NULL != n ? ((unsigned int)(*n)) : uk); + result.ptr->otypesize = libxsmm_typesize((libxsmm_datatype)oprec); + if (NULL == env_tm || 0 >= atoi(env_tm)) { + const unsigned int vwidth = LIBXSMM_MAX(internal_gemm_vwidth / result.ptr->otypesize, 1); + const double s2 = (double)internal_gemm_nstretch * internal_gemm_kstretch; /* LIBXSMM_INIT! */ + unsigned int tmi = libxsmm_product_limit(um, internal_gemm_mlimit, 0); /* LIBXSMM_INIT! */ + for (; vwidth <= tmi; tmi = libxsmm_product_limit(um, tmi - 1, 0)) { + const double si = (double)(LIBXSMM_CONFIG_MAX_MNK) / ((double)tmi * tmi * tmi), s = (s2 <= si ? 1 : (s2 / si)); + unsigned int tni = libxsmm_product_limit(un, LIBXSMM_MAX((unsigned int)(tmi * (s * internal_gemm_nstretch)), 1), 0); + unsigned int tki = libxsmm_product_limit(uk, LIBXSMM_MAX((unsigned int)(tmi * (s * internal_gemm_kstretch)), 1), 0); + unsigned int ntmi, ntni, ntki, mti = 1, nti = 1, kti = 1; + LIBXSMM_ASSERT(tmi <= um && tni <= un && tki <= uk); + if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) { + const unsigned int ttm = (unsigned int)libxsmm_product_limit(tmi, (unsigned int)ntasks, 0); + const unsigned int ttn = (unsigned int)libxsmm_product_limit(tni, (unsigned int)ntasks, 0); + tmi = tni = LIBXSMM_MIN(ttm, ttn); /* prefer threads over larger tile */ + } + if (EXIT_SUCCESS == libxsmm_gemm_plan_internal((unsigned int)ntasks, um, un, uk, tmi, tni, tki, + &ntmi, &ntni, &ntki, &mti, &nti, &kti)) + { + const int exit_plan = ((tmi < um && tni < un && tki < uk && (tm != tmi || tn != tni || tk != tki)) ? 0 : 1); + const unsigned itasks = mti * nti * kti; + LIBXSMM_ASSERT(1 <= itasks); + if (max_ntasks < itasks) { + ntm = ntmi; ntn = ntni; ntk = ntki; + mt = mti; nt = nti; kt = kti; + tm = tmi; tn = tni; tk = tki; + max_ntasks = itasks; + } + if (itasks == (unsigned int)ntasks || 0 != exit_plan) break; + } + } + } + else { + const unsigned int tmi = atoi(env_tm); + const double s2 = (double)internal_gemm_nstretch * internal_gemm_kstretch; /* LIBXSMM_INIT! */ + double si, s; + tm = libxsmm_product_limit(um, LIBXSMM_MIN(tmi, internal_gemm_mlimit), 0); /* LIBXSMM_INIT! */ + si = (double)(LIBXSMM_CONFIG_MAX_MNK) / ((double)tm * tm * tm); s = (s2 <= si ? 1 : (s2 / si)); + tn = libxsmm_product_limit(un, LIBXSMM_MAX((unsigned int)(tm * (s * internal_gemm_nstretch)), 1), 0); + tk = libxsmm_product_limit(uk, LIBXSMM_MAX((unsigned int)(tm * (s * internal_gemm_kstretch)), 1), 0); + if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) { + const unsigned int ttm = (unsigned int)libxsmm_product_limit(tm, (unsigned int)ntasks, 0); + const unsigned int ttn = (unsigned int)libxsmm_product_limit(tn, (unsigned int)ntasks, 0); + tm = tn = LIBXSMM_MIN(ttm, ttn); /* prefer threads over larger tile */ + } + if (EXIT_SUCCESS == libxsmm_gemm_plan_internal((unsigned int)ntasks, um, un, uk, tm, tn, tk, + &ntm, &ntn, &ntk, &mt, &nt, &kt)) + { +#if defined(NDEBUG) + max_ntasks = 2; /* only need something unequal to zero to pass below condition */ +#else + max_ntasks = mt * nt * kt; +#endif + } + } + LIBXSMM_ASSERT(LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags) || tm == tn); + /* check for non-conforming GEMM parameters (error), and conforming GEMM parameters (fast-path, fallback) */ + if (0 == max_ntasks || 0 == tm || 0 == tn || 0 == tk || 0 != (um % tm) || 0 != (un % tn) || 0 != (uk % tk)) { + return NULL; + } + result.ptr->flags = flags; + if (LIBXSMM_GEMM_HANDLE_FLAG_AUTO == flags && 0 == LIBXSMM_SMM_AI(um, un, uk, + 0 == (result.ptr->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0) ? 1 : 2/*RFO*/, result.ptr->otypesize)) + { + if (um == LIBXSMM_UP2POT(um) || un == LIBXSMM_UP2POT(un)) { /* power-of-two (POT) extent(s) */ + result.ptr->flags |= LIBXSMM_GEMM_HANDLE_FLAG_COPY_C; + if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) { + result.ptr->flags |= LIBXSMM_GEMM_HANDLE_FLAG_COPY_A; + } + } + } + result.ptr->itypesize = libxsmm_typesize((libxsmm_datatype)iprec); + result.ptr->ldc = (unsigned int)(NULL != ldc ? *ldc : *m); + ulda = (NULL != lda ? ((unsigned int)(*lda)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & result.ptr->gemm_flags) ? ((unsigned int)(*m)) : uk)); + uldb = (NULL != ldb ? ((unsigned int)(*ldb)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & result.ptr->gemm_flags) ? uk : un)); + if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) { /* NN, NT, or TN */ + const libxsmm_blasint itm = (libxsmm_blasint)tm, itk = (libxsmm_blasint)tk; +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + const libxsmm_blasint itn = (libxsmm_blasint)tn; +#endif + kldc = (libxsmm_blasint)result.ptr->ldc; + klda = (libxsmm_blasint)ulda; + kldb = (libxsmm_blasint)uldb; + if (0 != (LIBXSMM_GEMM_FLAG_TRANS_A & result.ptr->gemm_flags)) { /* TN */ +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + result.ptr->copy_a.function = libxsmm_dispatch_meltw_unary(itk, itm, &klda, &itm, + (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); +#endif + klda = itm; + } + else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_A & result.ptr->flags)) { +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + result.ptr->copy_a.function = libxsmm_dispatch_meltw_unary(itm, itk, &klda, &itm, + (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY); +#endif + klda = (libxsmm_blasint)tm; + } + if (0 != (LIBXSMM_GEMM_FLAG_TRANS_B & result.ptr->gemm_flags)) { /* NT */ +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + result.ptr->copy_b.function = libxsmm_dispatch_meltw_unary(itn, itk, &kldb, &itk, + (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); +#endif + kldb = itk; + } + else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_B & result.ptr->flags)) { +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + result.ptr->copy_b.function = libxsmm_dispatch_meltw_unary(itk, itn, &kldb, &itk, + (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY); +#endif + kldb = (libxsmm_blasint)tk; + } + if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_C & result.ptr->flags)) { +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + result.ptr->copy_o.function = libxsmm_dispatch_meltw_unary(itm, itn, &itm, &kldc, + (libxsmm_datatype)oprec, (libxsmm_datatype)oprec, (libxsmm_datatype)oprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY); + if (0 == (result.ptr->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* copy-in only if beta!=0 */ + result.ptr->copy_i.function = libxsmm_dispatch_meltw_unary(itm, itn, &kldc, &itm, + (libxsmm_datatype)oprec, (libxsmm_datatype)oprec, (libxsmm_datatype)oprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY); + } +#endif + kldc = (libxsmm_blasint)tm; + } + result.ptr->lda = ulda; result.ptr->ldb = uldb; + result.ptr->km = tm; result.ptr->kn = tn; + result.ptr->mt = mt; result.ptr->nt = nt; + result.ptr->m = um; result.ptr->n = un; + result.ptr->dm = LIBXSMM_UPDIV(ntm, mt) * tm; + result.ptr->dn = LIBXSMM_UPDIV(ntn, nt) * tn; + km = tm; kn = tn; + } + else { /* TT */ + const unsigned int tt = tm; + const libxsmm_blasint itt = (libxsmm_blasint)tt; +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + const libxsmm_blasint ildc = (libxsmm_blasint)result.ptr->ldc; + result.ptr->copy_o.function = libxsmm_dispatch_meltw_unary(itt, itt, &itt, &ildc, + (libxsmm_datatype)oprec, (libxsmm_datatype)oprec, (libxsmm_datatype)oprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + if (0 == (result.ptr->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* copy-in only if beta!=0 */ + result.ptr->copy_i.function = libxsmm_dispatch_meltw_unary(itt, itt, &ildc, &itt, + (libxsmm_datatype)oprec, (libxsmm_datatype)oprec, (libxsmm_datatype)oprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + } +#endif + klda = (libxsmm_blasint)uldb; + kldb = (libxsmm_blasint)ulda; + kldc = itt; + LIBXSMM_ASSERT(tt == tn); + if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_A & result.ptr->flags)) { +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + const libxsmm_blasint itk = (libxsmm_blasint)tk; + result.ptr->copy_a.function = libxsmm_dispatch_meltw_unary(itt, itk, &kldb, &itk, + (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY); +#endif + klda = itt; + } + if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_B & result.ptr->flags)) { +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + const libxsmm_blasint itn = (libxsmm_blasint)tn, itk = (libxsmm_blasint)tk; + result.ptr->copy_b.function = libxsmm_dispatch_meltw_unary(itk, itn, &klda, &itk, + (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, (libxsmm_datatype)iprec, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY); +#endif + kldb = (libxsmm_blasint)tk; + } + result.ptr->lda = uldb; result.ptr->ldb = ulda; + result.ptr->km = tn; result.ptr->kn = tm; + result.ptr->mt = nt; result.ptr->nt = mt; + result.ptr->m = un; result.ptr->n = um; + result.ptr->dm = LIBXSMM_UPDIV(ntn, nt) * tn; + result.ptr->dn = LIBXSMM_UPDIV(ntm, mt) * tm; + km = kn = tt; + } + result.ptr->dk = ntk / kt * tk; + result.ptr->kk = tk; + result.ptr->kt = kt; + result.ptr->k = uk; + desc = libxsmm_gemm_descriptor_init2( /* remove transpose flags from kernel request */ + &desc_blob, iprec, oprec, km, kn, result.ptr->kk, klda, kldb, kldc, + alpha, beta, result.ptr->gemm_flags & ~LIBXSMM_GEMM_FLAG_TRANS_AB, internal_gemm_tiled_prefetch); + result.ptr->kernel[0] = libxsmm_xmmdispatch(desc); + if (NULL != result.ptr->kernel[0].xmm) { + if (0 == (desc->flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* beta!=0 */ + result.ptr->kernel[1] = result.ptr->kernel[0]; + } + else { /* generate kernel with beta=1 */ + desc->flags &= ~LIBXSMM_GEMM_FLAG_BETA_0; + result.ptr->kernel[1] = libxsmm_xmmdispatch(desc); + if (NULL == result.ptr->kernel[1].xmm) result.ptr = NULL; + } + } + else result.ptr = NULL; + } + else { + result.ptr = NULL; + } + return result.ptr; +} + + +LIBXSMM_API_INLINE size_t libxsmm_gemm_handle_get_scratch_size_a(const libxsmm_gemm_handle* handle) +{ + size_t result; + if (NULL == handle || (0 == (handle->flags & LIBXSMM_GEMM_HANDLE_FLAG_COPY_A) + && (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) || + (LIBXSMM_GEMM_FLAG_TRANS_A & handle->gemm_flags) == 0))) + { + result = 0; + } + else { + const size_t size = (size_t)handle->km * handle->kk * handle->itypesize; + result = LIBXSMM_UP2(size, LIBXSMM_CACHELINE); + } + return result; +} + + +LIBXSMM_API_INLINE size_t libxsmm_gemm_handle_get_scratch_size_b(const libxsmm_gemm_handle* handle) +{ + size_t result; + if (NULL == handle || (0 == (handle->flags & LIBXSMM_GEMM_HANDLE_FLAG_COPY_B) + && (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) || + (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags) == 0))) + { + result = 0; + } + else { + const size_t size = (size_t)handle->kk * handle->kn * handle->itypesize; + result = LIBXSMM_UP2(size, LIBXSMM_CACHELINE); + } + return result; +} + + +LIBXSMM_API_INLINE size_t libxsmm_gemm_handle_get_scratch_size_c(const libxsmm_gemm_handle* handle) +{ + size_t result; + if (NULL == handle || (0 == (handle->flags & LIBXSMM_GEMM_HANDLE_FLAG_COPY_C) + && LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags))) + { + result = 0; + } + else { + const size_t size = (size_t)handle->km * handle->kn * handle->otypesize; + result = LIBXSMM_UP2(size, LIBXSMM_CACHELINE); + } + return result; +} + + +LIBXSMM_API size_t libxsmm_gemm_handle_get_scratch_size(const libxsmm_gemm_handle* handle) +{ + size_t result; + if (NULL != handle) { /* thread-local scratch buffer for GEMM */ + const size_t size_a = libxsmm_gemm_handle_get_scratch_size_a(handle); + const size_t size_b = libxsmm_gemm_handle_get_scratch_size_b(handle); + const size_t size_c = libxsmm_gemm_handle_get_scratch_size_c(handle); + result = (size_a + size_b + size_c) * handle->mt * handle->nt * handle->kt; + } + else { + result = 0; + } + return result; +} + + +LIBXSMM_API void libxsmm_gemm_task(const libxsmm_gemm_handle* handle, void* scratch, + const void* a, const void* b, void* c, /*unsigned*/int tid, /*unsigned*/int ntasks) +{ +#if !defined(NDEBUG) + if (NULL != handle && 0 <= tid && tid < ntasks) +#endif + { + const unsigned int utasks = (unsigned int)ntasks; + const unsigned int wksize = handle->mt * handle->nt * handle->kt; + const unsigned int spread = (wksize <= utasks ? (utasks / wksize) : 1); + const unsigned int utid = (unsigned int)tid, vtid = utid / spread; + if (utid < (spread * wksize) && 0 == (utid - vtid * spread)) { + const int excess = (utasks << 1) <= (vtid + wksize); + const unsigned int rtid = vtid / handle->mt, mtid = vtid - rtid * handle->mt, ntid = rtid % handle->nt, ktid = vtid / (handle->mt * handle->nt); + const unsigned int m0 = mtid * handle->dm, m1 = (0 == excess ? LIBXSMM_MIN(m0 + handle->dm, handle->m) : handle->m); + const unsigned int n0 = ntid * handle->dn, n1 = (0 == excess ? LIBXSMM_MIN(n0 + handle->dn, handle->n) : handle->n); + const unsigned int k0 = ktid * handle->dk, k1 = (0 == excess ? LIBXSMM_MIN(k0 + handle->dk, handle->k) : handle->k); + const unsigned int ldo = (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) ? handle->km : handle->kk); + /* calculate increments to simplify address calculations */ + const unsigned int dom = handle->km * handle->otypesize; + const unsigned int don = handle->kn * handle->otypesize; + const unsigned int dik = handle->kk * handle->itypesize; + const unsigned int on = handle->otypesize * n0; + /* calculate base address of thread-local storage */ + const size_t size_a = libxsmm_gemm_handle_get_scratch_size_a(handle); + const size_t size_b = libxsmm_gemm_handle_get_scratch_size_b(handle); + const size_t size_c = libxsmm_gemm_handle_get_scratch_size_c(handle); + char *const at = (char*)scratch + (size_a + size_b + size_c) * vtid; + char *const bt = at + size_a, *const ct = bt + size_b; + const libxsmm_xcopykernel kernel = { NULL }; + /* loop induction variables and other variables */ + unsigned int om = handle->otypesize * m0, im = m0, in = n0, ik = k0, im1, in1, ik1; + LIBXSMM_ASSERT_MSG(mtid < handle->mt && ntid < handle->nt && ktid < handle->kt, "Invalid task-ID"); + LIBXSMM_ASSERT_MSG(m1 <= handle->m && n1 <= handle->n && k1 <= handle->k, "Invalid task size"); + for (im1 = im + handle->km; (im1 - 1) < m1; im = im1, im1 += handle->km, om += dom) { + unsigned int dn = don, dka = dik, dkb = dik; + char *c0 = (char*)c, *ci; + const char *aa; + if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) { + if (0 != (LIBXSMM_GEMM_FLAG_TRANS_A & handle->gemm_flags)) { /* TN */ + aa = (const char*)a + ((size_t)im * handle->lda + k0) * handle->itypesize; + } + else if (0 != (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags)) { /* NT */ + aa = (const char*)a + ((size_t)k0 * handle->lda + im) * handle->itypesize; + dka *= handle->lda; dkb *= handle->ldb; + } + else { /* NN */ + aa = (const char*)a + ((size_t)k0 * handle->lda + im) * handle->itypesize; + dka *= handle->lda; + } + c0 += (size_t)on * handle->ldc + om; + dn *= handle->ldc; + } + else { /* TT */ + aa = (const char*)b + ((size_t)k0 * handle->lda + im) * handle->itypesize; + c0 += (size_t)on + handle->ldc * (size_t)om; + dka *= handle->lda; + } + for (in = n0, in1 = in + handle->kn; (in1 - 1) < n1; in = in1, in1 += handle->kn, c0 += dn) { + const char *a0 = aa, *b0 = (const char*)b; + if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) { + if (0 != (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags)) { /* NT */ + b0 += ((size_t)k0 * handle->ldb + in) * handle->itypesize; + } + else { /* NN or TN */ + b0 += ((size_t)in * handle->ldb + k0) * handle->itypesize; + } + } + else { /* TT */ + b0 = (const char*)a + ((size_t)in * handle->ldb + k0) * handle->itypesize; + } +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + if (NULL == handle->copy_i.ptr) +#endif + { + ci = (NULL == handle->copy_o.ptr ? c0 : ct); + if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) { + const unsigned int km = handle->kn, kn = handle->km; + libxsmm_otrans_internal(ct/*out*/, c0/*in*/, handle->otypesize, handle->ldc/*ldi*/, kn/*ldo*/, + 0, km, 0, kn, km/*tile*/, kn/*tile*/, kernel); + ci = ct; + } + else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_C & handle->flags)) { + if (0 == (handle->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* copy-in only if beta!=0 */ + libxsmm_matcopy_internal(ct/*out*/, c0/*in*/, handle->otypesize, handle->ldc/*ldi*/, handle->km/*ldo*/, + 0, handle->km, 0, handle->kn, handle->km/*tile*/, handle->kn/*tile*/, kernel); + } + ci = ct; + } + } +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + else { /* MCOPY/TCOPY kernel */ + LIBXSMM_MCOPY_CALL(handle->copy_i, handle->otypesize, c0, &handle->ldc, ct, &handle->km); + ci = ct; + } +#endif + for (ik = k0, ik1 = ik + handle->kk; (ik1 - 1) < k1; ik = ik1, ik1 += handle->kk) { + const char *const a1 = a0 + dka, *const b1 = b0 + dkb, *ai = a0, *bi = b0; +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + if (NULL == handle->copy_a.ptr) +#endif + { + if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) && + (LIBXSMM_GEMM_FLAG_TRANS_A & handle->gemm_flags) != 0) /* pure A-transpose */ + { + LIBXSMM_ASSERT(ldo == handle->km); + libxsmm_otrans_internal(at/*out*/, a0/*in*/, handle->itypesize, handle->lda/*ldi*/, ldo, + 0, handle->kk, 0, handle->km, handle->kk/*tile*/, handle->km/*tile*/, kernel); + ai = at; + } + else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_A & handle->flags)) { + libxsmm_matcopy_internal(at/*out*/, a0/*in*/, handle->itypesize, handle->lda/*ldi*/, ldo, + 0, handle->km, 0, handle->kk, handle->km/*tile*/, handle->kk/*tile*/, kernel); + ai = at; + } + } +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + else { /* MCOPY/TCOPY kernel */ + LIBXSMM_MCOPY_CALL(handle->copy_a, handle->itypesize, a0, &handle->lda, at, &ldo); + ai = at; + } +#endif +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + if (NULL == handle->copy_b.ptr) +#endif + { + if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) && + (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags) != 0) /* pure B-transpose */ + { + libxsmm_otrans_internal(bt/*out*/, b0/*in*/, handle->itypesize, handle->ldb/*ldi*/, handle->kk/*ldo*/, + 0, handle->kn, 0, handle->kk, handle->kn/*tile*/, handle->kk/*tile*/, kernel); + bi = bt; + } + else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_B & handle->flags)) { + libxsmm_matcopy_internal(bt/*out*/, b0/*in*/, handle->itypesize, handle->ldb/*ldi*/, handle->kk/*ldo*/, + 0, handle->kk, 0, handle->kn, handle->kk/*tile*/, handle->kn/*tile*/, kernel); + bi = bt; + } + } +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + else { /* MCOPY/TCOPY kernel */ + LIBXSMM_MCOPY_CALL(handle->copy_b, handle->itypesize, b0, &handle->ldb, bt, &handle->kk); + bi = bt; + } +#endif + /* beta0-kernel on first-touch, beta1-kernel otherwise (beta0/beta1 are identical if beta=1) */ + LIBXSMM_MMCALL_PRF(handle->kernel[k0!=ik?1:0].xmm, ai, bi, ci, a1, b1, c0); + a0 = a1; + b0 = b1; + } + /* TODO: synchronize */ +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + if (NULL == handle->copy_o.ptr) +#endif + { + if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) { + libxsmm_otrans_internal(c0/*out*/, ct/*in*/, handle->otypesize, handle->km/*ldi*/, handle->ldc/*ldo*/, + 0, handle->km, 0, handle->kn, handle->km/*tile*/, handle->kn/*tile*/, kernel); + } + else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_C & handle->flags)) { + libxsmm_matcopy_internal(c0/*out*/, ct/*in*/, handle->otypesize, handle->km/*ldi*/, handle->ldc/*ldo*/, + 0, handle->km, 0, handle->kn, handle->km/*tile*/, handle->kn/*tile*/, kernel); + } + } +#if defined(LIBXSMM_GEMM_XCOPY_JIT) + else { /* MCOPY/TCOPY kernel */ + LIBXSMM_MCOPY_CALL(handle->copy_o, handle->otypesize, ct, &handle->km, c0, &handle->ldc); + } +#endif + } + } + } + } +#if !defined(NDEBUG) + else if (/*implies LIBXSMM_INIT*/0 != libxsmm_get_verbosity()) { /* library code is expected to be mute */ + static int error_once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { + fprintf(stderr, "LIBXSMM ERROR: libxsmm_gemm_task - invalid handle!\n"); + } + } +#endif +} + + +LIBXSMM_API void libxsmm_xgemm(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc) +{ + libxsmm_gemm_blob blob; + const libxsmm_gemm_handle *const handle = libxsmm_gemm_handle_init(&blob, iprec, oprec, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta, LIBXSMM_GEMM_HANDLE_FLAG_AUTO, 1/*ntasks*/); + const size_t scratch_size = libxsmm_gemm_handle_get_scratch_size(handle); + void* scratch = NULL; + if (NULL != handle && (0 == scratch_size || + NULL != (scratch = libxsmm_scratch_malloc(scratch_size, LIBXSMM_CACHELINE, LIBXSMM_MALLOC_INTERNAL_CALLER)))) + { + libxsmm_gemm_task(handle, scratch, a, b, c, 0/*tid*/, 1/*ntasks*/); + libxsmm_free(scratch); + } + else { /* fallback or error */ + static int error_once = 0; + if (NULL == handle) { /* fallback */ + if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: XGEMM fallback code path triggered!\n"); + } + } + else if (0 != libxsmm_verbosity && /* library code is expected to be mute */ + 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: failed to allocate GEMM-scratch memory!\n"); + } + libxsmm_blas_xgemm(iprec, oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +} + + +LIBXSMM_API void libxsmm_dgemm_batch( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[], + const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + const libxsmm_blasint ngroups = LIBXSMM_ABS(*group_count), ptrsize = sizeof(void*); + libxsmm_blasint i, j = 0; + for (i = 0; i < ngroups; ++i) { + const libxsmm_blasint size = group_size[i]; + libxsmm_gemm_batch(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, transa_array + i, transb_array + i, + m_array[i], n_array[i], k_array[i], alpha_array + i, a_array + j, lda_array + i, b_array + j, ldb_array + i, beta_array + i, c_array + j, ldc_array + i, + 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size); + j += LIBXSMM_ABS(size); + } +} + + +LIBXSMM_API void libxsmm_sgemm_batch( + const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[], + const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[], + const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) +{ + const libxsmm_blasint ngroups = LIBXSMM_ABS(*group_count), ptrsize = sizeof(void*); + libxsmm_blasint i, j = 0; + for (i = 0; i < ngroups; ++i) { + const libxsmm_blasint size = group_size[i]; + libxsmm_gemm_batch(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, transa_array + i, transb_array + i, + m_array[i], n_array[i], k_array[i], alpha_array + i, a_array + j, lda_array + i, b_array + j, ldb_array + i, beta_array + i, c_array + j, ldc_array + i, + 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size); + j += LIBXSMM_ABS(size); + } +} + + +LIBXSMM_API void libxsmm_dgemm(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const double* alpha, const double* a, const libxsmm_blasint* lda, + const double* b, const libxsmm_blasint* ldb, + const double* beta, double* c, const libxsmm_blasint* ldc) +{ + LIBXSMM_XGEMM(double, double, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +LIBXSMM_API void libxsmm_sgemm(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const float* a, const libxsmm_blasint* lda, + const float* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) +{ + LIBXSMM_XGEMM(float, float, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +LIBXSMM_API void libxsmm_wigemm(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const int* alpha, const short* a, const libxsmm_blasint* lda, + const short* b, const libxsmm_blasint* ldb, + const int* beta, int* c, const libxsmm_blasint* ldc) +{ + LIBXSMM_XGEMM(short, int, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +LIBXSMM_API void libxsmm_bsgemm(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const libxsmm_bfloat16* a, const libxsmm_blasint* lda, + const libxsmm_bfloat16* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) +{ + LIBXSMM_XGEMM(libxsmm_bfloat16, float, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +LIBXSMM_API int libxsmm_mmbatch_kernel(libxsmm_xmmfunction kernel, libxsmm_blasint index_base, + libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + const void* a, const void* b, void* c, libxsmm_blasint batchsize, /*unsigned*/int tid, /*unsigned*/int ntasks, + unsigned char itypesize, unsigned char otypesize, int flags) +{ + int result = EXIT_SUCCESS; + const libxsmm_blasint size = LIBXSMM_ABS(batchsize); + const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, ntasks); + const libxsmm_blasint begin = tid * tasksize, span = begin + tasksize; + const libxsmm_blasint end = LIBXSMM_MIN(span, size); + + LIBXSMM_ASSERT(NULL != a && NULL != b && NULL != c && NULL != kernel.xmm); + if (begin < end) { + const char *const a0 = (const char*)a, *const b0 = (const char*)b; + char *const c0 = (char*)c; + + LIBXSMM_ASSERT(0 < itypesize && 0 < otypesize); + if (0 == (LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS & flags)) { + if (0 != index_stride) { /* stride arrays contain indexes */ + libxsmm_blasint i = begin * index_stride, ic = (NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0); + const char* ai = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0]; + const char* bi = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0]; + char* ci = &c0[ic * otypesize]; + const libxsmm_blasint end1 = (end != size ? end : (end - 1)) * index_stride; +#if (0 != LIBXSMM_SYNC) + if (1 == ntasks || 0 == internal_gemm_nlocks || 0 > batchsize || 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & flags)) +#endif + { /* no locking */ + if (NULL != stride_a && NULL != stride_b && NULL != stride_c) { + const unsigned char ibits = (unsigned char)LIBXSMM_INTRINSICS_BITSCANBWD32(itypesize); + const unsigned char obits = (unsigned char)LIBXSMM_INTRINSICS_BITSCANBWD32(otypesize); + + if (itypesize == (1 << ibits) && otypesize == (1 << obits)) { + for (i += index_stride; i <= end1; i += index_stride) { + const char *const an = &a0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) << ibits]; + const char *const bn = &b0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) << ibits]; + char *const cn = &c0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) << obits]; + kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */ + ai = an; bi = bn; ci = cn; /* next */ + } + } + else { /* non-pot type sizes */ + for (i += index_stride; i <= end1; i += index_stride) { + const char *const an = &a0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize]; + const char *const bn = &b0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize]; + char *const cn = &c0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) * otypesize]; + kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */ + ai = an; bi = bn; ci = cn; /* next */ + } + } + } + else { /* mixed specification of strides */ + for (i += index_stride; i <= end1; i += index_stride) { + const char *const an = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0]; + const char *const bn = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0]; + char *const cn = &c0[NULL != stride_c ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) * otypesize) : 0]; + kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */ + ai = an; bi = bn; ci = cn; /* next */ + } + } + if (end == size) { /* remainder multiplication */ + kernel.xmm(ai, bi, ci, ai, bi, ci); /* pseudo-prefetch */ + } + } +#if (0 != LIBXSMM_SYNC) + else { /* synchronize among C-indexes */ + LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock = &internal_gemm_lock[LIBXSMM_GEMM_LOCKIDX(ic, internal_gemm_nlocks)].state; +# if defined(LIBXSMM_GEMM_LOCKFWD) + LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock0 = NULL; +# endif + LIBXSMM_ASSERT(NULL != lock); + if (NULL != stride_a && NULL != stride_b && NULL != stride_c) { + for (i += index_stride; i <= end1; i += index_stride) { + ic = LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base; + { + const char *const an = &a0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize]; + const char *const bn = &b0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize]; + char *const cn = &c0[ic * otypesize]; + LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) *const lock1 = &internal_gemm_lock[LIBXSMM_GEMM_LOCKIDX(ic, internal_gemm_nlocks)].state; +# if defined(LIBXSMM_GEMM_LOCKFWD) + if (lock != lock0) { lock0 = lock; LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); } +# else + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); +# endif + kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */ +# if defined(LIBXSMM_GEMM_LOCKFWD) + if (lock != lock1 || i == end1) { LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; } +# else + LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; +# endif + ai = an; bi = bn; ci = cn; /* next */ + } + } + } + else { + for (i += index_stride; i <= end1; i += index_stride) { + ic = (NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0); + { + const char *const an = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0]; + const char *const bn = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0]; + char *const cn = &c0[ic * otypesize]; + LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) *const lock1 = &internal_gemm_lock[LIBXSMM_GEMM_LOCKIDX(ic, internal_gemm_nlocks)].state; +# if defined(LIBXSMM_GEMM_LOCKFWD) + if (lock != lock0) { lock0 = lock; LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); } +# else + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); +# endif + kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */ +# if defined(LIBXSMM_GEMM_LOCKFWD) + if (lock != lock1 || i == end1) { LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; } +# else + LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; +# endif + ai = an; bi = bn; ci = cn; /* next */ + } + } + } + if (end == size) { /* remainder multiplication */ + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); + kernel.xmm(ai, bi, ci, ai, bi, ci); /* pseudo-prefetch */ + LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); + } + } +#endif /*(0 != LIBXSMM_SYNC)*/ + } + else { /* array of pointers to matrices (singular strides are measured in Bytes) */ + const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0); + const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0); + const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0); + const libxsmm_blasint end1 = (end != size ? end : (end - 1)); + const char *ai = a0 + (size_t)da * begin, *bi = b0 + (size_t)db * begin; + char* ci = c0 + (size_t)dc * begin; + libxsmm_blasint i; +#if (0 != LIBXSMM_SYNC) + if (1 == ntasks || 0 == internal_gemm_nlocks || 0 > batchsize || 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & flags)) +#endif + { /* no locking */ + for (i = begin; i < end1; ++i) { + const char *const an = ai + da, *const bn = bi + db; + char *const cn = ci + dc; +#if defined(LIBXSMM_BATCH_CHECK) + if (NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != *((const void**)ci)) +#endif + { + kernel.xmm( /* with prefetch */ + *((const void**)ai), *((const void**)bi), *((void**)ci), + *((const void**)an), *((const void**)bn), *((const void**)cn)); + } + ai = an; bi = bn; ci = cn; /* next */ + } + if ( /* remainder multiplication */ +#if defined(LIBXSMM_BATCH_CHECK) + NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != *((const void**)ci) && +#endif + end == size) + { + kernel.xmm( /* pseudo-prefetch */ + *((const void**)ai), *((const void**)bi), *((void**)ci), + *((const void**)ai), *((const void**)bi), *((const void**)ci)); + } + } +#if (0 != LIBXSMM_SYNC) + else { /* synchronize among C-indexes */ + void* cc = *((void**)ci); + LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock = &internal_gemm_lock[LIBXSMM_GEMM_LOCKPTR(cc, internal_gemm_nlocks)].state; +# if defined(LIBXSMM_GEMM_LOCKFWD) + LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock0 = NULL; +# endif + LIBXSMM_ASSERT(NULL != lock); + for (i = begin + 1; i <= end1; ++i) { + const char *const an = ai + da, *const bn = bi + db; + char *const cn = ci + dc; + void *const nc = *((void**)cn); +# if defined(LIBXSMM_BATCH_CHECK) + if (NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != cc) +# endif + { + LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) *const lock1 = &internal_gemm_lock[LIBXSMM_GEMM_LOCKPTR(nc, internal_gemm_nlocks)].state; +# if defined(LIBXSMM_GEMM_LOCKFWD) + if (lock != lock0) { lock0 = lock; LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); } +# else + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); +# endif + kernel.xmm( /* with prefetch */ + *((const void**)ai), *((const void**)bi), cc, + *((const void**)an), *((const void**)bn), *((const void**)cn)); +# if defined(LIBXSMM_GEMM_LOCKFWD) + if (lock != lock1 || i == end1) { LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; } +# else + LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; +# endif + } + ai = an; bi = bn; ci = cn; cc = nc; /* next */ + } + if ( /* remainder multiplication */ +# if defined(LIBXSMM_BATCH_CHECK) + NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != cc && +# endif + end == size) + { + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); + kernel.xmm( /* pseudo-prefetch */ + *((const void**)ai), *((const void**)bi), cc, + *((const void**)ai), *((const void**)bi), cc); + LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); + } + } +#endif /*(0 != LIBXSMM_SYNC)*/ + } + } +#if defined(LIBXSMM_GEMM_BATCHREDUCE) + else /* LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS */ +# if defined(LIBXSMM_BATCH_CHECK) + if ( +# if (0 != LIBXSMM_SYNC) + (1 == ntasks || 0 == internal_gemm_nlocks || 0 > batchsize) && +# endif + (0 == (LIBXSMM_GEMM_FLAG_BETA_0 & flags)) && + (0 != internal_gemm_batchreduce)) +# endif + { + const unsigned int n = libxsmm_mmbatch_size * (LIBXSMM_GEMM_BATCHSCALE) / ((unsigned int)sizeof(void*)); + LIBXSMM_ASSERT(NULL != libxsmm_mmbatch_array && 0 != libxsmm_mmbatch_size); + if ((2U/*A and B matrices*/ * tasksize) <= n) { + const void **ai = (const void**)libxsmm_mmbatch_array + begin, **bi = ai + size; + unsigned long long count; + if (0 != index_stride) { /* stride arrays contain indexes */ + const size_t end_stride = (size_t)end * index_stride; + size_t i = (size_t)begin * index_stride; + char *ci = &c0[NULL != stride_c ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) * otypesize) : 0], *cn = ci; + do { + for (count = 0; i < end_stride && ci == cn; ++count) { + const size_t j = i + index_stride; + *ai++ = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0]; + *bi++ = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0]; + cn = &c0[NULL != stride_c ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, j) - index_base) * otypesize) : 0]; + i = j; + } + ai = (const void**)libxsmm_mmbatch_array + begin; bi = ai + size; + kernel.xbm(ai, bi, ci, &count); + ci = cn; + } while (i < end_stride); + } + else { /* array of pointers to matrices (singular strides are measured in Bytes) */ + const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0); + const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0); + const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0); + const char *ia = a0 + (size_t)da * begin, *ib = b0 + (size_t)db * begin; + char* ic = c0 + (size_t)dc * begin; + if ( +# if defined(LIBXSMM_BATCH_CHECK) + NULL != *((const void**)ia) && NULL != *((const void**)ib) && NULL != *((const void**)ic) && +# endif + sizeof(void*) == da && sizeof(void*) == db) /* fast path */ + { + if (0 != dc) { + libxsmm_blasint i = begin; + char* jc = ic; + do { + for (count = 0; i < end && *((const void**)ic) == *((const void**)jc); ++i) { +# if defined(LIBXSMM_BATCH_CHECK) + if (NULL != *((const void**)jc)) +# endif + ++count; + jc += dc; /* next */ + } + memcpy((void*)ai, ia, count * sizeof(void*)); + memcpy((void*)bi, ib, count * sizeof(void*)); + kernel.xbm(ai, bi, *((void**)ic), &count); + ic = jc; + } while (i < end); + } + else { /* fastest path */ + count = (unsigned long long)end - begin; + memcpy((void*)ai, ia, count * sizeof(void*)); + memcpy((void*)bi, ib, count * sizeof(void*)); + kernel.xbm(ai, bi, *((void**)ic), &count); + } + } + else { /* custom-copy required */ + libxsmm_blasint i = begin; + char* jc = ic; + do { + for (count = 0; i < end && *((const void**)ic) == *((const void**)jc); ++i) { +# if defined(LIBXSMM_BATCH_CHECK) + if (NULL != *((const void**)ia) && NULL != *((const void**)ib) && NULL != *((const void**)jc)) +# endif + { + *ai++ = *((const void**)ia); *bi++ = *((const void**)ib); + ++count; + } + ia += da; ib += db; jc += dc; /* next */ + } + ai = (const void**)libxsmm_mmbatch_array + begin; bi = ai + size; + kernel.xbm(ai, bi, *((void**)ic), &count); + ic = jc; + } while (i < end); + } + } + } + else { /* fallback */ + result = EXIT_FAILURE; + } + } +#endif /*defined(LIBXSMM_GEMM_BATCHREDUCE)*/ + } + /* coverity[missing_unlock] */ + return result; +} + + +LIBXSMM_API void libxsmm_gemm_internal_set_batchflag(libxsmm_gemm_descriptor* descriptor, void* c, libxsmm_blasint index_stride, + libxsmm_blasint batchsize, int multithreaded) +{ + LIBXSMM_ASSERT(NULL != descriptor); + if (0 != (LIBXSMM_GEMM_FLAG_BETA_0 & descriptor->flags)) { + const uintptr_t vw = (LIBXSMM_X86_AVX512 <= libxsmm_target_archid ? 64 : 32); + /* assume that all C-matrices are aligned eventually */ + if (0 == LIBXSMM_MOD2((uintptr_t)c, vw) +#if 0 /* should fallback in BE */ + && LIBXSMM_X86_AVX <= libxsmm_target_archid +#endif + && 0 != index_stride) + { + const int oprec = LIBXSMM_GETENUM_OUT(descriptor->datatype); + const libxsmm_blasint typesize = LIBXSMM_TYPESIZE(oprec); + const libxsmm_blasint csize = (libxsmm_blasint)descriptor->ldc * descriptor->n * typesize; + /* finalize assumption if matrix-size is a multiple of the vector-width */ + descriptor->flags |= (unsigned short)(0 == LIBXSMM_MOD2(csize, vw) ? LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT : 0); + } + } +#if defined(LIBXSMM_GEMM_BATCHREDUCE) + else if (0 != internal_gemm_batchreduce) { /* check if reduce-batch kernel can be used */ + static int error_once = 0; + LIBXSMM_ASSERT(NULL != libxsmm_mmbatch_array); +# if (0 != LIBXSMM_SYNC) + if (0 == multithreaded || 0 == internal_gemm_nlocks || 0 > batchsize) +# endif + { + int result = EXIT_FAILURE; + switch (LIBXSMM_GETENUM_INP(descriptor->datatype)) { + case LIBXSMM_GEMM_PRECISION_F64: { + if (LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_OUT(descriptor->datatype)) { + result = EXIT_SUCCESS; + } + } break; + case LIBXSMM_GEMM_PRECISION_F32: { + if (LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_OUT(descriptor->datatype)) { + result = EXIT_SUCCESS; + } + } break; + } + if (EXIT_SUCCESS == result) { + descriptor->flags |= LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS; + descriptor->prefetch = 0; /* omit decision */ + } + else { + if ((LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) && /* library code is expected to be mute */ + 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: data type not supported in batch-reduce!\n"); + } + } + } +# if (0 != LIBXSMM_SYNC) + else if ((LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) && /* library code is expected to be mute */ + 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM: potential data races prevent batch-reduce.\n"); + } +# endif + } +#endif /*defined(LIBXSMM_GEMM_BATCHREDUCE)*/ +#if !defined(LIBXSMM_GEMM_BATCHREDUCE) || (0 == LIBXSMM_SYNC) + LIBXSMM_UNUSED(batchsize); LIBXSMM_UNUSED(multithreaded); +#endif +} + + +LIBXSMM_API_INTERN void libxsmm_dmmbatch_blas(const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const double* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const double* beta, void* c, const libxsmm_blasint* ldc, + libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize) +{ +#if defined(LIBXSMM_BATCH_CHECK) + if (NULL != a && NULL != b && NULL != c) +#endif + { + const libxsmm_blasint end = LIBXSMM_ABS(batchsize); + libxsmm_blasint i; + if (0 != index_stride) { /* stride arrays contain indexes */ + const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base) : 0); + const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base) : 0); + const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base) : 0); + const libxsmm_blasint end1 = end * index_stride; + const double *const a0 = (const double*)a, *const b0 = (const double*)b, *ai = a0 + da, *bi = b0 + db; + double *const c0 = (double*)c, *ci = c0 + dc; + for (i = index_stride; i <= end1; i += index_stride) { + const double *const an = &a0[NULL != stride_a ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) : 0]; + const double *const bn = &b0[NULL != stride_b ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) : 0]; + double *const cn = &c0[NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0]; + libxsmm_blas_dgemm(transa, transb, &m, &n, &k, alpha, ai, lda, bi, ldb, beta, ci, ldc); + ai = an; bi = bn; ci = cn; /* next */ + } + } + else { /* array of pointers to matrices (singular strides are measured in Bytes) */ + const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0); + const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0); + const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0); + const char *const a0 = (const char*)a, *const b0 = (const char*)b, *ai = a0, *bi = b0; + char *const c0 = (char*)c, *ci = c0; + for (i = 0; i < end; ++i) { + const char *const an = ai + da, *const bn = bi + db; + char *const cn = ci + dc; +#if defined(LIBXSMM_BATCH_CHECK) + if (NULL != *((const double**)ai) && NULL != *((const double**)bi) && NULL != *((const double**)ci)) +#endif + { + libxsmm_blas_dgemm(transa, transb, &m, &n, &k, alpha, *((const double**)ai), lda, *((const double**)bi), ldb, beta, *((double**)ci), ldc); + } + ai = an; bi = bn; ci = cn; /* next */ + } + } + } +} + + +LIBXSMM_API_INTERN void libxsmm_smmbatch_blas(const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const float* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const float* beta, void* c, const libxsmm_blasint* ldc, + libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize) +{ +#if defined(LIBXSMM_BATCH_CHECK) + if (NULL != a && NULL != b && NULL != c) +#endif + { + const libxsmm_blasint end = LIBXSMM_ABS(batchsize); + libxsmm_blasint i; + if (0 != index_stride) { /* stride arrays contain indexes */ + const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base) : 0); + const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base) : 0); + const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base) : 0); + const libxsmm_blasint end1 = end * index_stride; + const float *a0 = (const float*)a, *b0 = (const float*)b, *ai = a0 + da, *bi = b0 + db; + float *c0 = (float*)c, *ci = c0 + dc; + for (i = index_stride; i <= end1; i += index_stride) { + const float *const an = &a0[NULL != stride_a ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) : 0]; + const float *const bn = &b0[NULL != stride_b ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) : 0]; + float *const cn = &c0[NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0]; + libxsmm_blas_sgemm(transa, transb, &m, &n, &k, alpha, ai, lda, bi, ldb, beta, ci, ldc); + ai = an; bi = bn; ci = cn; /* next */ + } + } + else { /* array of pointers to matrices (singular strides are measured in Bytes) */ + const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0); + const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0); + const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0); + const char *a0 = (const char*)a, *b0 = (const char*)b, *ai = a0, *bi = b0; + char *c0 = (char*)c, *ci = c0; + for (i = 0; i < end; ++i) { + const char *const an = ai + da; + const char *const bn = bi + db; + char *const cn = ci + dc; +#if defined(LIBXSMM_BATCH_CHECK) + if (NULL != *((const float**)ai) && NULL != *((const float**)bi) && NULL != *((const float**)ci)) +#endif + { + libxsmm_blas_sgemm(transa, transb, &m, &n, &k, alpha, *((const float**)ai), lda, *((const float**)bi), ldb, beta, *((float**)ci), ldc); + } + ai = an; bi = bn; ci = cn; /* next */ + } + } + } +} + + +LIBXSMM_API int libxsmm_mmbatch_blas( + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const void* beta, void* c, const libxsmm_blasint* ldc, + libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize) +{ + int result; + if (NULL != a && NULL != b && NULL != c) { + switch (LIBXSMM_GETENUM(iprec, oprec)) { + case LIBXSMM_GEMM_PRECISION_F64: { + libxsmm_dmmbatch_blas(transa, transb, m, n, k, + (const double*)alpha, a, lda, b, ldb, (const double*)beta, c, ldc, + index_base, index_stride, stride_a, stride_b, stride_c, batchsize); + result = EXIT_SUCCESS; + } break; + case LIBXSMM_GEMM_PRECISION_F32: { + libxsmm_smmbatch_blas(transa, transb, m, n, k, + (const float*)alpha, a, lda, b, ldb, (const float*)beta, c, ldc, + index_base, index_stride, stride_a, stride_b, stride_c, batchsize); + result = EXIT_SUCCESS; + } break; + default: result = EXIT_FAILURE; + } + } + else { + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API void libxsmm_mmbatch(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc, libxsmm_blasint index_base, libxsmm_blasint index_stride, + const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize, /*unsigned*/int tid, /*unsigned*/int ntasks) +{ + static int error_once = 0; +#if defined(LIBXSMM_BATCH_CHECK) + if (NULL != a && NULL != b && NULL != c && 0 <= tid && tid < ntasks) +#endif + { + const unsigned char otypesize = libxsmm_typesize((libxsmm_datatype)oprec); + int result = EXIT_FAILURE; + LIBXSMM_INIT + if (LIBXSMM_SMM_AI(m, n, k, 2/*RFO*/, otypesize)) { /* check if an SMM is suitable */ + const int gemm_flags = LIBXSMM_GEMM_PFLAGS(transa, transb, LIBXSMM_FLAGS); + libxsmm_descriptor_blob blob; + libxsmm_gemm_descriptor *const desc = libxsmm_gemm_descriptor_init2(&blob, iprec, oprec, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, alpha, beta, gemm_flags, libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO)); + if (NULL != desc) { + libxsmm_xmmfunction kernel; + libxsmm_gemm_internal_set_batchflag(desc, c, index_stride, batchsize, 0/*multi-threaded*/); + kernel = libxsmm_xmmdispatch(desc); + if (NULL != kernel.xmm) { + result = libxsmm_mmbatch_kernel(kernel, index_base, index_stride, + stride_a, stride_b, stride_c, a, b, c, batchsize, tid, ntasks, + libxsmm_typesize((libxsmm_datatype)iprec), otypesize, desc->flags); + } + } + } + if (EXIT_SUCCESS != result) { /* quiet fallback */ + if (EXIT_SUCCESS == libxsmm_mmbatch_blas(iprec, oprec, + transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + index_base, index_stride, stride_a, stride_b, stride_c, batchsize)) + { + if (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) { + const size_t threshold = LIBXSMM_MNK_SIZE(m, n, m); + static size_t threshold_max = 0; + if (threshold_max < threshold) { + LIBXSMM_STDIO_ACQUIRE(); + fprintf(stderr, "LIBXSMM WARNING: "); + libxsmm_gemm_print2(stderr, iprec, oprec, transa, transb, &m, &n, &k, + alpha, NULL/*a*/, lda, NULL/*b*/, ldb, beta, NULL/*c*/, ldc); + fprintf(stderr, " => batched GEMM was falling back to BLAS!\n"); + LIBXSMM_STDIO_RELEASE(); + threshold_max = threshold; + } + } + } + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: libxsmm_mmbatch failed!\n"); + } + } + } +#if defined(LIBXSMM_BATCH_CHECK) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: incorrect arguments (libxsmm_mmbatch)!\n"); + } +#endif +} + + +LIBXSMM_API void libxsmm_gemm_batch(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, + const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc, libxsmm_blasint index_base, libxsmm_blasint index_stride, + const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize) +{ + libxsmm_mmbatch(iprec, oprec, transa, transb, m, n, k, + alpha,a, lda, b, ldb, beta, c, ldc, index_base, index_stride, + stride_a, stride_b, stride_c, batchsize, 0/*tid*/, 1/*ntasks*/); +} + + +#if defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__)) + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dgemm)(const char*, const char*, + const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const double*, const double*, const libxsmm_blasint*, + const double*, const libxsmm_blasint*, + const double*, double*, const libxsmm_blasint*); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const double* alpha, const double* a, const libxsmm_blasint* lda, + const double* b, const libxsmm_blasint* ldb, + const double* beta, double* c, const libxsmm_blasint* ldc) +{ + libxsmm_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_sgemm)(const char*, const char*, + const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const float*, const float*, const libxsmm_blasint*, + const float*, const libxsmm_blasint*, + const float*, float*, const libxsmm_blasint*); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_sgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const float* a, const libxsmm_blasint* lda, + const float* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) +{ + libxsmm_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_wigemm)(const char*, const char*, + const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const int*, const short*, const libxsmm_blasint*, + const short*, const libxsmm_blasint*, + const int*, int*, const libxsmm_blasint*); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_wigemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const int* alpha, const short* a, const libxsmm_blasint* lda, + const short* b, const libxsmm_blasint* ldb, + const int* beta, int* c, const libxsmm_blasint* ldc) +{ + libxsmm_wigemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_bsgemm)(const char*, const char*, + const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const float*, const libxsmm_bfloat16*, const libxsmm_blasint*, + const libxsmm_bfloat16*, const libxsmm_blasint*, + const float*, float*, const libxsmm_blasint*); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_bsgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const libxsmm_bfloat16* a, const libxsmm_blasint* lda, + const libxsmm_bfloat16* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) +{ + libxsmm_bsgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_xgemm)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*, + const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const float*, const float*, const libxsmm_blasint*, + const float*, const libxsmm_blasint*, + const float*, float*, const libxsmm_blasint*); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_xgemm)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec, + const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const float* a, const libxsmm_blasint* lda, + const float* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) +{ + LIBXSMM_ASSERT(NULL != iprec && NULL != oprec); + libxsmm_blas_xgemm(*iprec, *oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_dgemm)(const char*, const char*, + const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const double*, const double*, const libxsmm_blasint*, + const double*, const libxsmm_blasint*, + const double*, double*, const libxsmm_blasint*); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_dgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const double* alpha, const double* a, const libxsmm_blasint* lda, + const double* b, const libxsmm_blasint* ldb, + const double* beta, double* c, const libxsmm_blasint* ldc) +{ + libxsmm_blas_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_sgemm)(const char*, const char*, + const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const float*, const float*, const libxsmm_blasint*, + const float*, const libxsmm_blasint*, + const float*, float*, const libxsmm_blasint*); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_sgemm)(const char* transa, const char* transb, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const float* alpha, const float* a, const libxsmm_blasint* lda, + const float* b, const libxsmm_blasint* ldb, + const float* beta, float* c, const libxsmm_blasint* ldc) +{ + libxsmm_blas_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_mmbatch)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*, + const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const void*, const void*, const libxsmm_blasint*, const void*, const libxsmm_blasint*, + const void*, void*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const libxsmm_blasint[], const libxsmm_blasint[], const libxsmm_blasint[], + const libxsmm_blasint*, const /*unsigned*/int*, const /*unsigned*/int*); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_mmbatch)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec, + const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc, const libxsmm_blasint* index_base, const libxsmm_blasint* index_stride, + const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + const libxsmm_blasint* batchsize, const /*unsigned*/int* tid, const /*unsigned*/int* ntasks) +{ + LIBXSMM_ASSERT(NULL != iprec && NULL != oprec && NULL != m && NULL != n && NULL != k); + LIBXSMM_ASSERT(NULL != index_base && NULL != index_stride && NULL != batchsize); + LIBXSMM_ASSERT(NULL != tid && NULL != ntasks); + libxsmm_mmbatch(*iprec, *oprec, transa, transb, *m, *n, *k, alpha, a, lda, b, ldb, beta, c, ldc, + *index_base, *index_stride, stride_a, stride_b, stride_c, *batchsize, *tid, *ntasks); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_gemm_batch)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*, + const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const void*, const void*, const libxsmm_blasint*, const void*, const libxsmm_blasint*, + const void*, void*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*, + const libxsmm_blasint[], const libxsmm_blasint[], const libxsmm_blasint[], + const libxsmm_blasint*); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_gemm_batch)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec, + const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, + const void* beta, void* c, const libxsmm_blasint* ldc, const libxsmm_blasint* index_base, const libxsmm_blasint* index_stride, + const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + const libxsmm_blasint* batchsize) +{ + LIBXSMM_ASSERT(NULL != iprec && NULL != oprec && NULL != m && NULL != n && NULL != k); + LIBXSMM_ASSERT(NULL != index_base && NULL != index_stride && NULL != batchsize); + libxsmm_gemm_batch(*iprec, *oprec, transa, transb, *m, *n, *k, alpha, a, lda, b, ldb, beta, c, ldc, + *index_base, *index_stride, stride_a, stride_b, stride_c, *batchsize); +} + +#endif /*defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/ + diff --git a/third_party/libxsmm/src/libxsmm_gemm.h b/third_party/libxsmm/src/libxsmm_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..8d9db076906a3ebc1474436b60906aa28a994ebe --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_gemm.h @@ -0,0 +1,219 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_GEMM_H +#define LIBXSMM_GEMM_H + +#include "libxsmm_main.h" + +#if !defined(LIBXSMM_BLAS_WRAP_DYNAMIC) && defined(LIBXSMM_INTERCEPT_DYNAMIC) && (!defined(__BLAS) || (0 != __BLAS)) +# define LIBXSMM_BLAS_WRAP_DYNAMIC +#endif + +#if !defined(LIBXSMM_GEMM_LOCK) +# define LIBXSMM_GEMM_LOCK LIBXSMM_LOCK_DEFAULT +#endif +#if !defined(LIBXSMM_GEMM_MMBATCH_SCALE) +# define LIBXSMM_GEMM_MMBATCH_SCALE 1.5 +#endif +#if !defined(LIBXSMM_GEMM_MMBATCH_VERBOSITY) +# define LIBXSMM_GEMM_MMBATCH_VERBOSITY ((LIBXSMM_VERBOSITY_HIGH) + 1) +#endif +#if !defined(LIBXSMM_GEMM_NPARGROUPS) +# define LIBXSMM_GEMM_NPARGROUPS 128 +#endif + +#if !defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD) && \ + (defined(LIBXSMM_CONFIG_WRAP) && 0 != (LIBXSMM_CONFIG_WRAP)) && \ + (defined(LIBXSMM_BLAS_WRAP_DYNAMIC) || !defined(NDEBUG) || defined(_WIN32)) /* debug */ +# define LIBXSMM_WRAP LIBXSMM_CONFIG_WRAP +#endif + +/** Undefine (disarm) MKL's DIRECT_CALL macros. */ +#if (defined(MKL_DIRECT_CALL_SEQ) || defined(MKL_DIRECT_CALL)) +# if defined(sgemm_) +# undef sgemm_ +# endif +# if defined(dgemm_) +# undef dgemm_ +# endif +#endif + +#if !defined(LIBXSMM_BLAS_ERROR) +#define LIBXSMM_BLAS_ERROR(SYMBOL, PCOUNTER) do { \ + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(PCOUNTER, 1, LIBXSMM_ATOMIC_RELAXED)) { \ + fprintf(stderr, "LIBXSMM ERROR: application must be linked against LAPACK/BLAS %s!\n", SYMBOL); \ + } \ + } while(0) +#endif + +#if defined(LIBXSMM_BUILD) +# define LIBXSMM_BLAS_WRAPPER_STATIC1(TYPE, KIND, ORIGINAL) if (NULL == (ORIGINAL)) { \ + ORIGINAL = LIBXSMM_FSYMBOL(LIBXSMM_CONCATENATE(__real_, LIBXSMM_TPREFIX(TYPE, KIND))); \ + } +# define LIBXSMM_BLAS_WRAPPER_STATIC0 LIBXSMM_BLAS_WRAPPER_STATIC1 +#else +# define LIBXSMM_BLAS_WRAPPER_STATIC1(TYPE, KIND, ORIGINAL) if (NULL == (ORIGINAL)) { \ + ORIGINAL = (LIBXSMM_BLAS_FNTYPE(TYPE, KIND))LIBXSMM_BLAS_SYMBOL(TYPE, KIND); \ + } +# define LIBXSMM_BLAS_WRAPPER_STATIC0(TYPE, KIND, ORIGINAL) +#endif +#define LIBXSMM_BLAS_WRAPPER_STATIC(CONDITION, TYPE, KIND, ORIGINAL) \ + LIBXSMM_CONCATENATE(LIBXSMM_BLAS_WRAPPER_STATIC, CONDITION)(TYPE, KIND, ORIGINAL) + +#if defined(LIBXSMM_BLAS_WRAP_DYNAMIC) +# define LIBXSMM_BLAS_WRAPPER_DYNAMIC(TYPE, KIND, ORIGINAL, NEXT) { \ + union { const void* pfin; \ + LIBXSMM_BLAS_FNTYPE(TYPE, KIND) (*chain)(void); /* chain */ \ + LIBXSMM_BLAS_FNTYPE(TYPE, KIND) pfout; \ + } libxsmm_blas_wrapper_dynamic_ /*= { 0 }*/; \ + dlerror(); /* clear an eventual error status */ \ + libxsmm_blas_wrapper_dynamic_.chain = NEXT; \ + libxsmm_blas_wrapper_dynamic_.pfin = ((NULL == libxsmm_blas_wrapper_dynamic_.pfin) ? \ + dlsym(LIBXSMM_RTLD_NEXT, "libxsmm_original_" LIBXSMM_STRINGIFY(LIBXSMM_TPREFIX(TYPE, KIND))) : NULL); \ + if (NULL == libxsmm_blas_wrapper_dynamic_.pfout || NULL != dlerror() || NULL == libxsmm_blas_wrapper_dynamic_.chain()) { \ + libxsmm_blas_wrapper_dynamic_.pfin = dlsym(LIBXSMM_RTLD_NEXT, LIBXSMM_STRINGIFY(LIBXSMM_BLAS_SYMBOL(TYPE, KIND))); \ + /*LIBXSMM_ATOMIC_STORE(&(ORIGINAL), libxsmm_blas_wrapper_dynamic_.pfout, LIBXSMM_ATOMIC_RELAXED);*/ \ + ORIGINAL = (NULL == dlerror() ? libxsmm_blas_wrapper_dynamic_.pfout : NULL); \ + } \ + } +#else +# define LIBXSMM_BLAS_WRAPPER_DYNAMIC(TYPE, KIND, ORIGINAL, NEXT) +#endif + +#define LIBXSMM_BLAS_WRAPPER(CONDITION, TYPE, KIND, ORIGINAL, NEXT) if (NULL == (ORIGINAL)) { \ + LIBXSMM_BLAS_WRAPPER_DYNAMIC(TYPE, KIND, ORIGINAL, NEXT); \ + LIBXSMM_BLAS_WRAPPER_STATIC(CONDITION, TYPE, KIND, ORIGINAL); \ +} + + +/** Provides GEMM functions available via BLAS; NOT thread-safe. */ +LIBXSMM_API_INTERN void libxsmm_gemm_init(int archid); + +/** Finalizes the GEMM facility; NOT thread-safe. */ +LIBXSMM_API_INTERN void libxsmm_gemm_finalize(void); + +LIBXSMM_API_INTERN int libxsmm_gemm_prefetch2uid(libxsmm_gemm_prefetch_type prefetch); +LIBXSMM_API_INTERN libxsmm_gemm_prefetch_type libxsmm_gemm_uid2prefetch(int uid); + +#if defined(LIBXSMM_BUILD) +#if defined(LIBXSMM_BUILD_EXT) +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(__wrap_dgemm_batch)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm_batch)); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(__wrap_sgemm_batch)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm_batch)); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(__wrap_dgemm)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm)); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(__wrap_sgemm)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm)); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(__wrap_dgemv)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemv)); +LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(__wrap_sgemv)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemv)); +LIBXSMM_APIEXT void __wrap_dgemm_batch(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm_batch)); +LIBXSMM_APIEXT void __wrap_sgemm_batch(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm_batch)); +#endif +LIBXSMM_API void LIBXSMM_FSYMBOL(__real_dgemm_batch)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm_batch)); +LIBXSMM_API void LIBXSMM_FSYMBOL(__real_sgemm_batch)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm_batch)); +LIBXSMM_API void LIBXSMM_FSYMBOL(__real_dgemm)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm)); +LIBXSMM_API void LIBXSMM_FSYMBOL(__real_sgemm)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm)); +LIBXSMM_API void LIBXSMM_FSYMBOL(__real_dgemv)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemv)); +LIBXSMM_API void LIBXSMM_FSYMBOL(__real_sgemv)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemv)); +LIBXSMM_API void __real_dgemm_batch(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm_batch)); +LIBXSMM_API void __real_sgemm_batch(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm_batch)); +#endif + +LIBXSMM_BLAS_SYMBOL_FDECL(LIBXSMM_BLAS_CONST*, *, double, gemm_batch); +LIBXSMM_BLAS_SYMBOL_CDECL(LIBXSMM_BLAS_CONST*, *, double, gemm_batch); +LIBXSMM_BLAS_SYMBOL_FDECL(LIBXSMM_BLAS_CONST*, *, float, gemm_batch); +LIBXSMM_BLAS_SYMBOL_CDECL(LIBXSMM_BLAS_CONST*, *, float, gemm_batch); +LIBXSMM_BLAS_SYMBOL_FDECL(LIBXSMM_BLAS_CONST*, *, double, gemm); +LIBXSMM_BLAS_SYMBOL_FDECL(LIBXSMM_BLAS_CONST*, *, float, gemm); +LIBXSMM_BLAS_SYMBOL_FDECL(LIBXSMM_BLAS_CONST*, *, double, gemv); +LIBXSMM_BLAS_SYMBOL_FDECL(LIBXSMM_BLAS_CONST*, *, float, gemv); + +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_gemm_handle { + libxsmm_xcopykernel copy_a, copy_b, copy_i, copy_o; + libxsmm_xmmfunction kernel[2]; + unsigned int m, n, k, lda, ldb, ldc; + /* kernel size (tile) */ + unsigned int km, kn, kk; + /* tile size per task */ + unsigned int dm, dn, dk; + unsigned int itypesize, otypesize; + /* number of tasks per direction */ + unsigned int mt, nt, kt; + int gemm_flags, flags; +}; + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_mmbatch_item { + struct { + const void *a, *b; + void *c; + } value; + struct { + libxsmm_gemm_descriptor desc; + unsigned int count; + const char* symbol; + } stat; + /* TODO: consider padding */ +} libxsmm_mmbatch_item; + +LIBXSMM_API void libxsmm_gemm_internal_set_batchflag(libxsmm_gemm_descriptor* descriptor, void* c, libxsmm_blasint index_stride, + libxsmm_blasint batchsize, int multithreaded); + +LIBXSMM_API int libxsmm_mmbatch_kernel(libxsmm_xmmfunction kernel, libxsmm_blasint index_base, + libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + const void* a, const void* b, void* c, libxsmm_blasint batchsize, /*unsigned*/int tid, /*unsigned*/int ntasks, + unsigned char itypesize, unsigned char otypesize, int flags); + +LIBXSMM_API int libxsmm_mmbatch_blas( + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const void* beta, void* c, const libxsmm_blasint* ldc, + libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize); + +LIBXSMM_API_INTERN void libxsmm_dmmbatch_blas(const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const double* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const double* beta, void* c, const libxsmm_blasint* ldc, + libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize); + +LIBXSMM_API_INTERN void libxsmm_smmbatch_blas(const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const float* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const float* beta, void* c, const libxsmm_blasint* ldc, + libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[], + libxsmm_blasint batchsize); + +LIBXSMM_EXTERN_C typedef void (*libxsmm_mmbatch_flush_function)(void); + +/** auto-batch descriptor (filter). */ +LIBXSMM_APIVAR_PUBLIC(libxsmm_gemm_descriptor libxsmm_mmbatch_desc); +/** Records a batch of SMMs or is used for batch-reduce. */ +LIBXSMM_APIVAR_PUBLIC(void* libxsmm_mmbatch_array); +/** Lock: libxsmm_mmbatch_begin, libxsmm_mmbatch_end, internal_mmbatch_flush. */ +LIBXSMM_APIVAR_PUBLIC(LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) libxsmm_mmbatch_lock); +/** Maximum size of the recorded batch. */ +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_mmbatch_size); +/** Maximum number of parallelized batch-groups. */ +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_gemm_npargroups); +/** Minimum batchsize per thread/task. */ +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_gemm_taskgrain); +/** Determines if OpenMP tasks are used. */ +LIBXSMM_APIVAR_PUBLIC(int libxsmm_gemm_tasks); +/** + * Intercepted GEMM + * - [>=1 and odd]: sequential and non-tiled (small problem sizes only) + * - [>=2 and even]: parallelized and tiled (all problem sizes) + * - [>=3 and odd]: GEMV is intercepted; small problem sizes + * - [>=4 and even]: GEMV is intercepted; all problem sizes + * - [0]: disabled + */ +LIBXSMM_APIVAR_PUBLIC(int libxsmm_gemm_wrap); + +/** Determines the default prefetch strategy, which is used in case of LIBXSMM_PREFETCH_AUTO. */ +LIBXSMM_APIVAR_PRIVATE(libxsmm_gemm_prefetch_type libxsmm_gemm_auto_prefetch_default); +/** Determines the prefetch strategy, which is used in case of LIBXSMM_PREFETCH_AUTO. */ +LIBXSMM_APIVAR_PRIVATE(libxsmm_gemm_prefetch_type libxsmm_gemm_auto_prefetch); + +#endif /*LIBXSMM_GEMM_H*/ + diff --git a/third_party/libxsmm/src/libxsmm_generator.c b/third_party/libxsmm/src/libxsmm_generator.c new file mode 100644 index 0000000000000000000000000000000000000000..4d76d8ee246a1d8e92015de79c5575b43eb34069 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_generator.c @@ -0,0 +1,530 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_main.h" + +#if !defined(LIBXSMM_PRODUCT_LIMIT) +# define LIBXSMM_PRODUCT_LIMIT 1024 +#endif + + +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_intrinsics_mm512_rng_state0[16]); +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_intrinsics_mm512_rng_state1[16]); +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_intrinsics_mm512_rng_state2[16]); +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_intrinsics_mm512_rng_state3[16]); + +/* definition of corresponding variables */ +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_ninit); +LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_target_archid); +LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_verbosity); +LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_se); + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_dgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + double alpha, double beta, int flags, int prefetch) +{ + union { + libxsmm_gemm_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + if (LIBXSMM_GEMM_NO_BYPASS(flags, alpha, beta) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(lda, ldb, ldc) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(m, n, k)) + { + result.blob = blob; + LIBXSMM_GEMM_DESCRIPTOR(*result.ptr, LIBXSMM_GEMM_PRECISION(double), + flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); + } + else { /* quiet error (unsupported) */ + result.ptr = NULL; + } + return result.ptr; +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_sgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + float alpha, float beta, int flags, int prefetch) +{ + union { + libxsmm_gemm_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + if (LIBXSMM_GEMM_NO_BYPASS(flags, alpha, beta) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(lda, ldb, ldc) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(m, n, k)) + { + result.blob = blob; + LIBXSMM_GEMM_DESCRIPTOR(*result.ptr, LIBXSMM_GEMM_PRECISION(float), + flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); + } + else { /* unsupported */ + result.ptr = NULL; + } + return result.ptr; +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_wigemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + int alpha, int beta, int flags, int prefetch) +{ + union { + libxsmm_gemm_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + if (LIBXSMM_GEMM_NO_BYPASS(flags, alpha, beta) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(lda, ldb, ldc) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(m, n, k)) + { + result.blob = blob; + LIBXSMM_GEMM_DESCRIPTOR2(*result.ptr, LIBXSMM_GEMM_PRECISION(short), LIBXSMM_GEMM_PRECISION(int), + flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); + } + else { /* unsupported */ + result.ptr = NULL; + } + return result.ptr; +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_bsgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + float alpha, float beta, int flags, int prefetch) +{ + union { + libxsmm_gemm_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + if (LIBXSMM_GEMM_NO_BYPASS(flags, alpha, beta) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(lda, ldb, ldc) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(m, n, k)) + { + result.blob = blob; + LIBXSMM_GEMM_DESCRIPTOR2(*result.ptr, LIBXSMM_GEMM_PRECISION(libxsmm_bfloat16), LIBXSMM_GEMM_PRECISION(float), + flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); + } + else { /* unsupported */ + result.ptr = NULL; + } + return result.ptr; +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_bgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + float alpha, float beta, int flags, int prefetch) +{ + union { + libxsmm_gemm_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + if (LIBXSMM_GEMM_NO_BYPASS(flags, alpha, beta) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(lda, ldb, ldc) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(m, n, k)) + { + result.blob = blob; + LIBXSMM_GEMM_DESCRIPTOR2(*result.ptr, LIBXSMM_GEMM_PRECISION(libxsmm_bfloat16), LIBXSMM_GEMM_PRECISION(libxsmm_bfloat16), + flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); + } + else { /* unsupported */ + result.ptr = NULL; + } + return result.ptr; +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_bigemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + int alpha, int beta, int flags, int prefetch) +{ + union { + libxsmm_gemm_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + if (LIBXSMM_GEMM_NO_BYPASS(flags, alpha, beta) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(lda, ldb, ldc) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(m, n, k)) + { + result.blob = blob; + LIBXSMM_GEMM_DESCRIPTOR2(*result.ptr, LIBXSMM_GEMM_PRECISION(char), LIBXSMM_GEMM_PRECISION(int), + flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); + } + else { /* unsupported */ + result.ptr = NULL; + } + return result.ptr; +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_bbgemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, + int alpha, int beta, int flags, int prefetch) +{ + union { + libxsmm_gemm_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + if (LIBXSMM_GEMM_NO_BYPASS(flags, alpha, beta) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(lda, ldb, ldc) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(m, n, k)) + { + result.blob = blob; + LIBXSMM_GEMM_DESCRIPTOR2(*result.ptr, LIBXSMM_GEMM_PRECISION(char), LIBXSMM_GEMM_PRECISION(char), + flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); + } + else { /* unsupported */ + result.ptr = NULL; + } + return result.ptr; +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_dinit(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision precision, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, double alpha, double beta, + int flags, int prefetch) +{ + return libxsmm_gemm_descriptor_dinit2(blob, precision, precision, m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch); +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_dinit2(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, double alpha, double beta, + int flags, int prefetch) +{ + union { + libxsmm_gemm_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + if (LIBXSMM_GEMM_NO_BYPASS(flags, alpha, beta) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(lda, ldb, ldc) + && LIBXSMM_GEMM_NO_BYPASS_DIMS(m, n, k)) + { + result.blob = blob; + /* Note: iprec/oprec combination is not checked to omit type-switch (invalid combination may result in BE-error) */ + LIBXSMM_GEMM_DESCRIPTOR2(*result.ptr, iprec, oprec, flags, m, n, k, lda, ldb, ldc, alpha, beta, prefetch); + } + else { /* quiet error (unsupported) */ + result.ptr = NULL; + } + return result.ptr; +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision precision, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, const void* alpha, const void* beta, + int flags, int prefetch) +{ + return libxsmm_gemm_descriptor_init2(blob, precision, precision, m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch); +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_init2(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, const void* alpha, const void* beta, + int flags, int prefetch) +{ + return libxsmm_gemm_descriptor_init3(blob, iprec, oprec, m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch, + NULL/*dalpha*/, NULL/*dbeta*/); +} + + +LIBXSMM_API libxsmm_gemm_descriptor* libxsmm_gemm_descriptor_init3(libxsmm_descriptor_blob* blob, + libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_blasint lda, libxsmm_blasint ldb, libxsmm_blasint ldc, const void* alpha, const void* beta, + int flags, int prefetch, double* dalpha, double* dbeta) +{ + /* avoid warning about potentially uninitialized variable (initialize outside of control flow) */ + libxsmm_gemm_descriptor* result = NULL; + switch (iprec) { + case LIBXSMM_GEMM_PRECISION_F64: { + const double aa = (NULL != alpha ? *((const double*)alpha) : (LIBXSMM_ALPHA)); + const double bb = (NULL != beta ? *((const double*)beta) : (LIBXSMM_BETA)); + LIBXSMM_ASSERT(LIBXSMM_GEMM_PRECISION_F64 == oprec); + result = libxsmm_dgemm_descriptor_init(blob, m, n, k, lda, ldb, ldc, aa, bb, flags, prefetch); + if (NULL != dalpha) *dalpha = aa; + if (NULL != dbeta) *dbeta = bb; + } break; + case LIBXSMM_GEMM_PRECISION_F32: { + const float aa = (NULL != alpha ? *((const float*)alpha) : (LIBXSMM_ALPHA)); + const float bb = (NULL != beta ? *((const float*)beta) : (LIBXSMM_BETA)); + LIBXSMM_ASSERT(LIBXSMM_GEMM_PRECISION_F32 == oprec); + result = libxsmm_sgemm_descriptor_init(blob, m, n, k, lda, ldb, ldc, aa, bb, flags, prefetch); + if (NULL != dalpha) *dalpha = (double)aa; + if (NULL != dbeta) *dbeta = (double)bb; + } break; + case LIBXSMM_GEMM_PRECISION_I16: { + /** + * Take alpha and beta as short data although wgemm works on integers. + * However, alpha and beta are only JIT-supported for certain values, + * and the call-side may not distinct different input and output types + * (integer/short), hence it is safer to only read short data. + */ + const short aa = (short)(NULL != alpha ? *((const short*)alpha) : (LIBXSMM_ALPHA)); + const short bb = (short)(NULL != beta ? *((const short*)beta) : (LIBXSMM_BETA)); + LIBXSMM_ASSERT(LIBXSMM_GEMM_PRECISION_I32 == oprec); + result = libxsmm_wigemm_descriptor_init(blob, m, n, k, lda, ldb, ldc, aa, bb, flags, prefetch); + if (NULL != dalpha) *dalpha = (double)aa; + if (NULL != dbeta) *dbeta = (double)bb; + } break; + case LIBXSMM_GEMM_PRECISION_I8: { + /** + * Take alpha and beta as short data although wgemm works on integers. + * However, alpha and beta are only JIT-supported for certain values, + * and the call-side may not distinct different input and output types + * (integer/short), hence it is safer to only read short data. + */ + if (LIBXSMM_GEMM_PRECISION_I32 == oprec) { + const short aa = (short)(NULL != alpha ? *((const short*)alpha) : (LIBXSMM_ALPHA)); + const short bb = (short)(NULL != beta ? *((const short*)beta) : (LIBXSMM_BETA)); + result = libxsmm_bigemm_descriptor_init(blob, m, n, k, lda, ldb, ldc, aa, bb, flags, prefetch); + if (NULL != dalpha) *dalpha = (double)aa; + if (NULL != dbeta) *dbeta = (double)bb; + } + else if (LIBXSMM_GEMM_PRECISION_I8 == oprec) { + const short aa = (short)(NULL != alpha ? *((const short*)alpha) : (LIBXSMM_ALPHA)); + const short bb = (short)(NULL != beta ? *((const short*)beta) : (LIBXSMM_BETA)); + result = libxsmm_bbgemm_descriptor_init(blob, m, n, k, lda, ldb, ldc, aa, bb, flags, prefetch); + if (NULL != dalpha) *dalpha = (double)aa; + if (NULL != dbeta) *dbeta = (double)bb; + } + } break; + case LIBXSMM_GEMM_PRECISION_BF16: { + if (LIBXSMM_GEMM_PRECISION_F32 == oprec) { + const float aa = (NULL != alpha ? *((const float*)alpha) : (LIBXSMM_ALPHA)); + const float bb = (NULL != beta ? *((const float*)beta) : (LIBXSMM_BETA)); + result = libxsmm_bsgemm_descriptor_init(blob, m, n, k, lda, ldb, ldc, aa, bb, flags, prefetch); + if (NULL != dalpha) *dalpha = (double)aa; + if (NULL != dbeta) *dbeta = (double)bb; + } + else if (LIBXSMM_GEMM_PRECISION_BF16 == oprec) { + const float aa = (NULL != alpha ? *((const float*)alpha) : (LIBXSMM_ALPHA)); + const float bb = (NULL != beta ? *((const float*)beta) : (LIBXSMM_BETA)); + result = libxsmm_bgemm_descriptor_init(blob, m, n, k, lda, ldb, ldc, aa, bb, flags, prefetch); + if (NULL != dalpha) *dalpha = (double)aa; + if (NULL != dbeta) *dbeta = (double)bb; + } + } break; + default: /* result remains NULL */; + } + if (NULL == result) { + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: GEMM precision is not supported!\n"); + } + } + return result; +} + + +LIBXSMM_API libxsmm_meltw_descriptor* libxsmm_meltw_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_datatype in_type, libxsmm_datatype out_type, + libxsmm_blasint m, libxsmm_blasint n, + libxsmm_blasint ldi, libxsmm_blasint ldo, + unsigned short flags, unsigned char param, unsigned char operation) +{ + union { + libxsmm_meltw_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + LIBXSMM_DESCRIPTOR_CLEAR(blob); + result.blob = blob; + result.ptr->datatype = (unsigned char)LIBXSMM_GETENUM(in_type, out_type); + result.ptr->datatype2 = 0; + result.ptr->flags = (unsigned short)flags; + result.ptr->operation = (unsigned char)operation; + result.ptr->param = (unsigned char)param; + result.ptr->ldi = ldi; + result.ptr->ldo = ldo; + result.ptr->ldi2 = 0; + result.ptr->ldi3 = 0; + result.ptr->m = m; + result.ptr->n = n; + return result.ptr; +} + + +LIBXSMM_API libxsmm_meltw_descriptor* libxsmm_meltw_descriptor_init2(libxsmm_descriptor_blob* blob, + libxsmm_datatype in_type, libxsmm_datatype in2_type, libxsmm_datatype out_type, libxsmm_datatype out2_type, + libxsmm_blasint m, libxsmm_blasint n, + libxsmm_blasint ldi, libxsmm_blasint ldo, libxsmm_blasint ldi2, libxsmm_blasint ldi3, + unsigned short flags, unsigned char param, unsigned char operation) +{ + union { + libxsmm_meltw_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + LIBXSMM_DESCRIPTOR_CLEAR(blob); + result.blob = blob; + result.ptr->datatype = (unsigned char)LIBXSMM_GETENUM(in_type, out_type); + result.ptr->datatype2 = (unsigned char)LIBXSMM_GETENUM(in2_type, out2_type); + result.ptr->flags = (unsigned short)flags; + result.ptr->operation = (unsigned char)operation; + result.ptr->param = (unsigned char)param; + result.ptr->ldi = ldi; + result.ptr->ldo = ldo; + result.ptr->ldi2 = ldi2; + result.ptr->ldi3 = ldi3; + result.ptr->m = m; + result.ptr->n = n; + return result.ptr; +} + + +LIBXSMM_API libxsmm_meqn_descriptor* libxsmm_meqn_descriptor_init(libxsmm_descriptor_blob* blob, + libxsmm_datatype out_type, libxsmm_blasint m, libxsmm_blasint n, + libxsmm_blasint ldo, unsigned int eqn_idx) +{ + union { + libxsmm_meqn_descriptor* ptr; + libxsmm_descriptor_blob* blob; + } result; + LIBXSMM_DESCRIPTOR_CLEAR(blob); + result.blob = blob; + result.ptr->datatype = (unsigned char)LIBXSMM_GETENUM( LIBXSMM_DATATYPE_UNSUPPORTED, out_type); + result.ptr->eqn_idx = eqn_idx; + result.ptr->ldo = ldo; + result.ptr->m = m; + result.ptr->n = n; + return result.ptr; +} + + +LIBXSMM_API size_t libxsmm_gcd(size_t a, size_t b) +{ + while (0 != b) { + const size_t r = a % b; + a = b; b = r; + } + return 0 != a ? a : 1; +} + + +LIBXSMM_API size_t libxsmm_lcm(size_t a, size_t b) +{ + const size_t gcd = libxsmm_gcd(a, b); + return 0 != gcd ? ((a / gcd) * b) : 0; +} + + +LIBXSMM_API int libxsmm_primes_u32(unsigned int num, unsigned int num_factors_n32[]) +{ + unsigned int c = num, i; + int n = 0; + if (0 < c && 0 == (c & 1)) { /* non-zero even */ + unsigned int j = c / 2; + while (c == (2 * j)) { + num_factors_n32[n++] = 2; + c = j; j /= 2; + } + } + for (i = 3; i <= c; i += 2) { + unsigned int j = c / i; + while (c == (i * j)) { + num_factors_n32[n++] = i; + c = j; j /= i; + } + if ((i * i) > num) { + break; + } + } + if (1 < c && 0 != n) { + num_factors_n32[n++] = c; + } + return n; +} + + +LIBXSMM_API_INLINE unsigned int internal_product_limit(unsigned int product, unsigned int limit) +{ + unsigned int fact[32], maxp = limit, result = 1; + int i, n; + /* attempt to lower the memory requirement for DP; can miss best solution */ + if (LIBXSMM_PRODUCT_LIMIT < limit) { + const unsigned int minfct = (limit + limit - 1) / LIBXSMM_PRODUCT_LIMIT; + const unsigned int maxfct = (unsigned int)libxsmm_gcd(product, limit); + result = maxfct; + if (minfct < maxfct) { + n = libxsmm_primes_u32(result, fact); + for (i = 0; i < n; ++i) { + if (minfct < fact[i]) { + result = fact[i]; + break; + } + } + } + maxp /= result; + } + if (LIBXSMM_PRODUCT_LIMIT >= maxp) { + unsigned int k[2][LIBXSMM_PRODUCT_LIMIT], *k0 = k[0], *k1 = k[1], *kt, p; + n = libxsmm_primes_u32(product / result, fact); + /* initialize table with trivial factor */ + for (p = 0; p <= maxp; ++p) k[0][p] = 1; + k[0][0] = k[1][0] = 1; + for (i = 1; i <= n; ++i) { + for (p = 1; p <= maxp; ++p) { + const unsigned int f = fact[i - 1], h = k0[p]; + if (p < f) { + k1[p] = h; + } + else { + const unsigned int g = f * k0[p / f]; + k1[p] = LIBXSMM_MAX(g, h); + } + } + kt = k0; k0 = k1; k1 = kt; + } + result *= k0[maxp]; + } + else { /* trivial approximation */ + n = libxsmm_primes_u32(product, fact); + for (i = 0; i < n; ++i) { + const unsigned int f = result * fact[i]; + if (f <= limit) { + result = f; + } + else break; + } + } + return result; +} + + +LIBXSMM_API unsigned int libxsmm_product_limit(unsigned int product, unsigned int limit, int is_lower) +{ + unsigned int result; + if (1 < limit) { /* check for fast-path */ + result = internal_product_limit(product, limit); + } + else { + result = limit; + } + if (0 != is_lower && limit < product) { + if (result < limit) { + result = internal_product_limit(product, 2 * limit - 1); + } + if (result < limit) { + result = product; + } + LIBXSMM_ASSERT(limit <= result); + } + if (product < result) { + result = product; + } + LIBXSMM_ASSERT(result <= product); + return result; +} + diff --git a/third_party/libxsmm/src/libxsmm_generator_gemm_driver.c b/third_party/libxsmm/src/libxsmm_generator_gemm_driver.c new file mode 100644 index 0000000000000000000000000000000000000000..d83c6df8a852c4533656eb1b076f09db07e21ce6 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_generator_gemm_driver.c @@ -0,0 +1,280 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include + + +LIBXSMM_INLINE void print_help(void) { + printf("\nwrong usage -> exit!\n\n\n"); + printf("Usage (sparse*dense=dense, dense*sparse=dense):\n"); + printf(" sparse, sparse_csr, sparse_csr_reg\n"); + printf(" filename to append\n"); + printf(" routine name\n"); + printf(" M\n"); + printf(" N\n"); + printf(" K\n"); + printf(" LDA (if < 1 --> A sparse)\n"); + printf(" LDB (if < 1 --> B sparse)\n"); + printf(" LDC\n"); + printf(" alpha: 1\n"); + printf(" beta: 0 or 1\n"); + printf(" 0: unaligned A, otherwise aligned (ignored for sparse)\n"); + printf(" 0: unaligned C, otherwise aligned (ignored for sparse)\n"); + printf(" ARCH: noarch, wsm, snb, hsw, knl, knm, skx, clx, cpx\n"); + printf(" PREFETCH: nopf (none), pfsigonly, other options fallback to pfsigonly\n"); + printf(" PRECISION: SP, DP\n"); + printf(" matrix input (CSC mtx file)\n"); + printf("\n\n"); + printf("Usage (dense*dense=dense):\n"); + printf(" dense, dense_asm\n"); + printf(" filename to append\n"); + printf(" routine name\n"); + printf(" M\n"); + printf(" N\n"); + printf(" K\n"); + printf(" LDA\n"); + printf(" LDB\n"); + printf(" LDC\n"); + printf(" alpha: -1 or 1\n"); + printf(" beta: 0 or 1\n"); + printf(" 0: unaligned A, otherwise aligned\n"); + printf(" 0: unaligned C, otherwise aligned\n"); + printf(" ARCH: noarch, wsm, snb, hsw, knl, knm, skx, clx, cpx\n"); + printf(" PREFETCH: nopf (none), pfsigonly, BL2viaC, AL2, curAL2,\n" + " AL2_BL2viaC, curAL2_BL2viaC,\n"); + printf(" PRECISION: I16, SP, DP\n"); + printf("\n\n\n\n"); +} + +int main(int argc, char* argv []) { + const libxsmm_gemm_descriptor* l_xgemm_desc = 0; + int l_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + libxsmm_gemm_prefetch_type l_prefetch; + libxsmm_descriptor_blob l_xgemm_blob; + char* l_type; + char* l_file_out; + char* l_matrix_file_in; + char* l_routine_name; + char* l_arch; + char* l_precision; + int l_m = 0; + int l_n = 0; + int l_k = 0; + int l_lda = 0; + int l_ldb = 0; + int l_ldc = 0; + int l_aligned_a = 0; + int l_aligned_c = 0; + double l_alpha = 0; + double l_beta = 0; + int l_single_precision = 0; + int l_is_csr = 0; + + /* check argument count for a valid range */ + if (argc != 17 && argc != 18) { + print_help(); + return EXIT_FAILURE; + } + + /* names of files and routines */ + l_type = argv[1]; + l_file_out = argv[2]; + l_routine_name = argv[3]; + + /* xgemm sizes */ + l_m = atoi(argv[4]); + l_n = atoi(argv[5]); + l_k = atoi(argv[6]); + l_lda = atoi(argv[7]); + l_ldb = atoi(argv[8]); + l_ldc = atoi(argv[9]); + + /* condense < 1 to 0 for lda and ldb */ + if ( l_lda < 1 ) + l_lda = 0; + if ( l_ldb < 1 ) + l_ldb = 0; + + /* some sugar */ + l_alpha = atof(argv[10]); + l_beta = atof(argv[11]); + l_aligned_a = atoi(argv[12]); + l_aligned_c = atoi(argv[13]); + + l_flags |= (0 != l_aligned_a ? LIBXSMM_GEMM_FLAG_ALIGN_A : 0); + l_flags |= (0 != l_aligned_c ? LIBXSMM_GEMM_FLAG_ALIGN_C : 0); + + /* arch specific stuff */ + l_arch = argv[14]; + l_precision = argv[16]; + + /* some initial parameters checks */ + /* check for sparse / dense only */ + if ( (strcmp(l_type, "sparse") != 0) && + (strcmp(l_type, "sparse_csr") != 0) && + (strcmp(l_type, "sparse_csr_reg") != 0) && + (strcmp(l_type, "dense") != 0) && + (strcmp(l_type, "dense_asm") != 0) ) { + print_help(); + return EXIT_FAILURE; + } + + /* check for the right number of arguments depending on type */ + if ( ( (strcmp(l_type, "sparse") == 0) && (argc != 18) ) || + ( (strcmp(l_type, "sparse_csr") == 0) && (argc != 18) ) || + ( (strcmp(l_type, "sparse_csr_reg") == 0) && (argc != 18) ) || + ( (strcmp(l_type, "dense") == 0) && (argc != 17) ) || + ( (strcmp(l_type, "dense_asm") == 0) && (argc != 17) ) ) { + print_help(); + return EXIT_FAILURE; + } + + /* set value of prefetch flag */ + if (strcmp("nopf", argv[15]) == 0) { + l_prefetch = LIBXSMM_GEMM_PREFETCH_NONE; + } + else if (strcmp("pfsigonly", argv[15]) == 0) { + l_prefetch = LIBXSMM_GEMM_PREFETCH_SIGONLY; + } + else if (strcmp("BL2viaC", argv[15]) == 0) { + l_prefetch = LIBXSMM_GEMM_PREFETCH_BL2_VIA_C; + } + else if (strcmp("curAL2", argv[15]) == 0) { + l_prefetch = LIBXSMM_GEMM_PREFETCH_AL2_AHEAD; + } + else if (strcmp("curAL2_BL2viaC", argv[15]) == 0) { + l_prefetch = LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD; + } + else if (strcmp("AL2", argv[15]) == 0) { + l_prefetch = LIBXSMM_GEMM_PREFETCH_AL2; + } + else if (strcmp("AL2_BL2viaC", argv[15]) == 0) { + l_prefetch = LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C; + } + else { + print_help(); + return EXIT_FAILURE; + } + + /* check value of arch flag */ + if ( (strcmp(l_arch, "wsm") != 0) && + (strcmp(l_arch, "snb") != 0) && + (strcmp(l_arch, "hsw") != 0) && + (strcmp(l_arch, "knl") != 0) && + (strcmp(l_arch, "knm") != 0) && + (strcmp(l_arch, "skx") != 0) && + (strcmp(l_arch, "clx") != 0) && + (strcmp(l_arch, "cpx") != 0) && + (strcmp(l_arch, "noarch") != 0) ) { + print_help(); + return EXIT_FAILURE; + } + + /* check and evaluate precision flag */ + if ( strcmp(l_precision, "SP") == 0 ) { + l_single_precision = 1; + } else if ( strcmp(l_precision, "DP") == 0 ) { + l_single_precision = 0; + } else if ( strcmp(l_precision, "I16") == 0 ) { + l_single_precision = 2; + } else { + print_help(); + return EXIT_FAILURE; + } + + /* check alpha */ + if ((l_alpha < -1 || -1 < l_alpha) && (l_alpha < 1 || 1 < l_alpha)) { + print_help(); + return EXIT_FAILURE; + } + + /* check beta */ + if ((l_beta < 0 || 0 < l_beta) && (l_beta < 1 || 1 < l_beta)) { + print_help(); + return EXIT_FAILURE; + } + + switch (l_single_precision) { + case 0: { + l_xgemm_desc = libxsmm_gemm_descriptor_dinit(&l_xgemm_blob, LIBXSMM_GEMM_PRECISION_F64, + l_m, l_n, l_k, l_lda, l_ldb, l_ldc, l_alpha, l_beta, l_flags, l_prefetch); + } break; + case 1: { + l_xgemm_desc = libxsmm_gemm_descriptor_dinit(&l_xgemm_blob, LIBXSMM_GEMM_PRECISION_F32, + l_m, l_n, l_k, l_lda, l_ldb, l_ldc, l_alpha, l_beta, l_flags, l_prefetch); + } break; + case 2: { + l_xgemm_desc = libxsmm_gemm_descriptor_dinit(&l_xgemm_blob, LIBXSMM_GEMM_PRECISION_I16, + l_m, l_n, l_k, l_lda, l_ldb, l_ldc, l_alpha, l_beta, l_flags, l_prefetch); + } break; + default: { + print_help(); + return EXIT_FAILURE; + } + } + + if (NULL == l_xgemm_desc) { + print_help(); + return EXIT_FAILURE; + } + + if ( strcmp(l_type, "sparse") == 0 || strcmp(l_type, "sparse_csr") == 0 || + strcmp(l_type, "sparse_csr_reg") == 0 ) { + /* read additional parameter for CSC/CSR description */ + l_matrix_file_in = argv[17]; + + /* some more restrictive checks are needed in case of sparse */ + if ( (l_alpha < 1) || (1 < l_alpha) ) { + print_help(); + return EXIT_FAILURE; + } + + if (l_lda < 1 && l_ldb < 1) { + print_help(); + return EXIT_FAILURE; + } + + if (l_ldc < 1) { + print_help(); + return EXIT_FAILURE; + } + + if ( l_single_precision > 1 ) { + print_help(); + return EXIT_FAILURE; + } + + if ( strcmp(l_type, "sparse_csr") == 0 ) { + l_is_csr = 1; + } + if ( strcmp(l_type, "sparse_csr_reg") == 0 ) { + l_is_csr = 3; + } + + libxsmm_generator_spgemm( l_file_out, l_routine_name, l_xgemm_desc, l_arch, l_matrix_file_in, l_is_csr ); + } + + if ( (strcmp(l_type, "dense") == 0) || + (strcmp(l_type, "dense_asm") == 0) ) { + if (l_lda < 1 || l_ldb < 1 || l_ldc < 1) { + print_help(); + return EXIT_FAILURE; + } + + if ( strcmp(l_type, "dense") == 0 ) { + libxsmm_generator_gemm_inlineasm( l_file_out, l_routine_name, l_xgemm_desc, l_arch ); + } else { + libxsmm_generator_gemm_directasm( l_file_out, l_routine_name, l_xgemm_desc, l_arch ); + } + } + + return EXIT_SUCCESS; +} + diff --git a/third_party/libxsmm/src/libxsmm_hash.c b/third_party/libxsmm/src/libxsmm_hash.c new file mode 100644 index 0000000000000000000000000000000000000000..8f3289c709f441cf07cbe8b3a8aef519ab03d3a6 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_hash.c @@ -0,0 +1,595 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_hash.h" +#include "libxsmm_main.h" + +#if !defined(LIBXSMM_HASH_ALIGNMENT) +# define LIBXSMM_HASH_ALIGNMENT 8 +#endif + +#define LIBXSMM_HASH_U64(FN, SEED, BEGIN, END) { \ + const uint8_t *const end = (NULL != (END) ? ((END) - 7) : NULL); \ + for (; (BEGIN) < end; (BEGIN) += 8) { LIBXSMM_ASSERT(NULL != (BEGIN) || NULL == (END)); \ + SEED = (uint32_t)FN(SEED, BEGIN); \ + } \ +} +#define LIBXSMM_HASH_U32(FN, SEED, BEGIN, END) { \ + const uint8_t *const next = (BEGIN) + 4; \ + if (next <= (END)) { LIBXSMM_ASSERT(NULL != (BEGIN) || NULL == (END)); \ + SEED = FN(SEED, BEGIN); BEGIN = next; \ + } \ +} +#define LIBXSMM_HASH_U16(FN, SEED, BEGIN, END) { \ + const uint8_t *const next = (BEGIN) + 2; \ + if (next <= (END)) { LIBXSMM_ASSERT(NULL != (BEGIN) || NULL == (END)); \ + SEED = FN(SEED, BEGIN); BEGIN = next; \ + } \ +} +#define LIBXSMM_HASH_U8(FN, SEED, BEGIN, END) { \ + if ((BEGIN) < (END)) { LIBXSMM_ASSERT(NULL != (BEGIN) || NULL == (END)); \ + SEED = FN(SEED, BEGIN); ++(BEGIN); \ + } \ +} + +#define LIBXSMM_HASH_CRC32_U8(SEED, PVALUE) _mm_crc32_u8(SEED, *(const uint8_t*)(PVALUE)) +#define LIBXSMM_HASH_CRC32_U16(SEED, PVALUE) _mm_crc32_u16(SEED, *(const uint16_t*)(PVALUE)) +#define LIBXSMM_HASH_CRC32_U32(SEED, PVALUE) _mm_crc32_u32(SEED, *(const uint32_t*)(PVALUE)) + +#if (64 > (LIBXSMM_BITS)) || defined(__PGI) +# define LIBXSMM_HASH_CRC32_U64(SEED, PVALUE) \ + LIBXSMM_HASH_CRC32_U32(LIBXSMM_HASH_CRC32_U32((uint32_t)(SEED), PVALUE), (const uint32_t*)(PVALUE) + 1) +#else +# define LIBXSMM_HASH_CRC32_U64(SEED, PVALUE) _mm_crc32_u64(SEED, *(const uint64_t*)(PVALUE)) +#endif + +#define LIBXSMM_HASH_UNALIGNED(FN64, FN32, FN16, FN8, SEED, DATA, SIZE) { \ + const uint8_t *begin = (const uint8_t*)(DATA); \ + const uint8_t *const endb = begin + (SIZE); \ + LIBXSMM_HASH_U64(FN64, SEED, begin, endb); \ + LIBXSMM_HASH_U32(FN32, SEED, begin, endb); \ + LIBXSMM_HASH_U16(FN16, SEED, begin, endb); \ + return begin == endb ? (SEED) : FN8(SEED, begin); \ +} + +#if defined(LIBXSMM_HASH_ALIGNMENT) && 8 < (LIBXSMM_HASH_ALIGNMENT) +# define LIBXSMM_HASH(FN64, FN32, FN16, FN8, SEED, DATA, SIZE) { \ + const uint8_t *begin = (const uint8_t*)(DATA); \ + const uint8_t *const endb = begin + (SIZE); \ + const uint8_t *const enda = LIBXSMM_ALIGN(begin, LIBXSMM_HASH_ALIGNMENT); \ + if ((SIZE) > (size_t)(endb - enda)) { \ + LIBXSMM_HASH_U64(FN64, SEED, begin, enda); \ + LIBXSMM_HASH_U32(FN32, SEED, begin, enda); \ + LIBXSMM_HASH_U16(FN16, SEED, begin, enda); \ + LIBXSMM_HASH_U8(FN8, SEED, begin, enda); \ + } \ + LIBXSMM_ASSUME_ALIGNED(begin, LIBXSMM_HASH_ALIGNMENT); \ + LIBXSMM_HASH_U64(FN64, SEED, begin, endb); \ + LIBXSMM_HASH_U32(FN32, SEED, begin, endb); \ + LIBXSMM_HASH_U16(FN16, SEED, begin, endb); \ + return begin == endb ? (SEED) : FN8(SEED, begin); \ + } +#elif defined(LIBXSMM_HASH_ALIGNMENT) && 1 < (LIBXSMM_HASH_ALIGNMENT) +# define LIBXSMM_HASH(FN64, FN32, FN16, FN8, SEED, DATA, SIZE) { \ + const uint8_t *begin = (const uint8_t*)(DATA); \ + const uint8_t *const endb = begin + (SIZE); \ + const uint8_t *const enda = LIBXSMM_ALIGN(begin, LIBXSMM_HASH_ALIGNMENT); \ + if ((SIZE) > (size_t)(endb - enda)) { \ + LIBXSMM_HASH_U32(FN32, SEED, begin, enda); \ + LIBXSMM_HASH_U16(FN16, SEED, begin, enda); \ + LIBXSMM_HASH_U8(FN8, SEED, begin, enda); \ + } \ + LIBXSMM_ASSUME_ALIGNED(begin, LIBXSMM_HASH_ALIGNMENT); \ + LIBXSMM_HASH_U64(FN64, SEED, begin, endb); \ + LIBXSMM_HASH_U32(FN32, SEED, begin, endb); \ + LIBXSMM_HASH_U16(FN16, SEED, begin, endb); \ + return begin == endb ? (SEED) : FN8(SEED, begin); \ + } +#else +# define LIBXSMM_HASH LIBXSMM_HASH_UNALIGNED +#endif + +typedef uint32_t internal_crc32_entry_type[256]; +LIBXSMM_APIVAR_DEFINE(const internal_crc32_entry_type* internal_crc32_table); +LIBXSMM_APIVAR_DEFINE(libxsmm_hash_function internal_hash_u32_function); +LIBXSMM_APIVAR_DEFINE(libxsmm_hash_function internal_hash_u64_function); +LIBXSMM_APIVAR_DEFINE(libxsmm_hash_function internal_hash_u128_function); +LIBXSMM_APIVAR_DEFINE(libxsmm_hash_function internal_hash_u256_function); +LIBXSMM_APIVAR_DEFINE(libxsmm_hash_function internal_hash_u384_function); +LIBXSMM_APIVAR_DEFINE(libxsmm_hash_function internal_hash_u512_function); +LIBXSMM_APIVAR_DEFINE(libxsmm_hash_function internal_hash_function); + + +LIBXSMM_API_INLINE unsigned int internal_crc32_u8(unsigned int seed, const void* value) +{ + const uint8_t *const pu8 = (const uint8_t*)value; + LIBXSMM_ASSERT(NULL != pu8 && NULL != internal_crc32_table); + return internal_crc32_table[0][(seed^(*pu8)) & 0xFF] ^ (seed >> 8); +} + + +LIBXSMM_API_INLINE unsigned int internal_crc32_u16(unsigned int seed, const void* value) +{ + const uint8_t *const pu8 = (const uint8_t*)value; + LIBXSMM_ASSERT(NULL != pu8); + seed = internal_crc32_u8(seed, pu8 + 0); + seed = internal_crc32_u8(seed, pu8 + 1); + return seed; +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u32(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int internal_crc32_u32(unsigned int seed, const void* value, ...) +{ + const uint32_t *const pu32 = (const uint32_t*)value; + uint32_t c0, c1, c2, c3, s; + LIBXSMM_ASSERT(NULL != pu32 && NULL != internal_crc32_table); + s = seed ^ (*pu32); + c0 = internal_crc32_table[0][(s >> 24) & 0xFF]; + c1 = internal_crc32_table[1][(s >> 16) & 0xFF]; + c2 = internal_crc32_table[2][(s >> 8) & 0xFF]; + c3 = internal_crc32_table[3][(s & 0xFF)]; + return (c0 ^ c1) ^ (c2 ^ c3); +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u32_sse4(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_SSE42) +unsigned int internal_crc32_u32_sse4(unsigned int seed, const void* value, ...) +{ +#if defined(LIBXSMM_INTRINSICS_SSE42) + return LIBXSMM_HASH_CRC32_U32(seed, value); +#else + return internal_crc32_u32(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u64(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int internal_crc32_u64(unsigned int seed, const void* value, ...) +{ + const uint32_t *const pu32 = (const uint32_t*)value; + LIBXSMM_ASSERT(NULL != pu32); + seed = internal_crc32_u32(seed, pu32 + 0); + seed = internal_crc32_u32(seed, pu32 + 1); + return seed; +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u64_sse4(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_SSE42) +unsigned int internal_crc32_u64_sse4(unsigned int seed, const void* value, ...) +{ +#if defined(LIBXSMM_INTRINSICS_SSE42) + return (unsigned int)LIBXSMM_HASH_CRC32_U64(seed, value); +#else + return internal_crc32_u64(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u128(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int internal_crc32_u128(unsigned int seed, const void* value, ...) +{ + const uint64_t *const pu64 = (const uint64_t*)value; + LIBXSMM_ASSERT(NULL != pu64); + seed = internal_crc32_u64(seed, pu64 + 0); + seed = internal_crc32_u64(seed, pu64 + 1); + return seed; +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u128_sse4(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_SSE42) +unsigned int internal_crc32_u128_sse4(unsigned int seed, const void* value, ...) +{ +#if defined(LIBXSMM_INTRINSICS_SSE42) + const uint64_t *const pu64 = (const uint64_t*)value; + LIBXSMM_ASSERT(NULL != pu64); + seed = (unsigned int)LIBXSMM_HASH_CRC32_U64(seed, pu64 + 0); + seed = (unsigned int)LIBXSMM_HASH_CRC32_U64(seed, pu64 + 1); +#else + seed = internal_crc32_u128(seed, value); +#endif + return seed; +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u256(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int internal_crc32_u256(unsigned int seed, const void* value, ...) +{ + const uint8_t *const pu8 = (const uint8_t*)value; + LIBXSMM_ASSERT(NULL != pu8); + seed = internal_crc32_u128(seed, pu8 + 0x00); + seed = internal_crc32_u128(seed, pu8 + 0x10); + return seed; +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u256_sse4(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_SSE42) +unsigned int internal_crc32_u256_sse4(unsigned int seed, const void* value, ...) +{ +#if defined(LIBXSMM_INTRINSICS_SSE42) + const uint8_t *const pu8 = (const uint8_t*)value; + LIBXSMM_ASSERT(NULL != pu8); + seed = internal_crc32_u128_sse4(seed, pu8 + 0x00); + seed = internal_crc32_u128_sse4(seed, pu8 + 0x10); + return seed; +#else + return internal_crc32_u256(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u384(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int internal_crc32_u384(unsigned int seed, const void* value, ...) +{ + const uint8_t *const pu8 = (const uint8_t*)value; + LIBXSMM_ASSERT(NULL != pu8); + seed = internal_crc32_u256(seed, pu8 + 0x00); + seed = internal_crc32_u128(seed, pu8 + 0x20); + return seed; +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u384_sse4(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_SSE42) +unsigned int internal_crc32_u384_sse4(unsigned int seed, const void* value, ...) +{ +#if defined(LIBXSMM_INTRINSICS_SSE42) + const uint8_t *const pu8 = (const uint8_t*)value; + LIBXSMM_ASSERT(NULL != pu8); + seed = internal_crc32_u256_sse4(seed, pu8 + 0x00); + seed = internal_crc32_u128_sse4(seed, pu8 + 0x20); + return seed; +#else + return internal_crc32_u384(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u512(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int internal_crc32_u512(unsigned int seed, const void* value, ...) +{ + const uint8_t *const pu8 = (const uint8_t*)value; + LIBXSMM_ASSERT(NULL != pu8); + seed = internal_crc32_u256(seed, pu8 + 0x00); + seed = internal_crc32_u256(seed, pu8 + 0x20); + return seed; +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_u512_sse4(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_SSE42) +unsigned int internal_crc32_u512_sse4(unsigned int seed, const void* value, ...) +{ +#if defined(LIBXSMM_INTRINSICS_SSE42) + const uint8_t *const pu8 = (const uint8_t*)value; + LIBXSMM_ASSERT(NULL != pu8); + seed = internal_crc32_u256_sse4(seed, pu8 + 0x00); + seed = internal_crc32_u256_sse4(seed, pu8 + 0x20); + return seed; +#else + return internal_crc32_u512(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32(unsigned int seed, const void* data, size_t size); +LIBXSMM_API_INTERN unsigned int internal_crc32(unsigned int seed, const void* data, size_t size) +{ + LIBXSMM_ASSERT(NULL != data || 0 == size); + LIBXSMM_HASH(internal_crc32_u64, internal_crc32_u32, internal_crc32_u16, internal_crc32_u8, seed, data, size); +} + + +LIBXSMM_API_INTERN unsigned int internal_crc32_sse4(unsigned int seed, const void* data, size_t size); +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_SSE42) +unsigned int internal_crc32_sse4(unsigned int seed, const void* data, size_t size) +{ + LIBXSMM_ASSERT(NULL != data || 0 == size); +#if defined(LIBXSMM_INTRINSICS_SSE42) + LIBXSMM_HASH(LIBXSMM_HASH_CRC32_U64, LIBXSMM_HASH_CRC32_U32, LIBXSMM_HASH_CRC32_U16, LIBXSMM_HASH_CRC32_U8, seed, data, size); +#else + return internal_crc32(seed, data, size); +#endif +} + + +LIBXSMM_API_INTERN void libxsmm_hash_init(int target_arch) +{ + /* table-based implementation taken from http://dpdk.org/. */ + static const LIBXSMM_RETARGETABLE internal_crc32_entry_type crc32_table[] = { + { /*table0*/ + 0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, + 0x8AD958CF, 0x78B2DBCC, 0x6BE22838, 0x9989AB3B, 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, + 0x105EC76F, 0xE235446C, 0xF165B798, 0x030E349B, 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, + 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57, 0x89D76C54, 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, + 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A, 0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, + 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, 0x6DFE410E, 0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, + 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, 0xF779DEAE, 0x05125DAD, 0x1642AE59, 0xE4292D5A, + 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, 0x7DA08661, 0x8FCB0562, 0x9C9BF696, 0x6EF07595, + 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957, + 0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, + 0x5125DAD3, 0xA34E59D0, 0xB01EAA24, 0x42752927, 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, + 0xDBFC821C, 0x2997011F, 0x3AC7F2EB, 0xC8AC71E8, 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, + 0x61C69362, 0x93AD1061, 0x80FDE395, 0x72966096, 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, + 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859, 0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, + 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, 0xB602C312, 0x44694011, 0x5739B3E5, 0xA55230E6, + 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, 0x3CDB9BDD, 0xCEB018DE, 0xDDE0EB2A, 0x2F8B6829, + 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, 0x456CAC67, 0xB7072F64, 0xA457DC90, 0x563C5F93, + 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C, + 0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, + 0x1871A4D8, 0xEA1A27DB, 0xF94AD42F, 0x0B21572C, 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, + 0xA24BB5A6, 0x502036A5, 0x4370C551, 0xB11B4652, 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, + 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, 0x3BC21E9D, 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, + 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D, 0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, + 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, 0xFF56BD19, 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, + 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, 0x0417B1DB, 0xF67C32D8, 0xE52CC12C, 0x1747422F, + 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, 0x8ECEE914, 0x7CA56A17, 0x6FF599E3, 0x9D9E1AE0, + 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540, + 0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, + 0xE330A81A, 0x115B2B19, 0x020BD8ED, 0xF0605BEE, 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, + 0x69E9F0D5, 0x9B8273D6, 0x88D28022, 0x7AB90321, 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, + 0xF36E6F75, 0x0105EC76, 0x12551F82, 0xE03E9C81, 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, + 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E, 0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351 + }, + { /*table1*/ + 0x00000000, 0x13A29877, 0x274530EE, 0x34E7A899, 0x4E8A61DC, 0x5D28F9AB, 0x69CF5132, 0x7A6DC945, + 0x9D14C3B8, 0x8EB65BCF, 0xBA51F356, 0xA9F36B21, 0xD39EA264, 0xC03C3A13, 0xF4DB928A, 0xE7790AFD, + 0x3FC5F181, 0x2C6769F6, 0x1880C16F, 0x0B225918, 0x714F905D, 0x62ED082A, 0x560AA0B3, 0x45A838C4, + 0xA2D13239, 0xB173AA4E, 0x859402D7, 0x96369AA0, 0xEC5B53E5, 0xFFF9CB92, 0xCB1E630B, 0xD8BCFB7C, + 0x7F8BE302, 0x6C297B75, 0x58CED3EC, 0x4B6C4B9B, 0x310182DE, 0x22A31AA9, 0x1644B230, 0x05E62A47, + 0xE29F20BA, 0xF13DB8CD, 0xC5DA1054, 0xD6788823, 0xAC154166, 0xBFB7D911, 0x8B507188, 0x98F2E9FF, + 0x404E1283, 0x53EC8AF4, 0x670B226D, 0x74A9BA1A, 0x0EC4735F, 0x1D66EB28, 0x298143B1, 0x3A23DBC6, + 0xDD5AD13B, 0xCEF8494C, 0xFA1FE1D5, 0xE9BD79A2, 0x93D0B0E7, 0x80722890, 0xB4958009, 0xA737187E, + 0xFF17C604, 0xECB55E73, 0xD852F6EA, 0xCBF06E9D, 0xB19DA7D8, 0xA23F3FAF, 0x96D89736, 0x857A0F41, + 0x620305BC, 0x71A19DCB, 0x45463552, 0x56E4AD25, 0x2C896460, 0x3F2BFC17, 0x0BCC548E, 0x186ECCF9, + 0xC0D23785, 0xD370AFF2, 0xE797076B, 0xF4359F1C, 0x8E585659, 0x9DFACE2E, 0xA91D66B7, 0xBABFFEC0, + 0x5DC6F43D, 0x4E646C4A, 0x7A83C4D3, 0x69215CA4, 0x134C95E1, 0x00EE0D96, 0x3409A50F, 0x27AB3D78, + 0x809C2506, 0x933EBD71, 0xA7D915E8, 0xB47B8D9F, 0xCE1644DA, 0xDDB4DCAD, 0xE9537434, 0xFAF1EC43, + 0x1D88E6BE, 0x0E2A7EC9, 0x3ACDD650, 0x296F4E27, 0x53028762, 0x40A01F15, 0x7447B78C, 0x67E52FFB, + 0xBF59D487, 0xACFB4CF0, 0x981CE469, 0x8BBE7C1E, 0xF1D3B55B, 0xE2712D2C, 0xD69685B5, 0xC5341DC2, + 0x224D173F, 0x31EF8F48, 0x050827D1, 0x16AABFA6, 0x6CC776E3, 0x7F65EE94, 0x4B82460D, 0x5820DE7A, + 0xFBC3FAF9, 0xE861628E, 0xDC86CA17, 0xCF245260, 0xB5499B25, 0xA6EB0352, 0x920CABCB, 0x81AE33BC, + 0x66D73941, 0x7575A136, 0x419209AF, 0x523091D8, 0x285D589D, 0x3BFFC0EA, 0x0F186873, 0x1CBAF004, + 0xC4060B78, 0xD7A4930F, 0xE3433B96, 0xF0E1A3E1, 0x8A8C6AA4, 0x992EF2D3, 0xADC95A4A, 0xBE6BC23D, + 0x5912C8C0, 0x4AB050B7, 0x7E57F82E, 0x6DF56059, 0x1798A91C, 0x043A316B, 0x30DD99F2, 0x237F0185, + 0x844819FB, 0x97EA818C, 0xA30D2915, 0xB0AFB162, 0xCAC27827, 0xD960E050, 0xED8748C9, 0xFE25D0BE, + 0x195CDA43, 0x0AFE4234, 0x3E19EAAD, 0x2DBB72DA, 0x57D6BB9F, 0x447423E8, 0x70938B71, 0x63311306, + 0xBB8DE87A, 0xA82F700D, 0x9CC8D894, 0x8F6A40E3, 0xF50789A6, 0xE6A511D1, 0xD242B948, 0xC1E0213F, + 0x26992BC2, 0x353BB3B5, 0x01DC1B2C, 0x127E835B, 0x68134A1E, 0x7BB1D269, 0x4F567AF0, 0x5CF4E287, + 0x04D43CFD, 0x1776A48A, 0x23910C13, 0x30339464, 0x4A5E5D21, 0x59FCC556, 0x6D1B6DCF, 0x7EB9F5B8, + 0x99C0FF45, 0x8A626732, 0xBE85CFAB, 0xAD2757DC, 0xD74A9E99, 0xC4E806EE, 0xF00FAE77, 0xE3AD3600, + 0x3B11CD7C, 0x28B3550B, 0x1C54FD92, 0x0FF665E5, 0x759BACA0, 0x663934D7, 0x52DE9C4E, 0x417C0439, + 0xA6050EC4, 0xB5A796B3, 0x81403E2A, 0x92E2A65D, 0xE88F6F18, 0xFB2DF76F, 0xCFCA5FF6, 0xDC68C781, + 0x7B5FDFFF, 0x68FD4788, 0x5C1AEF11, 0x4FB87766, 0x35D5BE23, 0x26772654, 0x12908ECD, 0x013216BA, + 0xE64B1C47, 0xF5E98430, 0xC10E2CA9, 0xD2ACB4DE, 0xA8C17D9B, 0xBB63E5EC, 0x8F844D75, 0x9C26D502, + 0x449A2E7E, 0x5738B609, 0x63DF1E90, 0x707D86E7, 0x0A104FA2, 0x19B2D7D5, 0x2D557F4C, 0x3EF7E73B, + 0xD98EEDC6, 0xCA2C75B1, 0xFECBDD28, 0xED69455F, 0x97048C1A, 0x84A6146D, 0xB041BCF4, 0xA3E32483 + }, + { /*table2*/ + 0x00000000, 0xA541927E, 0x4F6F520D, 0xEA2EC073, 0x9EDEA41A, 0x3B9F3664, 0xD1B1F617, 0x74F06469, + 0x38513EC5, 0x9D10ACBB, 0x773E6CC8, 0xD27FFEB6, 0xA68F9ADF, 0x03CE08A1, 0xE9E0C8D2, 0x4CA15AAC, + 0x70A27D8A, 0xD5E3EFF4, 0x3FCD2F87, 0x9A8CBDF9, 0xEE7CD990, 0x4B3D4BEE, 0xA1138B9D, 0x045219E3, + 0x48F3434F, 0xEDB2D131, 0x079C1142, 0xA2DD833C, 0xD62DE755, 0x736C752B, 0x9942B558, 0x3C032726, + 0xE144FB14, 0x4405696A, 0xAE2BA919, 0x0B6A3B67, 0x7F9A5F0E, 0xDADBCD70, 0x30F50D03, 0x95B49F7D, + 0xD915C5D1, 0x7C5457AF, 0x967A97DC, 0x333B05A2, 0x47CB61CB, 0xE28AF3B5, 0x08A433C6, 0xADE5A1B8, + 0x91E6869E, 0x34A714E0, 0xDE89D493, 0x7BC846ED, 0x0F382284, 0xAA79B0FA, 0x40577089, 0xE516E2F7, + 0xA9B7B85B, 0x0CF62A25, 0xE6D8EA56, 0x43997828, 0x37691C41, 0x92288E3F, 0x78064E4C, 0xDD47DC32, + 0xC76580D9, 0x622412A7, 0x880AD2D4, 0x2D4B40AA, 0x59BB24C3, 0xFCFAB6BD, 0x16D476CE, 0xB395E4B0, + 0xFF34BE1C, 0x5A752C62, 0xB05BEC11, 0x151A7E6F, 0x61EA1A06, 0xC4AB8878, 0x2E85480B, 0x8BC4DA75, + 0xB7C7FD53, 0x12866F2D, 0xF8A8AF5E, 0x5DE93D20, 0x29195949, 0x8C58CB37, 0x66760B44, 0xC337993A, + 0x8F96C396, 0x2AD751E8, 0xC0F9919B, 0x65B803E5, 0x1148678C, 0xB409F5F2, 0x5E273581, 0xFB66A7FF, + 0x26217BCD, 0x8360E9B3, 0x694E29C0, 0xCC0FBBBE, 0xB8FFDFD7, 0x1DBE4DA9, 0xF7908DDA, 0x52D11FA4, + 0x1E704508, 0xBB31D776, 0x511F1705, 0xF45E857B, 0x80AEE112, 0x25EF736C, 0xCFC1B31F, 0x6A802161, + 0x56830647, 0xF3C29439, 0x19EC544A, 0xBCADC634, 0xC85DA25D, 0x6D1C3023, 0x8732F050, 0x2273622E, + 0x6ED23882, 0xCB93AAFC, 0x21BD6A8F, 0x84FCF8F1, 0xF00C9C98, 0x554D0EE6, 0xBF63CE95, 0x1A225CEB, + 0x8B277743, 0x2E66E53D, 0xC448254E, 0x6109B730, 0x15F9D359, 0xB0B84127, 0x5A968154, 0xFFD7132A, + 0xB3764986, 0x1637DBF8, 0xFC191B8B, 0x595889F5, 0x2DA8ED9C, 0x88E97FE2, 0x62C7BF91, 0xC7862DEF, + 0xFB850AC9, 0x5EC498B7, 0xB4EA58C4, 0x11ABCABA, 0x655BAED3, 0xC01A3CAD, 0x2A34FCDE, 0x8F756EA0, + 0xC3D4340C, 0x6695A672, 0x8CBB6601, 0x29FAF47F, 0x5D0A9016, 0xF84B0268, 0x1265C21B, 0xB7245065, + 0x6A638C57, 0xCF221E29, 0x250CDE5A, 0x804D4C24, 0xF4BD284D, 0x51FCBA33, 0xBBD27A40, 0x1E93E83E, + 0x5232B292, 0xF77320EC, 0x1D5DE09F, 0xB81C72E1, 0xCCEC1688, 0x69AD84F6, 0x83834485, 0x26C2D6FB, + 0x1AC1F1DD, 0xBF8063A3, 0x55AEA3D0, 0xF0EF31AE, 0x841F55C7, 0x215EC7B9, 0xCB7007CA, 0x6E3195B4, + 0x2290CF18, 0x87D15D66, 0x6DFF9D15, 0xC8BE0F6B, 0xBC4E6B02, 0x190FF97C, 0xF321390F, 0x5660AB71, + 0x4C42F79A, 0xE90365E4, 0x032DA597, 0xA66C37E9, 0xD29C5380, 0x77DDC1FE, 0x9DF3018D, 0x38B293F3, + 0x7413C95F, 0xD1525B21, 0x3B7C9B52, 0x9E3D092C, 0xEACD6D45, 0x4F8CFF3B, 0xA5A23F48, 0x00E3AD36, + 0x3CE08A10, 0x99A1186E, 0x738FD81D, 0xD6CE4A63, 0xA23E2E0A, 0x077FBC74, 0xED517C07, 0x4810EE79, + 0x04B1B4D5, 0xA1F026AB, 0x4BDEE6D8, 0xEE9F74A6, 0x9A6F10CF, 0x3F2E82B1, 0xD50042C2, 0x7041D0BC, + 0xAD060C8E, 0x08479EF0, 0xE2695E83, 0x4728CCFD, 0x33D8A894, 0x96993AEA, 0x7CB7FA99, 0xD9F668E7, + 0x9557324B, 0x3016A035, 0xDA386046, 0x7F79F238, 0x0B899651, 0xAEC8042F, 0x44E6C45C, 0xE1A75622, + 0xDDA47104, 0x78E5E37A, 0x92CB2309, 0x378AB177, 0x437AD51E, 0xE63B4760, 0x0C158713, 0xA954156D, + 0xE5F54FC1, 0x40B4DDBF, 0xAA9A1DCC, 0x0FDB8FB2, 0x7B2BEBDB, 0xDE6A79A5, 0x3444B9D6, 0x91052BA8 + }, + { /*table3*/ + 0x00000000, 0xDD45AAB8, 0xBF672381, 0x62228939, 0x7B2231F3, 0xA6679B4B, 0xC4451272, 0x1900B8CA, + 0xF64463E6, 0x2B01C95E, 0x49234067, 0x9466EADF, 0x8D665215, 0x5023F8AD, 0x32017194, 0xEF44DB2C, + 0xE964B13D, 0x34211B85, 0x560392BC, 0x8B463804, 0x924680CE, 0x4F032A76, 0x2D21A34F, 0xF06409F7, + 0x1F20D2DB, 0xC2657863, 0xA047F15A, 0x7D025BE2, 0x6402E328, 0xB9474990, 0xDB65C0A9, 0x06206A11, + 0xD725148B, 0x0A60BE33, 0x6842370A, 0xB5079DB2, 0xAC072578, 0x71428FC0, 0x136006F9, 0xCE25AC41, + 0x2161776D, 0xFC24DDD5, 0x9E0654EC, 0x4343FE54, 0x5A43469E, 0x8706EC26, 0xE524651F, 0x3861CFA7, + 0x3E41A5B6, 0xE3040F0E, 0x81268637, 0x5C632C8F, 0x45639445, 0x98263EFD, 0xFA04B7C4, 0x27411D7C, + 0xC805C650, 0x15406CE8, 0x7762E5D1, 0xAA274F69, 0xB327F7A3, 0x6E625D1B, 0x0C40D422, 0xD1057E9A, + 0xABA65FE7, 0x76E3F55F, 0x14C17C66, 0xC984D6DE, 0xD0846E14, 0x0DC1C4AC, 0x6FE34D95, 0xB2A6E72D, + 0x5DE23C01, 0x80A796B9, 0xE2851F80, 0x3FC0B538, 0x26C00DF2, 0xFB85A74A, 0x99A72E73, 0x44E284CB, + 0x42C2EEDA, 0x9F874462, 0xFDA5CD5B, 0x20E067E3, 0x39E0DF29, 0xE4A57591, 0x8687FCA8, 0x5BC25610, + 0xB4868D3C, 0x69C32784, 0x0BE1AEBD, 0xD6A40405, 0xCFA4BCCF, 0x12E11677, 0x70C39F4E, 0xAD8635F6, + 0x7C834B6C, 0xA1C6E1D4, 0xC3E468ED, 0x1EA1C255, 0x07A17A9F, 0xDAE4D027, 0xB8C6591E, 0x6583F3A6, + 0x8AC7288A, 0x57828232, 0x35A00B0B, 0xE8E5A1B3, 0xF1E51979, 0x2CA0B3C1, 0x4E823AF8, 0x93C79040, + 0x95E7FA51, 0x48A250E9, 0x2A80D9D0, 0xF7C57368, 0xEEC5CBA2, 0x3380611A, 0x51A2E823, 0x8CE7429B, + 0x63A399B7, 0xBEE6330F, 0xDCC4BA36, 0x0181108E, 0x1881A844, 0xC5C402FC, 0xA7E68BC5, 0x7AA3217D, + 0x52A0C93F, 0x8FE56387, 0xEDC7EABE, 0x30824006, 0x2982F8CC, 0xF4C75274, 0x96E5DB4D, 0x4BA071F5, + 0xA4E4AAD9, 0x79A10061, 0x1B838958, 0xC6C623E0, 0xDFC69B2A, 0x02833192, 0x60A1B8AB, 0xBDE41213, + 0xBBC47802, 0x6681D2BA, 0x04A35B83, 0xD9E6F13B, 0xC0E649F1, 0x1DA3E349, 0x7F816A70, 0xA2C4C0C8, + 0x4D801BE4, 0x90C5B15C, 0xF2E73865, 0x2FA292DD, 0x36A22A17, 0xEBE780AF, 0x89C50996, 0x5480A32E, + 0x8585DDB4, 0x58C0770C, 0x3AE2FE35, 0xE7A7548D, 0xFEA7EC47, 0x23E246FF, 0x41C0CFC6, 0x9C85657E, + 0x73C1BE52, 0xAE8414EA, 0xCCA69DD3, 0x11E3376B, 0x08E38FA1, 0xD5A62519, 0xB784AC20, 0x6AC10698, + 0x6CE16C89, 0xB1A4C631, 0xD3864F08, 0x0EC3E5B0, 0x17C35D7A, 0xCA86F7C2, 0xA8A47EFB, 0x75E1D443, + 0x9AA50F6F, 0x47E0A5D7, 0x25C22CEE, 0xF8878656, 0xE1873E9C, 0x3CC29424, 0x5EE01D1D, 0x83A5B7A5, + 0xF90696D8, 0x24433C60, 0x4661B559, 0x9B241FE1, 0x8224A72B, 0x5F610D93, 0x3D4384AA, 0xE0062E12, + 0x0F42F53E, 0xD2075F86, 0xB025D6BF, 0x6D607C07, 0x7460C4CD, 0xA9256E75, 0xCB07E74C, 0x16424DF4, + 0x106227E5, 0xCD278D5D, 0xAF050464, 0x7240AEDC, 0x6B401616, 0xB605BCAE, 0xD4273597, 0x09629F2F, + 0xE6264403, 0x3B63EEBB, 0x59416782, 0x8404CD3A, 0x9D0475F0, 0x4041DF48, 0x22635671, 0xFF26FCC9, + 0x2E238253, 0xF36628EB, 0x9144A1D2, 0x4C010B6A, 0x5501B3A0, 0x88441918, 0xEA669021, 0x37233A99, + 0xD867E1B5, 0x05224B0D, 0x6700C234, 0xBA45688C, 0xA345D046, 0x7E007AFE, 0x1C22F3C7, 0xC167597F, + 0xC747336E, 0x1A0299D6, 0x782010EF, 0xA565BA57, 0xBC65029D, 0x6120A825, 0x0302211C, 0xDE478BA4, + 0x31035088, 0xEC46FA30, 0x8E647309, 0x5321D9B1, 0x4A21617B, 0x9764CBC3, 0xF54642FA, 0x2803E842 + } + }; + internal_crc32_table = crc32_table; +#if (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH) + LIBXSMM_UNUSED(target_arch); +#else + if (LIBXSMM_X86_SSE42 <= target_arch) +#endif + { + internal_hash_u32_function = internal_crc32_u32_sse4; + internal_hash_u64_function = internal_crc32_u64_sse4; + internal_hash_u128_function = internal_crc32_u128_sse4; + internal_hash_u256_function = internal_crc32_u256_sse4; + internal_hash_u384_function = internal_crc32_u384_sse4; + internal_hash_u512_function = internal_crc32_u512_sse4; + internal_hash_function = (libxsmm_hash_function)internal_crc32_sse4; + } +#if (LIBXSMM_X86_SSE42 > LIBXSMM_STATIC_TARGET_ARCH) + else { +# if defined(LIBXSMM_PLATFORM_X86) && !defined(LIBXSMM_INTRINSICS_SSE42) + static int error_once = 0; + if (0 == error_once && 0 != libxsmm_verbosity) { /* library code is expected to be mute */ + fprintf(stderr, "LIBXSMM WARNING: unable to access CRC32 instructions due to the compiler used!\n"); + error_once = 1; /* no need for atomics */ + } +# endif + internal_hash_u32_function = internal_crc32_u32; + internal_hash_u64_function = internal_crc32_u64; + internal_hash_u128_function = internal_crc32_u128; + internal_hash_u256_function = internal_crc32_u256; + internal_hash_u384_function = internal_crc32_u384; + internal_hash_u512_function = internal_crc32_u512; + internal_hash_function = (libxsmm_hash_function)internal_crc32; + } +#endif + LIBXSMM_ASSERT(NULL != internal_hash_u32_function); + LIBXSMM_ASSERT(NULL != internal_hash_u64_function); + LIBXSMM_ASSERT(NULL != internal_hash_u128_function); + LIBXSMM_ASSERT(NULL != internal_hash_u256_function); + LIBXSMM_ASSERT(NULL != internal_hash_u384_function); + LIBXSMM_ASSERT(NULL != internal_hash_u512_function); + LIBXSMM_ASSERT(NULL != internal_hash_function); +} + + +LIBXSMM_API_INTERN void libxsmm_hash_finalize(void) +{ +#if !defined(NDEBUG) + internal_crc32_table = NULL; + internal_hash_u32_function = NULL; + internal_hash_u64_function = NULL; + internal_hash_u128_function = NULL; + internal_hash_u256_function = NULL; + internal_hash_u384_function = NULL; + internal_hash_u512_function = NULL; + internal_hash_function = NULL; +#endif +} + + +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u32(unsigned int seed, const void* value, ...) +{ +#if (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH) + return LIBXSMM_HASH_CRC32_U32(seed, value); +#elif (LIBXSMM_X86_SSE42 > LIBXSMM_MAX_STATIC_TARGET_ARCH) + return internal_crc32_u32(seed, value); +#else /* pointer based function call */ + LIBXSMM_ASSERT(NULL != internal_hash_u32_function); + return internal_hash_u32_function(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u64(unsigned int seed, const void* value, ...) +{ +#if (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH) + return (unsigned int)LIBXSMM_HASH_CRC32_U64(seed, value); +#elif (LIBXSMM_X86_SSE42 > LIBXSMM_MAX_STATIC_TARGET_ARCH) + return internal_crc32_u64(seed, value); +#else /* pointer based function call */ + LIBXSMM_ASSERT(NULL != internal_hash_u64_function); + return internal_hash_u64_function(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u128(unsigned int seed, const void* value, ...) +{ +#if (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH) + return internal_crc32_u128_sse4(seed, value); +#elif (LIBXSMM_X86_SSE42 > LIBXSMM_MAX_STATIC_TARGET_ARCH) + return internal_crc32_u128(seed, value); +#else /* pointer based function call */ + LIBXSMM_ASSERT(NULL != internal_hash_u128_function); + return internal_hash_u128_function(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u256(unsigned int seed, const void* value, ...) +{ +#if (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH) + return internal_crc32_u256_sse4(seed, value); +#elif (LIBXSMM_X86_SSE42 > LIBXSMM_MAX_STATIC_TARGET_ARCH) + return internal_crc32_u256(seed, value); +#else /* pointer based function call */ + LIBXSMM_ASSERT(NULL != internal_hash_u256_function); + return internal_hash_u256_function(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u384(unsigned int seed, const void* value, ...) +{ +#if (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH) + return internal_crc32_u384_sse4(seed, value); +#elif (LIBXSMM_X86_SSE42 > LIBXSMM_MAX_STATIC_TARGET_ARCH) + return internal_crc32_u384(seed, value); +#else /* pointer based function call */ + LIBXSMM_ASSERT(NULL != internal_hash_u384_function); + return internal_hash_u384_function(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u512(unsigned int seed, const void* value, ...) +{ +#if (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH) + return internal_crc32_u512_sse4(seed, value); +#elif (LIBXSMM_X86_SSE42 > LIBXSMM_MAX_STATIC_TARGET_ARCH) + return internal_crc32_u512(seed, value); +#else /* pointer based function call */ + LIBXSMM_ASSERT(NULL != internal_hash_u512_function); + return internal_hash_u512_function(seed, value); +#endif +} + + +LIBXSMM_API_INTERN unsigned int libxsmm_crc32(unsigned int seed, const void* data, size_t size) +{ +#if (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH) + return internal_crc32_sse4(seed, data, size); +#elif (LIBXSMM_X86_SSE42 > LIBXSMM_MAX_STATIC_TARGET_ARCH) + return internal_crc32(seed, data, size); +#else /* pointer based function call */ + LIBXSMM_ASSERT(NULL != internal_hash_function); + return internal_hash_function(seed, data, size); +#endif +} + diff --git a/third_party/libxsmm/src/libxsmm_hash.h b/third_party/libxsmm/src/libxsmm_hash.h new file mode 100644 index 0000000000000000000000000000000000000000..c5df564ef36bf8c9b6407b15b34506f80e15e333 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_hash.h @@ -0,0 +1,47 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_HASH_H +#define LIBXSMM_HASH_H + +#include + +/* Map number of Bits to corresponding routine. */ +#define LIBXSMM_CRC32U(N) LIBXSMM_CONCATENATE(libxsmm_crc32_u, N) +/* Map number of Bytes to number of bits. */ +#define LIBXSMM_CRC32(N) LIBXSMM_CONCATENATE(libxsmm_crc32_b, N) +#define libxsmm_crc32_b4 libxsmm_crc32_u32 +#define libxsmm_crc32_b8 libxsmm_crc32_u64 +#define libxsmm_crc32_b16 libxsmm_crc32_u128 +#define libxsmm_crc32_b32 libxsmm_crc32_u256 +#define libxsmm_crc32_b48 libxsmm_crc32_u384 +#define libxsmm_crc32_b64 libxsmm_crc32_u512 + + +/** Function type representing the CRC32 functionality. */ +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE unsigned int (*libxsmm_hash_function)( + unsigned int /*seed*/, const void* /*data*/, ... /*size*/); + +/** Initialize hash function module; not thread-safe. */ +LIBXSMM_API_INTERN void libxsmm_hash_init(int target_arch); +LIBXSMM_API_INTERN void libxsmm_hash_finalize(void); + +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u32(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u64(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u128(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u256(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u384(unsigned int seed, const void* value, ...); +LIBXSMM_API_INTERN unsigned int libxsmm_crc32_u512(unsigned int seed, const void* value, ...); + +/** Calculate the CRC32 for a given quantity (size) of raw data according to the seed. */ +LIBXSMM_API_INTERN unsigned int libxsmm_crc32(unsigned int seed, const void* data, size_t size); + +#endif /*LIBXSMM_HASH_H*/ + diff --git a/third_party/libxsmm/src/libxsmm_main.c b/third_party/libxsmm/src/libxsmm_main.c new file mode 100644 index 0000000000000000000000000000000000000000..4326fffd436c9216935c2a50f3841761d3a0c8c3 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_main.c @@ -0,0 +1,4981 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_trace.h" +#include "libxsmm_xcopy.h" +#include "libxsmm_gemm.h" +#include "libxsmm_hash.h" +#include "libxsmm_diff.h" +#include "libxsmm_main.h" +#if defined(LIBXSMM_PERF) +# include "libxsmm_perf.h" +#endif +#include "generator_common.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#if !defined(NDEBUG) +# include +#endif +#if defined(_WIN32) +# include +#else +# include +# include +# include +# include +# include +#endif +#if defined(__APPLE__) +# include +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if !defined(LIBXSMM_CODE_MAXSIZE) +# define LIBXSMM_CODE_MAXSIZE 131072 +#endif +#if !defined(LIBXSMM_DIFF_SIZE) +# define LIBXSMM_DIFF_SIZE LIBXSMM_DESCRIPTOR_SIGSIZE +#endif +#if !defined(LIBXSMM_HASH_SIZE) +/* can be smaller than MAXSIZE/SIGSIZE at the expense of collisions */ +# define LIBXSMM_HASH_SIZE 32 +#endif +#if !defined(LIBXSMM_HASH_SEED) +# define LIBXSMM_HASH_SEED 25071975 +#endif +#if !defined(LIBXSMM_MALLOC_HOOK_ALIGN) && 1 +# define LIBXSMM_MALLOC_HOOK_ALIGN +#endif +#if !defined(LIBXSMM_ENABLE_DEREG) && 0 +# define LIBXSMM_ENABLE_DEREG +#endif +#if !defined(LIBXSMM_REGUSER_HASH) && 1 +# define LIBXSMM_REGUSER_HASH +#endif +#if !defined(LIBXSMM_REGLOCK_TRY) && 0 +# define LIBXSMM_REGLOCK_TRY +#endif +#if !defined(LIBXSMM_UNIFY_LOCKS) && 1 +# define LIBXSMM_UNIFY_LOCKS +#endif +#if !defined(LIBXSMM_REGKEY_PAD) && 0 +# define LIBXSMM_REGKEY_PAD +#endif +#if !defined(LIBXSMM_CACHE_PAD) && 1 +# define LIBXSMM_CACHE_PAD +#endif +#if !defined(LIBXSMM_AUTOPIN) && 0 +# define LIBXSMM_AUTOPIN +#endif +#if !defined(INTERNAL_DELIMS) +# define INTERNAL_DELIMS ";,:" +#endif + +#if !defined(_WIN32) && !defined(__CYGWIN__) +LIBXSMM_EXTERN int posix_memalign(void**, size_t, size_t) LIBXSMM_THROW; +#endif +#if defined(LIBXSMM_AUTOPIN) && !defined(_WIN32) +LIBXSMM_EXTERN int putenv(char*) LIBXSMM_THROW; +#endif + +/* flag fused into the memory address of a code version in case of non-JIT */ +#define LIBXSMM_CODE_STATIC (1ULL << (8 * sizeof(void*) - 1)) +/* flag fused into the memory address of a code version in case of collision */ +#if 1 /* beneficial when registry approaches capacity (collisions) */ +# define LIBXSMM_HASH_COLLISION (1ULL << (8 * sizeof(void*) - 2)) +#endif + +/** Helper macro determining the default prefetch strategy which is used for statically generated kernels. */ +#if (0 > LIBXSMM_PREFETCH) /* auto-prefetch (frontend) */ || (defined(_WIN32) || defined(__CYGWIN__)) +# define INTERNAL_PREFETCH LIBXSMM_GEMM_PREFETCH_NONE +#else +# define INTERNAL_PREFETCH ((libxsmm_gemm_prefetch_type)LIBXSMM_PREFETCH) +#endif + +#if (0 != LIBXSMM_SYNC) +# if !defined(INTERNAL_REGLOCK_MAXN) +# if defined(_MSC_VER) +# define INTERNAL_REGLOCK_MAXN 0 +# else +# define INTERNAL_REGLOCK_MAXN 0 +# endif +# endif +# if (1 < INTERNAL_REGLOCK_MAXN) +# if !defined(LIBXSMM_CACHE_MAXSIZE) && (8 > INTERNAL_REGLOCK_MAXN) +# define LIBXSMM_CACHE_MAXSIZE LIBXSMM_CAPACITY_CACHE +# endif +# if !defined(LIBXSMM_REGLOCK) +# define LIBXSMM_REGLOCK LIBXSMM_LOCK_DEFAULT +# endif +# if !defined(LIBXSMM_CLEANUP_NTRY) +# define LIBXSMM_CLEANUP_NTRY 7 +# endif +# if LIBXSMM_LOCK_TYPE_ISPOD(LIBXSMM_REGLOCK) +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_reglocktype { + char pad[LIBXSMM_CACHELINE]; + LIBXSMM_LOCK_TYPE(LIBXSMM_REGLOCK) state; +} internal_reglocktype; +# else +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_reglocktype { + LIBXSMM_LOCK_TYPE(LIBXSMM_REGLOCK) state; +} internal_reglocktype; +# endif +LIBXSMM_APIVAR_DEFINE(internal_reglocktype internal_reglock[INTERNAL_REGLOCK_MAXN]); +# else /* RW-lock */ +# if !defined(LIBXSMM_CACHE_MAXSIZE) +# define LIBXSMM_CACHE_MAXSIZE LIBXSMM_CAPACITY_CACHE +# endif +# if !defined(LIBXSMM_REGLOCK) +# if defined(LIBXSMM_UNIFY_LOCKS) +# define LIBXSMM_REGLOCK LIBXSMM_LOCK +# elif defined(_MSC_VER) +# define LIBXSMM_REGLOCK LIBXSMM_LOCK_MUTEX +# elif 0 +# define LIBXSMM_REGLOCK LIBXSMM_LOCK_RWLOCK +# else +# define LIBXSMM_REGLOCK LIBXSMM_LOCK_DEFAULT +# endif +# endif +LIBXSMM_APIVAR_DEFINE(LIBXSMM_LOCK_TYPE(LIBXSMM_REGLOCK)* internal_reglock_ptr); +# endif +#elif !defined(LIBXSMM_CACHE_MAXSIZE) +# define LIBXSMM_CACHE_MAXSIZE LIBXSMM_CAPACITY_CACHE +#endif +#if defined(LIBXSMM_UNPACKED) /* CCE/Classic */ +# define LIBXSMM_CACHE_STRIDE LIBXSMM_MAX(sizeof(libxsmm_descriptor), LIBXSMM_DESCRIPTOR_MAXSIZE) +#else +# define LIBXSMM_CACHE_STRIDE LIBXSMM_DESCRIPTOR_MAXSIZE +#endif + +#if defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) +# define INTERNAL_FIND_CODE_CACHE_GROW(RESULT_INDEX, CACHE_SIZE) \ + RESULT_INDEX = CACHE_SIZE; CACHE_SIZE = (unsigned char)(0 != (CACHE_SIZE) ? ((CACHE_SIZE) << 1) : 1) +# define INTERNAL_FIND_CODE_CACHE_EVICT(RESULT_INDEX, CACHE_SIZE, CACHE_HIT) \ + RESULT_INDEX = (unsigned char)LIBXSMM_MOD2((CACHE_HIT) + ((CACHE_SIZE) - 1), CACHE_SIZE) +#endif + +#if (0 == LIBXSMM_SYNC) +# define INTERNAL_FIND_CODE_LOCK(LOCKINDEX, INDEX, DIFF, CODE) { +# define INTERNAL_FIND_CODE_UNLOCK(LOCKINDEX) } +#else +# if defined(LIBXSMM_REGLOCK_TRY) +# define INTERNAL_REGLOCK_TRY(DIFF, CODE) \ + if (1 != internal_reglock_count) { /* (re-)try and get (meanwhile) generated code */ \ + LIBXSMM_ASSERT(NULL != internal_registry); /* engine is not shut down */ \ + continue; \ + } \ + else { /* exit dispatch and let client fall back */ \ + DIFF = 0; CODE = 0; break; \ + } +# else +# define INTERNAL_REGLOCK_TRY(DIFF, CODE) \ + LIBXSMM_ASSERT(NULL != internal_registry); /* engine is not shut down */ \ + continue +# endif +# if (1 < INTERNAL_REGLOCK_MAXN) +# define INTERNAL_FIND_CODE_LOCK(LOCKINDEX, INDEX, DIFF, CODE) { \ + const unsigned int LOCKINDEX = (0 != internal_reglock_count ? LIBXSMM_MOD2(INDEX, internal_reglock_count) : 0); \ + if (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_REGLOCK) != LIBXSMM_LOCK_TRYLOCK(LIBXSMM_REGLOCK, &internal_reglock[LOCKINDEX].state)) { \ + INTERNAL_REGLOCK_TRY(DIFF, CODE); \ + } +# define INTERNAL_FIND_CODE_UNLOCK(LOCKINDEX) LIBXSMM_LOCK_RELEASE(LIBXSMM_REGLOCK, &internal_reglock[LOCKINDEX].state); } +# else /* RW-lock */ +# define INTERNAL_FIND_CODE_LOCK(LOCKINDEX, INDEX, DIFF, CODE) { \ + if (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_REGLOCK) != LIBXSMM_LOCK_TRYLOCK(LIBXSMM_REGLOCK, internal_reglock_ptr)) { \ + INTERNAL_REGLOCK_TRY(DIFF, CODE); \ + } +# define INTERNAL_FIND_CODE_UNLOCK(LOCKINDEX) LIBXSMM_LOCK_RELEASE(LIBXSMM_REGLOCK, internal_reglock_ptr); } +# endif +#endif + + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE internal_statistic_type { + unsigned int ntry, ncol, njit, nsta; +} internal_statistic_type; + +#if defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE internal_cache_entry_type { + libxsmm_descriptor keys[LIBXSMM_CACHE_MAXSIZE]; + libxsmm_code_pointer code[LIBXSMM_CACHE_MAXSIZE]; + unsigned int id; /* to invalidate */ + unsigned char size, hit; +} internal_cache_entry_type; + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_cache_type { +# if defined(LIBXSMM_CACHE_PAD) + char pad[LIBXSMM_UP2(sizeof(internal_cache_entry_type),LIBXSMM_CACHELINE)]; +# endif + internal_cache_entry_type entry; +} internal_cache_type; + +# if defined(LIBXSMM_NTHREADS_USE) +LIBXSMM_APIVAR_DEFINE(internal_cache_type* internal_cache_buffer); +# endif +LIBXSMM_APIVAR_DEFINE(int internal_cache_size); +#endif /*defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE))*/ + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_regkey_type { +#if defined(LIBXSMM_REGKEY_PAD) + char pad[LIBXSMM_UP2(sizeof(libxsmm_descriptor), LIBXSMM_CACHELINE)]; +#endif + libxsmm_descriptor entry; +} internal_regkey_type; + +/** Determines the try-lock property (1m && 1 < desc->n) { /* only record matrix-matrix multiplication */ + const unsigned long long kernel_size = LIBXSMM_MNK_SIZE(desc->m, desc->n, desc->k); + const int idx = (LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_OUT(desc->datatype) ? 0 : 1); + int bucket; + if (LIBXSMM_MNK_SIZE(internal_statistic_sml, internal_statistic_sml, internal_statistic_sml) >= kernel_size) { + bucket = 0; + } + else if (LIBXSMM_MNK_SIZE(internal_statistic_med, internal_statistic_med, internal_statistic_med) >= kernel_size) { + bucket = 1; + } + else if (LIBXSMM_MNK_SIZE(internal_statistic_mnk, internal_statistic_mnk, internal_statistic_mnk) >= kernel_size) { + bucket = 2; + } + else { /*huge*/ + bucket = 3; + } + if (0 != ncol) ncol/*dummy assignment*/ = LIBXSMM_ATOMIC_ADD_FETCH(&internal_statistic[idx][bucket].ncol, ncol, LIBXSMM_ATOMIC_RELAXED); + if (0 != ntry) ntry/*dummy assignment*/ = LIBXSMM_ATOMIC_ADD_FETCH(&internal_statistic[idx][bucket].ntry, ntry, LIBXSMM_ATOMIC_RELAXED); + /* the following counters are not manipulated concurrently (no need for atomic increment) */ + if (0 != njit) internal_statistic[idx][bucket].njit += njit; + if (0 != nsta) internal_statistic[idx][bucket].nsta += nsta; + } +} + + +LIBXSMM_API_INLINE unsigned int internal_print_number(unsigned int n, char default_unit, char* unit) +{ + unsigned int number = n; + LIBXSMM_ASSERT(NULL != unit); + *unit = default_unit; + if ((1000000) <= n) { + number = (n + 500000) / 1000000; + *unit = 'm'; + } + else if (9999 < n) { + number = (n + 500) / 1000; + *unit = 'k'; + } + return number; +} + + +LIBXSMM_API_INLINE unsigned int internal_print_statistic(FILE* ostream, + const char* target_arch, int precision, unsigned int linebreaks, unsigned int indent) +{ + const internal_statistic_type statistic_sml = internal_statistic[precision][0/*SML*/]; + const internal_statistic_type statistic_med = internal_statistic[precision][1/*MED*/]; + const internal_statistic_type statistic_big = internal_statistic[precision][2/*BIG*/]; + const internal_statistic_type statistic_xxx = internal_statistic[precision][3/*XXX*/]; + int printed = 0; + LIBXSMM_ASSERT(NULL != ostream && (0 <= precision && precision < 2)); + if (/* omit to print anything if it is superfluous */ + 0 != statistic_sml.ntry || 0 != statistic_sml.njit || 0 != statistic_sml.nsta || 0 != statistic_sml.ncol || + 0 != statistic_med.ntry || 0 != statistic_med.njit || 0 != statistic_med.nsta || 0 != statistic_med.ncol || + 0 != statistic_big.ntry || 0 != statistic_big.njit || 0 != statistic_big.nsta || 0 != statistic_big.ncol || + 0 != statistic_xxx.ntry || 0 != statistic_xxx.njit || 0 != statistic_xxx.nsta || 0 != statistic_xxx.ncol) + { + char title[256], range[256], unit[4]; + unsigned int counter[4]; + { + unsigned int n; + if (NULL != target_arch && '\0' != *target_arch) { + assert(strlen(target_arch) < sizeof(title)); /* !LIBXSMM_ASSERT */ + for (n = 0; 0 != target_arch[n] /*avoid code-gen. issue with some clang versions: && n < sizeof(title)*/; ++n) { + const char c = target_arch[n]; + title[n] = (char)(('a' <= c && c <= 'z') ? (c - 32) : c); /* toupper */ + } + LIBXSMM_SNPRINTF(title + n, sizeof(title) - n, "/%s", 0 == precision ? "DP" : "SP"); + } + else { + LIBXSMM_SNPRINTF(title, sizeof(title), "%s", 0 == precision ? "DP" : "SP"); + } + for (n = 0; n < linebreaks; ++n) fprintf(ostream, "\n"); + } + fprintf(ostream, "%*s%-8s %6s %6s %6s %6s\n", (int)indent, "", title, "TRY", "JIT", "STA", "COL"); + LIBXSMM_SNPRINTF(range, sizeof(range), "%u..%u", 0u, internal_statistic_sml); + counter[0] = internal_print_number(statistic_sml.ntry, ' ', unit + 0); + counter[1] = internal_print_number(statistic_sml.njit, ' ', unit + 1); + counter[2] = internal_print_number(statistic_sml.nsta, ' ', unit + 2); + counter[3] = internal_print_number(statistic_sml.ncol, ' ', unit + 3); + fprintf(ostream, "%*s%8s %6u%c %5u%c %5u%c %5u%c\n", (int)indent, "", range, + counter[0], unit[0], counter[1], unit[1], counter[2], unit[2], counter[3], unit[3]); + LIBXSMM_SNPRINTF(range, sizeof(range), "%u..%u", internal_statistic_sml + 1u, internal_statistic_med); + counter[0] = internal_print_number(statistic_med.ntry, ' ', unit + 0); + counter[1] = internal_print_number(statistic_med.njit, ' ', unit + 1); + counter[2] = internal_print_number(statistic_med.nsta, ' ', unit + 2); + counter[3] = internal_print_number(statistic_med.ncol, ' ', unit + 3); + fprintf(ostream, "%*s%8s %6u%c %5u%c %5u%c %5u%c\n", (int)indent, "", range, + counter[0], unit[0], counter[1], unit[1], counter[2], unit[2], counter[3], unit[3]); + LIBXSMM_SNPRINTF(range, sizeof(range), "%u..%u", internal_statistic_med + 1u, internal_statistic_mnk); + counter[0] = internal_print_number(statistic_big.ntry, ' ', unit + 0); + counter[1] = internal_print_number(statistic_big.njit, ' ', unit + 1); + counter[2] = internal_print_number(statistic_big.nsta, ' ', unit + 2); + counter[3] = internal_print_number(statistic_big.ncol, ' ', unit + 3); + fprintf(ostream, "%*s%8s %6u%c %5u%c %5u%c %5u%c\n", (int)indent, "", range, + counter[0], unit[0], counter[1], unit[1], counter[2], unit[2], counter[3], unit[3]); + if (0 != statistic_xxx.ntry || 0 != statistic_xxx.njit || 0 != statistic_xxx.nsta || 0 != statistic_xxx.ncol) { + LIBXSMM_SNPRINTF(range, sizeof(range), "> %u", internal_statistic_mnk); + counter[0] = internal_print_number(statistic_xxx.ntry, ' ', unit + 0); + counter[1] = internal_print_number(statistic_xxx.njit, ' ', unit + 1); + counter[2] = internal_print_number(statistic_xxx.nsta, ' ', unit + 2); + counter[3] = internal_print_number(statistic_xxx.ncol, ' ', unit + 3); + fprintf(ostream, "%*s%8s %6u%c %5u%c %5u%c %5u%c\n", (int)indent, "", range, + counter[0], unit[0], counter[1], unit[1], counter[2], unit[2], counter[3], unit[3]); + } + printed = 1; + } + return printed; +} + + +#if !(defined(_WIN32) || defined(__CYGWIN__)) +LIBXSMM_API_INLINE unsigned int internal_statistic_ntry(int precision) +{ + return internal_statistic[precision][0/*SML*/].ntry + internal_statistic[precision][1/*MED*/].ntry + + internal_statistic[precision][2/*BIG*/].ntry + internal_statistic[precision][3/*XXX*/].ntry; +} +#endif + + +#if !defined(_WIN32) +LIBXSMM_API_INLINE void internal_register_static_code( + libxsmm_gemm_precision precision, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + libxsmm_xmmfunction xgemm, libxsmm_code_pointer* registry) +{ + const libxsmm_blasint lda = m, ldb = k, ldc = m; + /*const*/ int precondition = LIBXSMM_GEMM_NO_BYPASS_DIMS(m, n, k) && LIBXSMM_GEMM_NO_BYPASS_DIMS(lda, ldb, ldc); + if (precondition) { + const size_t size = (LIBXSMM_HASH_SIZE) - sizeof(libxsmm_descriptor_kind); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_gemm_descriptor_dinit(&blob, precision, + m, n, k, lda, ldb, ldc, LIBXSMM_ALPHA, LIBXSMM_BETA, LIBXSMM_FLAGS, INTERNAL_PREFETCH); + unsigned int i = LIBXSMM_MOD2( + libxsmm_crc32(LIBXSMM_HASH_SEED, desc, LIBXSMM_MIN(sizeof(libxsmm_gemm_descriptor), size)), + LIBXSMM_CAPACITY_REGISTRY); + libxsmm_code_pointer* dst_entry = registry + i; +#if !defined(NDEBUG) + libxsmm_code_pointer code; code.xgemm = xgemm; + LIBXSMM_ASSERT(NULL != code.ptr_const && NULL != registry); + LIBXSMM_ASSERT(0 == (LIBXSMM_CODE_STATIC & code.uval)); +#endif + if (NULL != dst_entry->ptr_const) { /* collision */ + const unsigned int i0 = i; + do { /* continue to linearly search for an available slot */ + i = LIBXSMM_MOD2(i + 1, LIBXSMM_CAPACITY_REGISTRY); + if (NULL == registry[i].ptr_const) break; + } while (i != i0); +#if defined(LIBXSMM_HASH_COLLISION) /* mark entry as a collision */ + dst_entry->uval |= LIBXSMM_HASH_COLLISION; +#endif + dst_entry = registry + i; /* update destination */ + internal_update_mmstatistic(desc, 0, 1/*collision*/, 0, 0); + /* out of capacity (no registry slot available) */ + LIBXSMM_ASSERT(NULL == dst_entry->ptr_const || i == i0); + } + if (NULL == dst_entry->ptr_const) { /* registry not exhausted */ + internal_registry_keys[i].entry.kind = LIBXSMM_KERNEL_KIND_MATMUL; + LIBXSMM_ASSIGN127(&internal_registry_keys[i].entry.gemm.desc, desc); + dst_entry->xgemm = xgemm; + /* mark current entry as static code (non-JIT) */ + dst_entry->uval |= LIBXSMM_CODE_STATIC; + } + internal_update_mmstatistic(desc, 1/*try*/, 0, 0, 0); + } +} +#endif + + +LIBXSMM_API_INTERN void internal_release_scratch(void); +LIBXSMM_API_INTERN void internal_release_scratch(void) +{ + libxsmm_xrelease_scratch(NULL/*lock*/); + /* release global services */ + libxsmm_memory_finalize(); + libxsmm_hash_finalize(); + libxsmm_malloc_finalize(); +} + + +/* Caution: cannot be used multiple times in a single expression! */ +LIBXSMM_API_INTERN size_t libxsmm_format_value(char buffer[32], int buffer_size, size_t nbytes, const char scale[], const char* unit, int base) +{ + const int len = (NULL != scale ? ((int)strlen(scale)) : 0); + const int m = LIBXSMM_INTRINSICS_BITSCANBWD64(nbytes) / base, n = LIBXSMM_MIN(m, len); + int i; + buffer[0] = 0; /* clear */ + LIBXSMM_ASSERT(NULL != unit && 0 <= base); + for (i = 0; i < n; ++i) nbytes >>= base; + LIBXSMM_SNPRINTF(buffer, buffer_size, "%i %c%s", + (int)nbytes, 0 < n ? scale[n-1] : *unit, 0 < n ? unit : ""); + return nbytes; +} + + +LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_NO_TRACE void internal_dump(FILE* ostream, int urgent); +LIBXSMM_API_INTERN void internal_dump(FILE* ostream, int urgent) +{ + char *const env_dump_build = getenv("LIBXSMM_DUMP_BUILD"); + char *const env_dump_files = (NULL != getenv("LIBXSMM_DUMP_FILES") + ? getenv("LIBXSMM_DUMP_FILES") + : getenv("LIBXSMM_DUMP_FILE")); + LIBXSMM_ASSERT_MSG(INTERNAL_SINGLETON(internal_singleton_handle), "Invalid handle"); + /* determine whether this instance is unique or not */ + if (NULL != env_dump_files && '\0' != *env_dump_files && 0 == urgent) { /* dump per-node info */ + const char* filename = strtok(env_dump_files, INTERNAL_DELIMS); + char buffer[1024]; + for (; NULL != filename; filename = strtok(NULL, INTERNAL_DELIMS)) { + FILE* file = fopen(filename, "r"); + if (NULL != file) buffer[0] = '\0'; + else { /* parse keywords */ + const int seconds = atoi(filename); + if (0 == seconds) { + const char *const pid = strstr(filename, "PID"); + if (NULL != pid) { /* PID-keyword is present */ + int n = (int)(pid - filename); + n = LIBXSMM_SNPRINTF(buffer, sizeof(buffer), "%.*s%u%s", n, filename, libxsmm_get_pid(), filename + n + 3); + if (0 < n && (int)sizeof(buffer) > n) { + file = fopen(buffer, "r"); + filename = buffer; + } + } + } + else { + fprintf(stderr, "LIBXSMM INFO: PID=%u\n", libxsmm_get_pid()); + if (0 < seconds) { +#if defined(_WIN32) + Sleep((DWORD)(1000 * seconds)); +#else + LIBXSMM_EXPECT(EXIT_SUCCESS, sleep(seconds)); +#endif + } + else for(;;) LIBXSMM_SYNC_YIELD; + } + } + if (NULL != file) { + int c = fgetc(file); + fprintf(ostream, "\n\nLIBXSMM_DUMP_FILE: %s\n", filename); + /* coverity[tainted_data] */ + while (EOF != c) { + fputc(c, stdout); + c = fgetc(file); + } + fputc('\n', stdout); + fclose(file); + } + } + } + if (NULL != internal_build_state /* dump build state */ + && NULL != env_dump_build && '\0' != *env_dump_build) + { + const int dump_build = atoi(env_dump_build); + if (0 == urgent ? (0 < dump_build) : (0 > dump_build)) { + fprintf(ostream, "\n\nBUILD_DATE=%i\n", LIBXSMM_CONFIG_BUILD_DATE); + fprintf(ostream, "%s\n", internal_build_state); + } + } +} + + +LIBXSMM_API_INTERN void internal_finalize(void); +LIBXSMM_API_INTERN void internal_finalize(void) +{ + libxsmm_finalize(); + LIBXSMM_STDIO_ACQUIRE(); /* synchronize I/O */ + if (0 != libxsmm_verbosity) { /* print statistic on termination */ + const char *const env_target_hidden = getenv("LIBXSMM_TARGET_HIDDEN"); + const char *const target_arch = (NULL == env_target_hidden || 0 == atoi(env_target_hidden)) + ? libxsmm_cpuid_name(libxsmm_target_archid) : NULL/*hidden*/; + fprintf(stderr, "\nLIBXSMM_VERSION: %s%s%s (%i)", LIBXSMM_BRANCH, + 0 != *(LIBXSMM_BRANCH) ? "-" : "", 0 != *(LIBXSMM_VERSION) ? (LIBXSMM_VERSION) : "unconfigured", + LIBXSMM_VERSION4(LIBXSMM_VERSION_MAJOR, LIBXSMM_VERSION_MINOR, LIBXSMM_VERSION_UPDATE, LIBXSMM_VERSION_PATCH)); + if (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) { + unsigned int linebreak = (0 == internal_print_statistic(stderr, target_arch, 1/*SP*/, 1, 0)) ? 1 : 0; + const int high_verbosity = (LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity); + char number_format_buffer[32]; + libxsmm_scratch_info scratch_info; +#if defined(LIBXSMM_PLATFORM_X86) + libxsmm_cpuid_info info; + libxsmm_cpuid_x86(&info); + if ((LIBXSMM_VERBOSITY_HIGH < libxsmm_verbosity || 0 > libxsmm_verbosity) && + 0 == internal_cpuid_info.has_context && 0 != info.has_context) + { + fprintf(stderr, "\nLIBXSMM: CPU features have been promoted."); + } +#endif + if (0 == internal_print_statistic(stderr, target_arch, 0/*DP*/, linebreak, 0) && 0 != linebreak && NULL != target_arch) { + fprintf(stderr, "\nLIBXSMM_TARGET: %s\n", target_arch); + } + if (0 != libxsmm_format_value(number_format_buffer, sizeof(number_format_buffer), +#if defined(LIBXSMM_NTHREADS_USE) && defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + sizeof(internal_cache_type) * (LIBXSMM_NTHREADS_MAX) + +#endif + (sizeof(internal_regkey_type) + sizeof(libxsmm_code_pointer)) * (LIBXSMM_CAPACITY_REGISTRY), + "KM", "B", 10)) + { + fprintf(stderr, "Registry and code: %s", number_format_buffer); + if (0 != libxsmm_format_value(number_format_buffer, sizeof(number_format_buffer), internal_registry_nbytes, "KM", "B", 10)) { + fprintf(stderr, " + %s", number_format_buffer); + } + if (0 != high_verbosity) { + unsigned int ngemms = 0; + int i; for (i = 0; i < 4; ++i) { + ngemms += internal_statistic[0/*DP*/][i].nsta + internal_statistic[1/*SP*/][i].nsta; + ngemms += internal_statistic[0/*DP*/][i].njit + internal_statistic[1/*SP*/][i].njit; + } + if (0 != ngemms || 0 != internal_statistic_num_gemv + || 0 != internal_statistic_num_meltw + || 0 != libxsmm_statistic_num_spmdm + || 0 != internal_statistic_num_user + || 0 != internal_registry_nleaks) + { + const char sep[] = " ", *s = ""; + fprintf(stderr, " ("); + if (0 != ngemms) { fprintf(stderr, "gemm=%u", ngemms); s = sep; } + if (0 != internal_statistic_num_gemv) { fprintf(stderr, "%sgemv=%u", s, internal_statistic_num_gemv); s = sep; } + if (0 != internal_statistic_num_meltw) { fprintf(stderr, "%smeltw=%u", s, internal_statistic_num_meltw); s = sep; } + if (0 != libxsmm_statistic_num_spmdm) { fprintf(stderr, "%sspmdm=%u", s, libxsmm_statistic_num_spmdm); s = sep; } + if (0 != internal_statistic_num_user) { fprintf(stderr, "%suser=%u", s, internal_statistic_num_user); s = sep; } + if (0 != internal_registry_nleaks) { fprintf(stderr, "%snleaks=%u", s, internal_registry_nleaks); s = sep; } + fprintf(stderr, ")"); + } + } + fprintf(stderr, "\n"); + } + if (EXIT_SUCCESS == libxsmm_get_scratch_info(&scratch_info)) { + if (0 != scratch_info.size && + 0 != libxsmm_format_value(number_format_buffer, sizeof(number_format_buffer), scratch_info.size, "KM", "B", 10)) + { + fprintf(stderr, "Scratch: %s", number_format_buffer); + if (0 != high_verbosity) { + fprintf(stderr, " (mallocs=%lu, pools=%u)\n", (unsigned long int)scratch_info.nmallocs, scratch_info.npools); + } + else { + fprintf(stderr, "\n"); + } + } + if (0 != scratch_info.internal && 0 != high_verbosity && + libxsmm_format_value(number_format_buffer, sizeof(number_format_buffer), scratch_info.internal, "KM", "B", 10)) + { + fprintf(stderr, "Private: %s\n", number_format_buffer); + } + } + if (LIBXSMM_VERBOSITY_HIGH < libxsmm_verbosity || 0 > libxsmm_verbosity) { + fprintf(stderr, "Uptime: %f s", libxsmm_timer_duration(internal_timer_start, libxsmm_timer_tick())); + if (1 < libxsmm_thread_count && INT_MAX == libxsmm_verbosity) { + fprintf(stderr, " (nthreads=%u)", libxsmm_thread_count); + } + fprintf(stderr, "\n"); + } + } + else { + fprintf(stderr, "\nLIBXSMM_TARGET: %s\n", target_arch); + } + } + /* release scratch memory pool */ + if (EXIT_SUCCESS != atexit(internal_release_scratch) && 0 != libxsmm_verbosity) { + fprintf(stderr, "LIBXSMM ERROR: failed to perform final cleanup!\n"); + } + /* determine whether this instance is unique or not */ + if (INTERNAL_SINGLETON(internal_singleton_handle)) { + internal_dump(stdout, 0/*urgent*/); + /* cleanup singleton */ +#if defined(_WIN32) + ReleaseMutex(internal_singleton_handle); + CloseHandle(internal_singleton_handle); +#else + unlink(internal_singleton_fname); + close(internal_singleton_handle); +#endif + } + LIBXSMM_STDIO_RELEASE(); /* synchronize I/O */ +#if (0 != LIBXSMM_SYNC) + { /* release locks */ +# if (1 < INTERNAL_REGLOCK_MAXN) + int i; for (i = 0; i < internal_reglock_count; ++i) LIBXSMM_LOCK_DESTROY(LIBXSMM_REGLOCK, &internal_reglock[i].state); +# elif !defined(LIBXSMM_UNIFY_LOCKS) + LIBXSMM_LOCK_DESTROY(LIBXSMM_REGLOCK, internal_reglock_ptr); +# endif + LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK, &libxsmm_lock_global); + } +#endif +} + + +#if defined(LIBXSMM_INTERCEPT_DYNAMIC) +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void _gfortran_stop_string(const char* /*message*/, int /*len*/, int /*quiet*/); +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void _gfortran_stop_string(const char* message, int len, int quiet) +{ /* STOP termination handler for GNU Fortran runtime */ + static int once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&once, 1, LIBXSMM_ATOMIC_RELAXED)) { + union { const void* dlsym; void (*ptr)(const char*, int, int); } stop; + dlerror(); /* clear an eventual error status */ + stop.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "_gfortran_stop_string"); + if (NULL != stop.dlsym) { + stop.ptr(message, len, quiet); + } + else exit(EXIT_SUCCESS); /* statically linked runtime */ + } +} + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void for_stop_core(const char* /*message*/, int /*len*/); +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void for_stop_core(const char* message, int len) +{ /* STOP termination handler for Intel Fortran runtime */ + static int once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&once, 1, LIBXSMM_ATOMIC_RELAXED)) { + union { const void* dlsym; void (*ptr)(const char*, int); } stop; + dlerror(); /* clear an eventual error status */ + stop.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "for_stop_core"); + if (NULL != stop.dlsym) { + stop.ptr(message, len); + } + else exit(EXIT_SUCCESS); /* statically linked runtime */ + } +} + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void for_stop_core_quiet(void); +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void for_stop_core_quiet(void) +{ /* STOP termination handler for Intel Fortran runtime */ + static int once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&once, 1, LIBXSMM_ATOMIC_RELAXED)) { + union { const void* dlsym; void (*ptr)(void); } stop; + dlerror(); /* clear an eventual error status */ + stop.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "for_stop_core_quiet"); + if (NULL != stop.dlsym) { + stop.ptr(); + } + else exit(EXIT_SUCCESS); /* statically linked runtime */ + } +} +#endif + + +LIBXSMM_API_INTERN size_t internal_strlen(const char* /*cstr*/, size_t /*maxlen*/); +LIBXSMM_API_INTERN size_t internal_strlen(const char* cstr, size_t maxlen) +{ + size_t result = 0; + if (NULL != cstr) { + while (0 != cstr[result] && result < maxlen) ++result; + } + return result; +} + + +LIBXSMM_API_INTERN size_t internal_parse_nbytes(const char* /*nbytes*/, size_t /*ndefault*/, int* /*valid*/); +LIBXSMM_API_INTERN size_t internal_parse_nbytes(const char* nbytes, size_t ndefault, int* valid) +{ + size_t result = ndefault; + if (NULL != nbytes && '\0' != *nbytes) { + size_t u = internal_strlen(nbytes, 32) - 1; + const char units[] = "kmgKMG", *const unit = strchr(units, nbytes[u]); + char* end = NULL; + /* take parsed value with increased type-width */ + const long long int ibytes = strtol(nbytes, &end, 10); + if (NULL != end && ( /* no obvious error */ + /* must match allowed set of units */ + (NULL != unit && *unit == *end) || + /* value is given without unit */ + (NULL == unit && '\0' == *end))) + { + result = (size_t)ibytes; + if ((size_t)LIBXSMM_UNLIMITED != result) { + u = (NULL != unit ? ((unit - units) % 3) : 3); + if (u < 3) { + result <<= (u + 1) * 10; + } + } + if (NULL != valid) *valid = 1; + } + else if (NULL != valid) *valid = 0; + } + else if (NULL != valid) { + *valid = 0; + } + return result; +} + + +LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_NO_TRACE void internal_init(void); +LIBXSMM_API_INTERN void internal_init(void) +{ + int i; +#if (0 != LIBXSMM_SYNC) /* setup the locks in a thread-safe fashion */ + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK, &libxsmm_lock_global); +# if (1 < INTERNAL_REGLOCK_MAXN) + for (i = 0; i < internal_reglock_count; ++i) LIBXSMM_LOCK_ACQUIRE(LIBXSMM_REGLOCK, &internal_reglock[i].state); +# elif !defined(LIBXSMM_UNIFY_LOCKS) + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_REGLOCK, internal_reglock_ptr); +# endif +#endif + if (NULL == internal_registry) { /* double-check after acquiring the lock(s) */ +#if defined(LIBXSMM_INTERCEPT_DYNAMIC) && defined(LIBXSMM_AUTOPIN) + /* clear error status (dummy condition: it does not matter if MPI_Init or MPI_Abort) */ + const char *const dlsymname = (NULL == dlerror() ? "MPI_Init" : "MPI_Abort"); + const void *const dlsymbol = dlsym(LIBXSMM_RTLD_NEXT, dlsymname); + const void *const dlmpi = (NULL == dlerror() ? dlsymbol : NULL); +#endif + const char *const env_verbose = getenv("LIBXSMM_VERBOSE"); + void* new_registry = NULL, * new_keys = NULL; +#if defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) +# if defined(LIBXSMM_NTHREADS_USE) + void* new_cache = NULL; +# endif + const char *const env_cache = getenv("LIBXSMM_CACHE"); + if (NULL != env_cache && '\0' != *env_cache) { + const int cache_size = atoi(env_cache), cache_size2 = LIBXSMM_UP2POT(cache_size); + internal_cache_size = LIBXSMM_MIN(cache_size2, LIBXSMM_CACHE_MAXSIZE); + } + else { + internal_cache_size = LIBXSMM_CACHE_MAXSIZE; + } +#endif + /* setup verbosity as early as possible since below code may rely on verbose output */ + if (NULL != env_verbose && '\0' != *env_verbose) { + libxsmm_verbosity = atoi(env_verbose); + } +#if !defined(NDEBUG) + else { + libxsmm_verbosity = INT_MAX; /* quiet -> verbose */ + } +#endif +#if (0 == LIBXSMM_JIT) + if (2 > libxsmm_ninit && (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity)) { + fprintf(stderr, "LIBXSMM: JIT-code generation was disabled at compile-time.\n"); + } +#endif +#if defined(LIBXSMM_AUTOPIN) +# if defined(LIBXSMM_INTERCEPT_DYNAMIC) + /* MPI: unwanted affinity can slow-down unrelated jobs (over-subscription), e.g., CP2K regtests */ + if (NULL == dlmpi) +# endif + { /* setup some viable affinity if nothing else is present */ + const char *const gomp_cpu_affinity = getenv("GOMP_CPU_AFFINITY"); + const char *const kmp_affinity = getenv("KMP_AFFINITY"); + const char *const omp_proc_bind = getenv("OMP_PROC_BIND"); + if ((NULL == gomp_cpu_affinity || 0 == *gomp_cpu_affinity) + && (NULL == kmp_affinity || 0 == *kmp_affinity) + && (NULL == omp_proc_bind || 0 == *omp_proc_bind)) + { + static char affinity[] = "OMP_PROC_BIND=TRUE"; + LIBXSMM_EXPECT(EXIT_SUCCESS, LIBXSMM_PUTENV(affinity)); + if (LIBXSMM_VERBOSITY_HIGH < libxsmm_verbosity || 0 > libxsmm_verbosity) { /* library code is expected to be mute */ + fprintf(stderr, "LIBXSMM: prepared to pin threads.\n"); + } + } + } +# if defined(LIBXSMM_INTERCEPT_DYNAMIC) && 1 + else if (NULL == getenv("I_MPI_SHM_HEAP")) { + static char shmheap[] = "I_MPI_SHM_HEAP=1"; + LIBXSMM_EXPECT(EXIT_SUCCESS, LIBXSMM_PUTENV(shmheap)); + } +# endif +#endif +#if !defined(_WIN32) && 0 + umask(S_IRUSR | S_IWUSR); /* setup default/secure file mask */ +#endif +#if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + { const char *const env = getenv("LIBXSMM_SCRATCH_POOLS"); + if (NULL == env || 0 == *env) { + libxsmm_scratch_pools = LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS; + } + else { + libxsmm_scratch_pools = LIBXSMM_CLMP(atoi(env), 0, LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS); + /*libxsmm_scratch_pools_locked = 1;*/ + } + LIBXSMM_ASSERT(libxsmm_scratch_pools <= LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS); + } + { const char *const env = getenv("LIBXSMM_SCRATCH_SCALE"); + if (NULL == env || 0 == *env) { + libxsmm_scratch_scale = LIBXSMM_MALLOC_SCRATCH_SCALE; + } + else { + libxsmm_scratch_scale = LIBXSMM_CLMP(atof(env), 1.0, 10.0); + /*libxsmm_scratch_scale_locked = 1;*/ + } + assert(1 <= libxsmm_scratch_scale); /* !LIBXSMM_ASSERT */ + } + libxsmm_set_scratch_limit(internal_parse_nbytes(getenv("LIBXSMM_SCRATCH_LIMIT"), LIBXSMM_SCRATCH_DEFAULT, NULL/*valid*/)); +#endif /*defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS))*/ + { /* setup malloc-interception after internal allocations */ + const libxsmm_malloc_function null_malloc_fn = { 0 }; + const libxsmm_free_function null_free_fn = { 0 }; + char *const env_k = getenv("LIBXSMM_MALLOC"), *const env_t = getenv("LIBXSMM_MALLOC_LIMIT"), *end = NULL; + const char* env_i = (NULL != env_t ? strtok(env_t, INTERNAL_DELIMS) : NULL); + size_t malloc_lo = internal_parse_nbytes(env_i, LIBXSMM_MALLOC_LIMIT, NULL/*valid*/); + size_t malloc_hi = (NULL != env_i ? internal_parse_nbytes( + strtok(NULL, INTERNAL_DELIMS), LIBXSMM_SCRATCH_UNLIMITED, NULL/*valid*/) : LIBXSMM_SCRATCH_UNLIMITED); + const int malloc_kind = ((NULL == env_k || 0 == *env_k) ? 0/*disabled*/ : ((int)strtol(env_k, &end, 10))); + libxsmm_xset_default_allocator(NULL/*lock*/, NULL/*context*/, null_malloc_fn, null_free_fn); + libxsmm_xset_scratch_allocator(NULL/*lock*/, NULL/*context*/, null_malloc_fn, null_free_fn); + /* libxsmm_set_malloc implies libxsmm_malloc_init */ + if (NULL == end) { + libxsmm_set_malloc(0, &malloc_lo, &malloc_hi); + } + else if ('\0' == *end) { + libxsmm_set_malloc(malloc_kind, &malloc_lo, &malloc_hi); + } + else { + int valid = 1; + env_i = strtok(env_k, INTERNAL_DELIMS); + malloc_lo = internal_parse_nbytes(env_i, LIBXSMM_MALLOC_LIMIT, &valid); + env_i = (0 != valid ? strtok(NULL, INTERNAL_DELIMS) : NULL); + malloc_hi = (NULL != env_i + ? internal_parse_nbytes(env_i, LIBXSMM_SCRATCH_UNLIMITED, &valid) + : LIBXSMM_SCRATCH_UNLIMITED); + libxsmm_set_malloc(0 != valid ? 1 : 0, &malloc_lo, &malloc_hi); + } + } +#if defined(LIBXSMM_MAXTARGET) + libxsmm_set_target_arch(LIBXSMM_STRINGIFY(LIBXSMM_MAXTARGET)); +#else /* attempt to set libxsmm_target_archid per environment variable */ + libxsmm_set_target_arch(getenv("LIBXSMM_TARGET")); +#endif + { const char *const env = getenv("LIBXSMM_SYNC"); + libxsmm_nosync = (NULL == env || 0 == *env) ? 0/*default*/ : atoi(env); + } + /* clear internal counters/statistic */ + for (i = 0; i < 4/*sml/med/big/xxx*/; ++i) { + LIBXSMM_MEMZERO127(&internal_statistic[0/*DP*/][i]); + LIBXSMM_MEMZERO127(&internal_statistic[1/*SP*/][i]); + } + internal_statistic_mnk = LIBXSMM_MAX_DIM; + internal_statistic_sml = 13; + internal_statistic_med = 23; + LIBXSMM_ASSERT(LIBXSMM_ISPOT(LIBXSMM_CAPACITY_REGISTRY)); + libxsmm_hash_init(libxsmm_target_archid); /* used by debug memory allocation (checksum) */ + libxsmm_memory_init(libxsmm_target_archid); + if ( +#if defined(LIBXSMM_NTHREADS_USE) && defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + (EXIT_SUCCESS == libxsmm_xmalloc(&new_cache, /* if internal_cache_size is zero, allocation must still happen (later control-flow too expensive) */ + sizeof(internal_cache_type) * (LIBXSMM_NTHREADS_MAX), LIBXSMM_CACHELINE/*alignment*/, + LIBXSMM_MALLOC_FLAG_PRIVATE, NULL/*extra*/, 0/*extra-size*/) && NULL != new_cache) && +#endif + (EXIT_SUCCESS == libxsmm_xmalloc(&new_keys, (LIBXSMM_CAPACITY_REGISTRY) * sizeof(internal_regkey_type), 0/*auto-align*/, + LIBXSMM_MALLOC_FLAG_PRIVATE, NULL/*extra*/, 0/*extra-size*/) && NULL != new_keys) && + (EXIT_SUCCESS == libxsmm_xmalloc(&new_registry, (LIBXSMM_CAPACITY_REGISTRY) * sizeof(libxsmm_code_pointer), 0/*auto-align*/, + LIBXSMM_MALLOC_FLAG_PRIVATE, NULL/*extra*/, 0/*extra-size*/) && NULL != new_registry)) + { +#if defined(LIBXSMM_NTHREADS_USE) && defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + LIBXSMM_ASSERT(NULL != new_cache); /* SA: suppress false positive */ + memset(new_cache, 0, (LIBXSMM_NTHREADS_MAX) * sizeof(internal_cache_type)); +#endif + libxsmm_xcopy_init(libxsmm_target_archid); + libxsmm_dnn_init(libxsmm_target_archid); + { const char *const env = getenv("LIBXSMM_GEMM_PREFETCH"); +#if (defined(_WIN32) || defined(__CYGWIN__)) + libxsmm_gemm_auto_prefetch_default = INTERNAL_PREFETCH; +#else + libxsmm_gemm_auto_prefetch_default = (0 == internal_statistic_ntry(0/*DP*/) && 0 == internal_statistic_ntry(1/*SP*/)) + /* avoid special prefetch if static code is present, since such code uses INTERNAL_PREFETCH */ + ? (((LIBXSMM_X86_AVX512 >= libxsmm_target_archid || LIBXSMM_X86_AVX512_CORE <= libxsmm_target_archid)) + ? LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C : LIBXSMM_GEMM_PREFETCH_BL2_VIA_C) + : INTERNAL_PREFETCH; +#endif + libxsmm_gemm_auto_prefetch = INTERNAL_PREFETCH; + if (NULL != env && '\0' != *env) { /* user input beyond auto-prefetch is always considered */ + const int uid = atoi(env); + if (0 <= uid) { + libxsmm_gemm_auto_prefetch_default = libxsmm_gemm_uid2prefetch(uid); + libxsmm_gemm_auto_prefetch = libxsmm_gemm_auto_prefetch_default; + internal_gemm_auto_prefetch_locked = 1; + } + } + } + for (i = 0; i < (LIBXSMM_CAPACITY_REGISTRY); ++i) ((libxsmm_code_pointer*)new_registry)[i].ptr = NULL; + LIBXSMM_ASSERT(NULL == internal_registry && NULL == internal_registry_keys); +#if defined(LIBXSMM_NTHREADS_USE) && defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + LIBXSMM_ASSERT(NULL == internal_cache_buffer); + internal_cache_buffer = (internal_cache_type*)new_cache; +#endif + internal_registry_keys = (internal_regkey_type*)new_keys; /* prior to registering static kernels */ +#if defined(LIBXSMM_BUILD) && !defined(LIBXSMM_DEFAULT_CONFIG) +# include +#endif + libxsmm_gemm_init(libxsmm_target_archid); +#if defined(LIBXSMM_TRACE) + { int filter_threadid = 0/*only main-thread*/, filter_mindepth = 0, filter_maxnsyms = 0; + const int init_code = libxsmm_trace_init(filter_threadid, filter_mindepth, filter_maxnsyms); + if (EXIT_SUCCESS != init_code && 0 != libxsmm_verbosity) { /* library code is expected to be mute */ + fprintf(stderr, "LIBXSMM ERROR: failed to initialize TRACE (error #%i)!\n", init_code); + } + } +#endif + { /* commit the registry buffer and enable global visibility */ + void *const pv_registry = &internal_registry; + LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE, LIBXSMM_BITS)((void**)pv_registry, (void*)new_registry, LIBXSMM_ATOMIC_SEQ_CST); + } + } + else { + if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + fprintf(stderr, "LIBXSMM ERROR: failed to allocate internal buffers!\n"); + } + libxsmm_xfree(new_registry, 0/*no check*/); + libxsmm_xfree(new_keys, 0/*no check*/); +#if defined(LIBXSMM_NTHREADS_USE) && defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + libxsmm_xfree(new_cache, 0/*no check*/); +#endif + } + } +#if (0 != LIBXSMM_SYNC) /* release locks */ +# if (1 < INTERNAL_REGLOCK_MAXN) + for (i = 0; i < internal_reglock_count; ++i) LIBXSMM_LOCK_RELEASE(LIBXSMM_REGLOCK, &internal_reglock[i].state); +# elif !defined(LIBXSMM_UNIFY_LOCKS) + LIBXSMM_LOCK_RELEASE(LIBXSMM_REGLOCK, internal_reglock_ptr); +# endif + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK, &libxsmm_lock_global); +#endif +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_CTOR void libxsmm_init(void) +{ + if (0 == LIBXSMM_ATOMIC_LOAD(&internal_registry, LIBXSMM_ATOMIC_RELAXED)) { + static unsigned int ninit = 0, gid = 0; + const unsigned int tid = LIBXSMM_ATOMIC_ADD_FETCH(&ninit, 1, LIBXSMM_ATOMIC_SEQ_CST); + LIBXSMM_ASSERT(0 < tid); + /* libxsmm_ninit (1: initialization started, 2: library initialized, higher: to invalidate code-TLS) */ + if (1 == tid) { + libxsmm_timer_tickint s0 = libxsmm_timer_tick_rtc(); /* warm-up */ + libxsmm_timer_tickint t0 = libxsmm_timer_tick_tsc(); /* warm-up */ + s0 = libxsmm_timer_tick_rtc(); t0 = libxsmm_timer_tick_tsc(); /* start timing */ + assert(0 == LIBXSMM_ATOMIC_LOAD(&libxsmm_ninit, LIBXSMM_ATOMIC_SEQ_CST)); /* !LIBXSMM_ASSERT */ + /* coverity[check_return] */ + LIBXSMM_ATOMIC_ADD_FETCH(&libxsmm_ninit, 1, LIBXSMM_ATOMIC_SEQ_CST); + gid = tid; /* protect initialization */ +#if (0 != LIBXSMM_SYNC) + /* coverity[check_return] */ + LIBXSMM_TLS_CREATE(&libxsmm_tlskey); + { /* construct and initialize locks */ +# if defined(LIBXSMM_REGLOCK_TRY) + const char *const env_trylock = getenv("LIBXSMM_TRYLOCK"); +# endif + LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK) attr_global; +# if (1 < INTERNAL_REGLOCK_MAXN) + int i; + LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_REGLOCK) attr; + LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_REGLOCK, &attr); +# elif defined(LIBXSMM_UNIFY_LOCKS) + internal_reglock_ptr = &libxsmm_lock_global; +# else + static LIBXSMM_LOCK_TYPE(LIBXSMM_REGLOCK) internal_reglock; + internal_reglock_ptr = &internal_reglock; + LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_REGLOCK) attr; + LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_REGLOCK, &attr); + LIBXSMM_LOCK_INIT(LIBXSMM_REGLOCK, internal_reglock_ptr, &attr); + LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_REGLOCK, &attr); +# endif + LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK, &attr_global); + LIBXSMM_LOCK_INIT(LIBXSMM_LOCK, &libxsmm_lock_global, &attr_global); + LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK, &attr_global); + /* control number of locks needed; LIBXSMM_TRYLOCK implies only 1 lock */ +# if defined(LIBXSMM_REGLOCK_TRY) + if (NULL == env_trylock || 0 == *env_trylock) +# endif + { /* no LIBXSMM_TRYLOCK */ +# if defined(LIBXSMM_VTUNE) + internal_reglock_count = 1; /* avoid duplicated kernels */ +# elif (1 < INTERNAL_REGLOCK_MAXN) + const char *const env_nlocks = getenv("LIBXSMM_NLOCKS"); + const int reglock_count = (NULL == env_nlocks || 0 == *env_nlocks || 1 > atoi(env_nlocks)) + ? (INTERNAL_REGLOCK_MAXN) : LIBXSMM_MIN(atoi(env_nlocks), INTERNAL_REGLOCK_MAXN); + internal_reglock_count = LIBXSMM_LO2POT(reglock_count); +# else + internal_reglock_count = 0; +# endif + } +# if defined(LIBXSMM_REGLOCK_TRY) + else { /* LIBXSMM_TRYLOCK environment variable specified */ + internal_reglock_count = (0 != atoi(env_trylock) ? 1 +# if (1 < INTERNAL_REGLOCK_MAXN) + : INTERNAL_REGLOCK_MAXN); +# else + : 0); +# endif + } +# endif +# if (1 < INTERNAL_REGLOCK_MAXN) + LIBXSMM_ASSERT(1 <= internal_reglock_count); + for (i = 0; i < internal_reglock_count; ++i) LIBXSMM_LOCK_INIT(LIBXSMM_REGLOCK, &internal_reglock[i].state, &attr); + LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_REGLOCK, &attr); +# endif + } +#endif + { /* determine whether this instance is unique or not */ +#if defined(_WIN32) + internal_singleton_handle = CreateMutex(NULL, TRUE, "GlobalLIBXSMM"); +#else + const int result = LIBXSMM_SNPRINTF(internal_singleton_fname, sizeof(internal_singleton_fname), "/tmp/.libxsmm.%u", + /*rely on user id to avoid permission issues in case of left-over files*/(unsigned int)getuid()); + struct flock singleton_flock; + int singleton_handle; + singleton_flock.l_start = 0; + singleton_flock.l_len = 0; /* entire file */ + singleton_flock.l_type = F_WRLCK; /* exclusive across PIDs */ + singleton_flock.l_whence = SEEK_SET; + singleton_handle = ((0 < result && (int)sizeof(internal_singleton_fname) > result) ? open( + internal_singleton_fname, O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR) : -1); + internal_singleton_handle = fcntl(singleton_handle, F_SETLK, &singleton_flock); + if (0 > internal_singleton_handle && 0 <= singleton_handle) close(singleton_handle); +#endif /* coverity[leaked_handle] */ + } + { /* calibrate timer */ + int register_termination_proc; + libxsmm_timer_tickint s1, t1; + internal_init(); /* must be first to initialize verbosity, etc. */ + if (INTERNAL_SINGLETON(internal_singleton_handle)) { /* after internal_init */ + internal_dump(stdout, 1/*urgent*/); + } + s1 = libxsmm_timer_tick_rtc(); t1 = libxsmm_timer_tick_tsc(); /* mid-timing */ +#if defined(LIBXSMM_PLATFORM_X86) + libxsmm_cpuid_x86(&internal_cpuid_info); + if (0 != internal_cpuid_info.constant_tsc && t0 < t1) { + libxsmm_timer_scale = libxsmm_timer_duration_rtc(s0, s1) / (t1 - t0); + } +#endif + register_termination_proc = atexit(internal_finalize); + s1 = libxsmm_timer_tick_rtc(); t1 = libxsmm_timer_tick_tsc(); /* final timing */ + /* set timer-scale and determine start of the "uptime" (shown at termination) */ + if (t0 < t1 && 0.0 < libxsmm_timer_scale) { + const double scale = libxsmm_timer_duration_rtc(s0, s1) / (t1 - t0); + const double diff = LIBXSMM_DELTA(libxsmm_timer_scale, scale) / scale; + if (5E-4 > diff) { + libxsmm_timer_scale = scale; + internal_timer_start = t0; + } + else { + libxsmm_timer_scale = 0; + internal_timer_start = s0; +#if defined(_DEBUG) + libxsmm_se = 1; +#endif + } + } + else { + internal_timer_start = s0; + libxsmm_timer_scale = 0; + } + if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + if (EXIT_SUCCESS != register_termination_proc) { + fprintf(stderr, "LIBXSMM ERROR: failed to register termination procedure!\n"); + } + if (0 == libxsmm_timer_scale +#if defined(LIBXSMM_PLATFORM_X86) + && 0 == internal_cpuid_info.constant_tsc +#endif + && (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity)) + { + /* ARM: TSC is currently not implemented, hence warning shows up (if verbose) */ + fprintf(stderr, "LIBXSMM WARNING: timer is maybe not cycle-accurate!\n"); + } + } + } + assert(1 == LIBXSMM_ATOMIC_LOAD(&libxsmm_ninit, LIBXSMM_ATOMIC_SEQ_CST)); /* !LIBXSMM_ASSERT */ + /* coverity[check_return] */ + LIBXSMM_ATOMIC_ADD_FETCH(&libxsmm_ninit, 1, LIBXSMM_ATOMIC_SEQ_CST); + } + else /*if (gid != tid)*/ { /* avoid recursion */ + LIBXSMM_ASSERT(gid != tid); + LIBXSMM_UNUSED(gid); + while (2 > LIBXSMM_ATOMIC_LOAD(&libxsmm_ninit, LIBXSMM_ATOMIC_RELAXED)) LIBXSMM_SYNC_YIELD; + internal_init(); + } +#if defined(LIBXSMM_PERF) + libxsmm_perf_init(); +#endif + } + LIBXSMM_ASSERT(1 < libxsmm_ninit); +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_NO_TRACE void libxsmm_finalize(void); +LIBXSMM_API LIBXSMM_ATTRIBUTE_DTOR void libxsmm_finalize(void) +{ + void *const regaddr = &internal_registry; + uintptr_t regptr = LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_LOAD, LIBXSMM_BITS)((uintptr_t*)regaddr, LIBXSMM_ATOMIC_RELAXED); + libxsmm_code_pointer* registry = (libxsmm_code_pointer*)regptr; + if (NULL != registry) { + int i; +#if (0 != LIBXSMM_SYNC) + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK, &libxsmm_lock_global); +# if (1 < INTERNAL_REGLOCK_MAXN) + { /* acquire locks and thereby shortcut lazy initialization later on */ + int ntry = 0, n; + do { + for (i = 0, n = 0; i < internal_reglock_count; ++i) { + if (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_REGLOCK) == LIBXSMM_LOCK_TRYLOCK(LIBXSMM_REGLOCK, &internal_reglock[i].state)) ++n; + } + ntry += (0 == n ? 1 : 0); + } while (n < internal_reglock_count && ntry < LIBXSMM_CLEANUP_NTRY); + } +# elif !defined(LIBXSMM_UNIFY_LOCKS) + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_REGLOCK, internal_reglock_ptr); +# endif +#endif + regptr = LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_LOAD, LIBXSMM_BITS)((uintptr_t*)regaddr, LIBXSMM_ATOMIC_RELAXED); + registry = (libxsmm_code_pointer*)regptr; + if (NULL != registry) { + internal_regkey_type *const registry_keys = internal_registry_keys; +#if defined(LIBXSMM_NTHREADS_USE) && defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + internal_cache_type *const cache_buffer = internal_cache_buffer; +#endif + unsigned int rest = 0, errors = 0; +#if defined(LIBXSMM_TRACE) + i = libxsmm_trace_finalize(); + if (EXIT_SUCCESS != i && 0 != libxsmm_verbosity) { /* library code is expected to be mute */ + fprintf(stderr, "LIBXSMM ERROR: failed to finalize trace (error #%i)!\n", i); + } +#endif +#if defined(LIBXSMM_PERF) + libxsmm_perf_finalize(); +#endif + libxsmm_xcopy_finalize(); + libxsmm_gemm_finalize(); + libxsmm_dnn_finalize(); + /* coverity[check_return] */ + LIBXSMM_ATOMIC_ADD_FETCH(&libxsmm_ninit, 1, LIBXSMM_ATOMIC_RELAXED); /* invalidate code cache (TLS) */ +#if defined(LIBXSMM_NTHREADS_USE) && defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + internal_cache_buffer = NULL; +#endif + internal_registry_keys = NULL; /* make registry keys unavailable */ + LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE_ZERO, LIBXSMM_BITS)((uintptr_t*)regaddr, LIBXSMM_ATOMIC_SEQ_CST); + internal_registry_nbytes = 0; internal_registry_nleaks = 0; + for (i = 0; i < (LIBXSMM_CAPACITY_REGISTRY); ++i) { + /*const*/ libxsmm_code_pointer code = registry[i]; + if (NULL != code.ptr_const) { + /* check if the registered entity is a GEMM kernel */ + switch (LIBXSMM_DESCRIPTOR_KIND(registry_keys[i].entry.kind)) { + case LIBXSMM_KERNEL_KIND_MATMUL: { + const libxsmm_gemm_descriptor *const desc = ®istry_keys[i].entry.gemm.desc; + if (1 < desc->m && 1 < desc->n) { + const unsigned int njit = (0 == (LIBXSMM_CODE_STATIC & code.uval) ? 1 : 0); + const unsigned int nsta = (0 != (LIBXSMM_CODE_STATIC & code.uval) ? 1 : 0); + /* count whether kernel is static or JIT-code */ + internal_update_mmstatistic(desc, 0, 0, njit, nsta); + } + else { + ++internal_statistic_num_gemv; + } + ++rest; + } break; + case LIBXSMM_KERNEL_KIND_MELTW: { + ++internal_statistic_num_meltw; + } break; + case LIBXSMM_KERNEL_KIND_USER: { + ++internal_statistic_num_user; + } break; + default: if (LIBXSMM_KERNEL_UNREGISTERED <= LIBXSMM_DESCRIPTOR_KIND(registry_keys[i].entry.kind)) { + ++errors; + } + else { + ++rest; + } + } + if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + if (0 != errors) { + fprintf(stderr, "LIBXSMM ERROR: code registry is corrupted!\n"); + } + if (LIBXSMM_CAPACITY_REGISTRY == (rest + errors + internal_statistic_num_gemv + + internal_statistic_num_user + internal_statistic_num_meltw)) + { + fprintf(stderr, "LIBXSMM WARNING: code registry was exhausted!\n"); + } + } + if (0 == (LIBXSMM_CODE_STATIC & code.uval)) { /* check for allocated/generated JIT-code */ +# if defined(__APPLE__) && defined(__arm64__) +# else + void* buffer = NULL; + size_t size = 0; +# endif +#if defined(LIBXSMM_HASH_COLLISION) + code.uval &= ~LIBXSMM_HASH_COLLISION; /* clear collision flag */ +#endif +# if defined(__APPLE__) && defined(__arm64__) + ++internal_registry_nleaks; +#else + if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(code.ptr_const, &size, NULL/*flags*/, &buffer)) { + if (LIBXSMM_KERNEL_KIND_USER == LIBXSMM_DESCRIPTOR_KIND(registry_keys[i].entry.kind) + /* dump user-data just like JIT'ted code */ + && 0 > libxsmm_verbosity) + { + char name[16]; + int nchar; +#if defined(LIBXSMM_REGUSER_HASH) + const size_t descsize = LIBXSMM_DESCRIPTOR_ISBIG(registry_keys[i].entry.kind) + ? LIBXSMM_DESCRIPTOR_MAXSIZE : LIBXSMM_DESCRIPTOR_SIGSIZE; + const unsigned int id = libxsmm_crc32(LIBXSMM_HASH_SEED, registry_keys[i].entry.user.desc, + descsize - sizeof(libxsmm_descriptor_kind)); + LIBXSMM_ASSERT(descsize > sizeof(libxsmm_descriptor_kind)); +#else + const unsigned int id = internal_statistic_num_user; +#endif + nchar = LIBXSMM_SNPRINTF(name, sizeof(name), "%010u.user", id); + if (0 < nchar && (int)sizeof(name) > nchar) { + LIBXSMM_EXPECT(EXIT_SUCCESS, libxsmm_dump("LIBXSMM-USER-DUMP", name, code.ptr_const, size, 0/*unique*/)); + } + } +#if !defined(NDEBUG) + registry[i].ptr = NULL; +#endif + libxsmm_xfree(code.ptr_const, 0/*no check*/); + /* round-up size (it is fine to assume 4 KB pages since it is likely more accurate than not rounding up) */ + internal_registry_nbytes += LIBXSMM_UP2(size + (((char*)code.ptr_const) - (char*)buffer), LIBXSMM_PAGE_MINSIZE); + } + else ++internal_registry_nleaks; +#endif + } + } + } + /* release buffers (registry, keys, cache) */ +#if defined(LIBXSMM_NTHREADS_USE) && defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + libxsmm_xfree(cache_buffer, 0/*no check*/); +#endif + libxsmm_xfree(registry_keys, 0/*no check*/); + libxsmm_xfree(registry, 0/*no check*/); + } +#if (0 != LIBXSMM_SYNC) /* LIBXSMM_LOCK_RELEASE, but no LIBXSMM_LOCK_DESTROY */ +# if (1 < INTERNAL_REGLOCK_MAXN) + for (i = 0; i < internal_reglock_count; ++i) LIBXSMM_LOCK_RELEASE(LIBXSMM_REGLOCK, &internal_reglock[i].state); +# elif !defined(LIBXSMM_UNIFY_LOCKS) + LIBXSMM_LOCK_RELEASE(LIBXSMM_REGLOCK, internal_reglock_ptr); +# endif + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK, &libxsmm_lock_global); + /* coverity[check_return] */ + LIBXSMM_TLS_DESTROY(libxsmm_tlskey); +#endif + } +} + + +LIBXSMM_API void libxsmm_sink(LIBXSMM_VARIADIC) +{ + /* does nothing else but sinking given arguments */ +} + + +LIBXSMM_API int libxsmm_get_target_archid(void) +{ + LIBXSMM_INIT +#if !defined(__MIC__) + return libxsmm_target_archid; +#else /* no JIT support */ + return LIBXSMM_MIN(libxsmm_target_archid, LIBXSMM_X86_GENERIC); +#endif +} + + +LIBXSMM_API void libxsmm_set_target_archid(int id) +{ + int target_archid = LIBXSMM_TARGET_ARCH_UNKNOWN; + switch (id) { + case LIBXSMM_X86_AVX512_SPR: + case LIBXSMM_X86_AVX512_CPX: + case LIBXSMM_X86_AVX512_CLX: + case LIBXSMM_X86_AVX512_CORE: + case LIBXSMM_X86_AVX512_KNM: + case LIBXSMM_X86_AVX512_MIC: + case LIBXSMM_X86_AVX512: + case LIBXSMM_X86_AVX2: + case LIBXSMM_X86_AVX: + case LIBXSMM_X86_SSE42: + case LIBXSMM_X86_SSE3: + case LIBXSMM_AARCH64_V81: + case LIBXSMM_AARCH64_V82: + case LIBXSMM_AARCH64_A64FX: { + target_archid = id; + } break; + case LIBXSMM_TARGET_ARCH_GENERIC: +#if defined(LIBXSMM_PLATFORM_X86) + target_archid = LIBXSMM_X86_GENERIC; + break; +#elif defined(LIBXSMM_PLATFORM_AARCH64) + target_archid = LIBXSMM_AARCH64_V81; + break; +#endif + default: target_archid = libxsmm_cpuid(); + } + LIBXSMM_ATOMIC_STORE(&libxsmm_target_archid, target_archid, LIBXSMM_ATOMIC_RELAXED); + if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + const int cpuid = libxsmm_cpuid(); + if (cpuid < target_archid) { + const char *const target_arch = libxsmm_cpuid_name(target_archid); + fprintf(stderr, "LIBXSMM WARNING: \"%s\" code may fail to run on \"%s\"!\n", + target_arch, libxsmm_cpuid_name(cpuid)); + } + } +} + + +LIBXSMM_API const char* libxsmm_get_target_arch(void) +{ + LIBXSMM_INIT + return libxsmm_cpuid_name(libxsmm_target_archid); +} + + +/* function serves as a helper for implementing the Fortran interface */ +LIBXSMM_API const char* libxsmmf_get_target_arch(int* length); +LIBXSMM_API const char* libxsmmf_get_target_arch(int* length) +{ + const char *const arch = libxsmm_get_target_arch(); + /* valid here since function is not in the public interface */ + LIBXSMM_ASSERT(NULL != arch && 0 != length); + *length = (int)strlen(arch); + return arch; +} + + +LIBXSMM_API void libxsmm_set_target_arch(const char* arch) +{ + const int cpuid = libxsmm_cpuid(); + int target_archid; + if (NULL != arch && '\0' != *arch) { +#if defined(LIBXSMM_PLATFORM_X86) + const int jit = atoi(arch); +#endif + if (0 == strcmp("0", arch)) { +#if defined(LIBXSMM_PLATFORM_X86) + target_archid = LIBXSMM_X86_GENERIC; +#elif defined(LIBXSMM_PLATFORM_AARCH64) + target_archid = LIBXSMM_AARCH64_V81; +#else + target_archid = LIBXSMM_TARGET_ARCH_GENERIC; +#endif + } +#if defined(LIBXSMM_PLATFORM_X86) + else if (0 < jit) { + target_archid = LIBXSMM_X86_GENERIC + jit; + } + else if (arch == libxsmm_stristr(arch, "spr") || arch == libxsmm_stristr(arch, "amx")) { + target_archid = LIBXSMM_X86_AVX512_SPR; + } + else if (arch == libxsmm_stristr(arch, "cpx")) { + target_archid = LIBXSMM_X86_AVX512_CPX; + } + else if (arch == libxsmm_stristr(arch, "clx")) { + target_archid = LIBXSMM_X86_AVX512_CLX; + } + else if (arch == libxsmm_stristr(arch, "skx") || arch == libxsmm_stristr(arch, "skl") + /* "avx3"/"avx512" previously enabled LIBXSMM_X86_AVX512 */ + || arch == libxsmm_stristr(arch, "avx3") || arch == libxsmm_stristr(arch, "avx512")) + { + target_archid = LIBXSMM_X86_AVX512_CORE; + } + else if (arch == libxsmm_stristr(arch, "knm")) { + target_archid = LIBXSMM_X86_AVX512_KNM; + } + else if (arch == libxsmm_stristr(arch, "knl") || arch == libxsmm_stristr(arch, "mic")) { + target_archid = LIBXSMM_X86_AVX512_MIC; + } + else if (arch == libxsmm_stristr(arch, "hsw") || arch == libxsmm_stristr(arch, "avx2")) { + target_archid = LIBXSMM_X86_AVX2; + } + else if (arch == libxsmm_stristr(arch, "snb") || arch == libxsmm_stristr(arch, "avx")) { + target_archid = LIBXSMM_X86_AVX; + } + else if (arch == libxsmm_stristr(arch, "wsm") || arch == libxsmm_stristr(arch, "nhm") + || arch == libxsmm_stristr(arch, "sse4_2") || arch == libxsmm_stristr(arch, "sse4.2") + || arch == libxsmm_stristr(arch, "sse42") || arch == libxsmm_stristr(arch, "sse4")) + { + target_archid = LIBXSMM_X86_SSE42; + } + else if (arch == libxsmm_stristr(arch, "sse3")) { + target_archid = LIBXSMM_X86_SSE3; + } + else if (arch == libxsmm_stristr(arch, "x86") || arch == libxsmm_stristr(arch, "x86_64") + || arch == libxsmm_stristr(arch, "x64") || arch == libxsmm_stristr(arch, "sse2")) + { + target_archid = LIBXSMM_X86_GENERIC; + } +#elif defined(LIBXSMM_PLATFORM_AARCH64) + else if (arch == libxsmm_stristr(arch, "arm") || arch == libxsmm_stristr(arch, "arm64") + || arch == libxsmm_stristr(arch, "arm_v81") + || arch == libxsmm_stristr(arch, "aarch64")) + { + target_archid = LIBXSMM_AARCH64_V81; + } + else if (arch == libxsmm_stristr(arch, "arm_v82")) { + target_archid = LIBXSMM_AARCH64_V82; + } + else if (arch == libxsmm_stristr(arch, "a64fx")) + { + target_archid = LIBXSMM_AARCH64_A64FX; + } +#endif + else if (arch == libxsmm_stristr(arch, "generic")) { +#if defined(LIBXSMM_PLATFORM_X86) + target_archid = LIBXSMM_X86_GENERIC; +#elif defined(LIBXSMM_PLATFORM_AARCH64) + target_archid = LIBXSMM_AARCH64_V81; +#else + target_archid = LIBXSMM_TARGET_ARCH_GENERIC; +#endif + } + else if (arch == libxsmm_stristr(arch, "none")) { + target_archid = LIBXSMM_TARGET_ARCH_GENERIC; + } + else { + target_archid = cpuid; + } + } + else { + target_archid = cpuid; + } + if (cpuid < target_archid) { /* warn about code path if beyond CPUID */ + static int error_once = 0; + if ( 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + const char *const target_arch = libxsmm_cpuid_name(target_archid); + fprintf(stderr, "LIBXSMM WARNING: \"%s\" code will fail to run on \"%s\"!\n", + target_arch, libxsmm_cpuid_name(cpuid)); + } +#if 0 /* limit code path to confirmed features */ + target_archid = cpuid; +#endif + } + LIBXSMM_ATOMIC_STORE(&libxsmm_target_archid, target_archid, LIBXSMM_ATOMIC_RELAXED); +} + + +LIBXSMM_API int libxsmm_get_verbosity(void) +{ + LIBXSMM_INIT + return libxsmm_verbosity; +} + + +LIBXSMM_API void libxsmm_set_verbosity(int level) +{ + LIBXSMM_INIT + LIBXSMM_ATOMIC_STORE(&libxsmm_verbosity, level, LIBXSMM_ATOMIC_RELAXED); +} + + +LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_auto_prefetch(void) +{ + return (libxsmm_gemm_prefetch_type)libxsmm_gemm_auto_prefetch; +} + + +LIBXSMM_API void libxsmm_set_gemm_auto_prefetch(libxsmm_gemm_prefetch_type strategy) +{ + if (0 == internal_gemm_auto_prefetch_locked) { /* LIBXSMM_GEMM_PREFETCH environment takes precedence */ + LIBXSMM_ATOMIC_STORE(&libxsmm_gemm_auto_prefetch_default, strategy, LIBXSMM_ATOMIC_RELAXED); + LIBXSMM_ATOMIC_STORE(&libxsmm_gemm_auto_prefetch, strategy, LIBXSMM_ATOMIC_RELAXED); + } +} + + +LIBXSMM_API unsigned char libxsmm_typesize(libxsmm_datatype datatype) +{ + const unsigned char result = (unsigned char)LIBXSMM_TYPESIZE(datatype); + if (0 != result) { + return result; + } + else { + static int error_once = 0; + LIBXSMM_ASSERT_MSG(0, "unsupported data type"); + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { + fprintf(stderr, "LIBXSMM ERROR: unsupported data type!\n"); + } + return 1; /* avoid to return 0 to avoid div-by-zero in static analysis of depending code */ + } +} + + +LIBXSMM_API int libxsmm_dvalue(libxsmm_datatype datatype, const void* value, double* dvalue) +{ + int result = EXIT_SUCCESS; + if (NULL != value && NULL != dvalue) { + switch (datatype) { + case LIBXSMM_DATATYPE_F64: *dvalue = (*(const double *)value); break; + case LIBXSMM_DATATYPE_F32: *dvalue = (double)(*(const float *)value); break; + case LIBXSMM_DATATYPE_I64: *dvalue = (double)(*(const long long*)value); break; + case LIBXSMM_DATATYPE_I32: *dvalue = (double)(*(const int *)value); break; + case LIBXSMM_DATATYPE_I16: *dvalue = (double)(*(const short *)value); break; + case LIBXSMM_DATATYPE_I8: *dvalue = (double)(*(const char *)value); break; + default: result = EXIT_FAILURE; + } + } + else { + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API_INTERN const char* libxsmm_typename(libxsmm_datatype datatype) +{ + switch (datatype) { + case LIBXSMM_DATATYPE_F64: return "f64"; + case LIBXSMM_DATATYPE_F32: return "f32"; + case LIBXSMM_DATATYPE_BF16: return "bf16"; + case LIBXSMM_DATATYPE_F16: return "f16"; + case LIBXSMM_DATATYPE_I64: return "i64"; + case LIBXSMM_DATATYPE_I32: return "i32"; + case LIBXSMM_DATATYPE_I16: return "i16"; + case LIBXSMM_DATATYPE_I8: return "i8"; + default: { + if (LIBXSMM_GEMM_PRECISION_I16 == LIBXSMM_GETENUM_INP(datatype) && + LIBXSMM_GEMM_PRECISION_I32 == LIBXSMM_GETENUM_OUT(datatype)) + { + return "i16i32"; + } + else if (LIBXSMM_GEMM_PRECISION_I16 == LIBXSMM_GETENUM_INP(datatype) && + LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_OUT(datatype)) + { + return "i16f32"; + } + else if (LIBXSMM_GEMM_PRECISION_I8 == LIBXSMM_GETENUM_INP(datatype) && + LIBXSMM_GEMM_PRECISION_I32 == LIBXSMM_GETENUM_OUT(datatype)) + { + return "i8i32"; + } + else if (LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_INP(datatype) && + LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_OUT(datatype)) + { + return "bf16f32"; + } + else { + return "void"; + } + } + } +} + + +LIBXSMM_API_INLINE void internal_get_typesize_string(char buffer[4], int buffer_size, size_t typesize) +{ + LIBXSMM_ASSERT(256 > typesize && 4 <= buffer_size); + if (10 > typesize) { + buffer[0] = (char)('0' + typesize); + buffer[1] = 0; + } + else { + LIBXSMM_SNPRINTF(buffer, buffer_size, "%i", (int)typesize); + } +} + + +LIBXSMM_API_INTERN int libxsmm_dump(const char* title, const char* name, const void* data, size_t size, int unique) +{ + int result; + if (NULL != name && '\0' != *name && NULL != data && 0 != size) { + FILE* data_file = fopen(name, "rb"); + int diff = 0, result_close; + if (NULL == data_file) { /* file does not exist */ + data_file = fopen(name, "wb"); + if (NULL != data_file) { /* dump data into a file */ + result = ((size == fwrite(data, 1, size, data_file)) ? EXIT_SUCCESS : EXIT_FAILURE); + result_close = fclose(data_file); + if (EXIT_SUCCESS == result) result = result_close; + } + else result = EXIT_FAILURE; + } + else if (0 != unique) { /* check existing file */ + const char* check_a = (const char*)data; + char check_b[4096]; + size_t rest = size; + do { + const size_t n = fread(check_b, 1, LIBXSMM_MIN(sizeof(check_b), rest), data_file); + diff += memcmp(check_a, check_b, LIBXSMM_MIN(sizeof(check_b), n)); + check_a += n; + rest -= n; + } while (0 < rest && 0 == diff); + result = fclose(data_file); + } + else { + result = fclose(data_file); + } + if (EXIT_SUCCESS == result && NULL != title && '\0' != *title) { + fprintf(stderr, "%s(ptr:file) %p : %s\n", title, data, name); + } + if (0 != diff) { /* override existing dump and warn about erroneous condition */ + fprintf(stderr, "LIBXSMM ERROR: %s is not a unique filename!\n", name); + data_file = fopen(name, "wb"); + if (NULL != data_file) { /* dump data into a file */ + if (size != fwrite(data, 1, size, data_file)) result = EXIT_FAILURE; + result_close = fclose(data_file); + if (EXIT_SUCCESS == result) result = result_close; + } + if (EXIT_SUCCESS == result) result = EXIT_FAILURE; + } + } + else { + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API_INTERN int libxsmm_build(const libxsmm_build_request* request, unsigned int regindex, libxsmm_code_pointer* code) +{ + int result = EXIT_SUCCESS; +#if !defined(__MIC__) + const char * /*const*/ target_arch = libxsmm_cpuid_name(libxsmm_target_archid); + /* large enough temporary buffer for generated code */ + char jit_buffer[LIBXSMM_CODE_MAXSIZE], jit_name[256] = { 0 }; + libxsmm_generated_code generated_code; + libxsmm_kernel_xinfo extra; + + LIBXSMM_MEMZERO127(&generated_code); + generated_code.generated_code = jit_buffer; + generated_code.buffer_size = sizeof(jit_buffer); + /* setup code generation */ + generated_code.arch = libxsmm_target_archid; + generated_code.code_type = 2; + +# if !defined(NDEBUG) /* should not be needed (all members will be initialized below) */ + LIBXSMM_MEMZERO127(&extra); +# endif + extra.registered = regindex; + extra.nflops = 0; + + LIBXSMM_ASSERT(NULL != generated_code.generated_code || 0 == generated_code.buffer_size); + LIBXSMM_ASSERT(NULL != request && 0 != libxsmm_target_archid); + LIBXSMM_ASSERT(NULL != code && NULL == code->ptr_const); + LIBXSMM_ASSERT(0 == LIBXSMM_DESCRIPTOR_ISBIG(request->kind)); + + switch (request->kind) { /* generate kernel */ + case LIBXSMM_BUILD_KIND_GEMM: { /* small MxM kernel */ + LIBXSMM_ASSERT(NULL != request->descriptor.gemm); +# if 0 /* dummy kernel for an empty shape is desired */ + if (0 < request->descriptor.gemm->m && 0 < request->descriptor.gemm->n && 0 < request->descriptor.gemm->k && + 0 < request->descriptor.gemm->lda && 0 < request->descriptor.gemm->ldb && 0 < request->descriptor.gemm->ldc) +# endif + { + const unsigned int m = request->descriptor.gemm->m, n = request->descriptor.gemm->n, k = request->descriptor.gemm->k; + extra.nflops = 2 * m * n * k; +# if !defined(LIBXSMM_DENY_RETARGET) /* disable: ECFLAGS=-DLIBXSMM_DENY_RETARGET */ + if ((LIBXSMM_X86_AVX2 < libxsmm_target_archid) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) && + (LIBXSMM_GEMM_PRECISION_F64 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.gemm->datatype) || + LIBXSMM_GEMM_PRECISION_F32 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.gemm->datatype)) && + (16 >= (m * k) || 16 >= (k * n) || 16 >= (m * n))) + { + /* TODO: shall we update variable "target_arch" (name)? */ + generated_code.arch = LIBXSMM_X86_AVX2; + } +# endif + LIBXSMM_NO_OFFLOAD(void, libxsmm_generator_gemm_kernel, &generated_code, request->descriptor.gemm); +# if !defined(LIBXSMM_VTUNE) + if (0 > libxsmm_verbosity) +# endif + { + const int uid = libxsmm_gemm_prefetch2uid((libxsmm_gemm_prefetch_type)request->descriptor.gemm->prefetch); + const char *const tname = libxsmm_typename((libxsmm_datatype)request->descriptor.gemm->datatype); + const char *const meltw_tname = libxsmm_typename((libxsmm_datatype)request->descriptor.gemm->meltw_datatype_aux); + int typesigns = 0, br = 0; + char tc_option[16] = { 0 }; + int decompress_A = 0; + int sparsity_factor_A = 1; + if ((request->descriptor.gemm->meltw_operation == LIBXSMM_MELTW_OPERATION_DECOMPRESS_A) || + (request->descriptor.gemm->meltw_operation == LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A)) + { + decompress_A = 1; + sparsity_factor_A = (int)request->descriptor.gemm->meltw_param; + } + + /* query batch reduce variant */ + if ( (LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS & request->descriptor.gemm->flags) > 1 ) { + br = 1; + } else if ( (LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET & request->descriptor.gemm->flags) > 1 ) { + br = 2; + } else if ( (LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE & request->descriptor.gemm->flags) > 1 ) { + br = 3; + } else { + br = 0; + } + /* query A/B sign combinations */ + if ( (LIBXSMM_GEMM_FLAG_A_UNSIGNED & request->descriptor.gemm->flags) > 1 ) { + typesigns = 1; + } else if ( (LIBXSMM_GEMM_FLAG_B_UNSIGNED & request->descriptor.gemm->flags) > 1 ) { + typesigns = 2; + } else if ( (LIBXSMM_GEMM_FLAG_AB_UNSIGNED & request->descriptor.gemm->flags) > 1 ) { + typesigns = 3; + } else { + typesigns = 0; + } + /* query tileconfig options */ + if (((LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG & request->descriptor.gemm->flags) != 0) && + ((LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG & request->descriptor.gemm->flags) == 0) ) { + LIBXSMM_SNPRINTF(tc_option, sizeof(tc_option), "conf"); + } else if (((LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG & request->descriptor.gemm->flags) == 0) && + ((LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG & request->descriptor.gemm->flags) != 0) ) { + LIBXSMM_SNPRINTF(tc_option, sizeof(tc_option), "rele"); + } else if (((LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG & request->descriptor.gemm->flags) != 0) && + ((LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG & request->descriptor.gemm->flags) != 0)) { + LIBXSMM_SNPRINTF(tc_option, sizeof(tc_option), "none"); + } else { + LIBXSMM_SNPRINTF(tc_option, sizeof(tc_option), "abid"); + } + + if ( request->descriptor.gemm->meltw_operation != 0 ) { + /* adopt scheme which allows kernel names of LIBXSMM to appear in order (Intel VTune, etc.) */ + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_%s_%c%c_%ux%ux%u_%u_%u_%u_a%i_b%i_p%i_br%i_uh%u_si%i_tc-%s_avnni%i_bvnni%i_cvnni%i_meop%u-%s_mefl%u_meld%u-%u-%u_decompress_A%i_spfactor%i.mxm", target_arch, tname, + 0 == (LIBXSMM_GEMM_FLAG_TRANS_A & request->descriptor.gemm->flags) ? 'n' : 't', + 0 == (LIBXSMM_GEMM_FLAG_TRANS_B & request->descriptor.gemm->flags) ? 'n' : 't', m, n, k, + request->descriptor.gemm->lda, request->descriptor.gemm->ldb, request->descriptor.gemm->ldc, + /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & request->descriptor.gemm->flags) ? 0 : */1, + 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & request->descriptor.gemm->flags) ? 0 : 1, uid, + br, (unsigned int)request->descriptor.gemm->c3, typesigns, tc_option, + 0 != (LIBXSMM_GEMM_FLAG_VNNI_A & request->descriptor.gemm->flags) ? 1 : 0, + 0 != (LIBXSMM_GEMM_FLAG_VNNI_B & request->descriptor.gemm->flags) ? 1 : 0, + 0 != (LIBXSMM_GEMM_FLAG_VNNI_C & request->descriptor.gemm->flags) ? 1 : 0, + (unsigned int)request->descriptor.gemm->meltw_operation, meltw_tname, (unsigned int)request->descriptor.gemm->meltw_flags, + request->descriptor.gemm->meltw_ldx, request->descriptor.gemm->meltw_ldy, request->descriptor.gemm->meltw_ldz, decompress_A, sparsity_factor_A ); + } else { + /* adopt scheme which allows kernel names of LIBXSMM to appear in order (Intel VTune, etc.) */ + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_%s_%c%c_%ux%ux%u_%u_%u_%u_a%i_b%i_p%i_br%i_uh%u_si%i_tc-%s_avnni%i_bvnni%i_cvnni%i_decompress_A%i_spfactor%i.mxm", target_arch, tname, + 0 == (LIBXSMM_GEMM_FLAG_TRANS_A & request->descriptor.gemm->flags) ? 'n' : 't', + 0 == (LIBXSMM_GEMM_FLAG_TRANS_B & request->descriptor.gemm->flags) ? 'n' : 't', m, n, k, + request->descriptor.gemm->lda, request->descriptor.gemm->ldb, request->descriptor.gemm->ldc, + /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & request->descriptor.gemm->flags) ? 0 : */1, + 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & request->descriptor.gemm->flags) ? 0 : 1, uid, + br, (unsigned int)request->descriptor.gemm->c3, typesigns, tc_option, + 0 != (LIBXSMM_GEMM_FLAG_VNNI_A & request->descriptor.gemm->flags) ? 1 : 0, + 0 != (LIBXSMM_GEMM_FLAG_VNNI_B & request->descriptor.gemm->flags) ? 1 : 0, + 0 != (LIBXSMM_GEMM_FLAG_VNNI_C & request->descriptor.gemm->flags) ? 1 : 0, decompress_A, sparsity_factor_A ); + } + } + } + } break; + case LIBXSMM_BUILD_KIND_PSPGEMM_CSR: { /* packed sparse gemm kernel, CSR format */ + LIBXSMM_ASSERT(NULL != request->descriptor.pspgemm_csr && 0 != request->descriptor.pspgemm_csr->gemm); + LIBXSMM_ASSERT(NULL != request->descriptor.pspgemm_csr->row_ptr && 0 != request->descriptor.pspgemm_csr->column_idx && 0 != request->descriptor.pspgemm_csr->values); + /* only floating point */ + if (LIBXSMM_GEMM_PRECISION_F64 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.pspgemm_csr->gemm->datatype) || + LIBXSMM_GEMM_PRECISION_F32 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.pspgemm_csr->gemm->datatype)) + { + const unsigned int nnz = (request->descriptor.pspgemm_csr->gemm->lda == 0) ? + request->descriptor.pspgemm_csr->row_ptr[request->descriptor.pspgemm_csr->gemm->m] : request->descriptor.pspgemm_csr->row_ptr[request->descriptor.pspgemm_csr->gemm->k]; + const unsigned int gemm_factor = (request->descriptor.pspgemm_csr->gemm->lda == 0) ? request->descriptor.pspgemm_csr->gemm->n : request->descriptor.pspgemm_csr->gemm->m; + extra.nflops = 2 * nnz * gemm_factor * request->descriptor.pspgemm_csr->packed_width; + LIBXSMM_NO_OFFLOAD(void, libxsmm_generator_packed_spgemm_csr_kernel, &generated_code, request->descriptor.pspgemm_csr->gemm, + request->descriptor.pspgemm_csr->row_ptr, request->descriptor.pspgemm_csr->column_idx, request->descriptor.pspgemm_csr->values, request->descriptor.pspgemm_csr->packed_width); +# if !defined(LIBXSMM_VTUNE) + if (0 > libxsmm_verbosity) +# endif + { + const int uid = libxsmm_gemm_prefetch2uid((libxsmm_gemm_prefetch_type)request->descriptor.pspgemm_csr->gemm->prefetch); + const char *const tname = libxsmm_typename((libxsmm_datatype)request->descriptor.pspgemm_csr->gemm->datatype); + /* adopt scheme which allows kernel names of LIBXSMM to appear in order (Intel VTune, etc.) */ + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_%s_%c%c_%ux%ux%u_%u_%u_%u_w%u_a%i_b%i_p%i_nnz%u.pspgemm_csr", target_arch, tname, + 0 == (LIBXSMM_GEMM_FLAG_TRANS_A & request->descriptor.pspgemm_csr->gemm->flags) ? 'n' : 't', + 0 == (LIBXSMM_GEMM_FLAG_TRANS_B & request->descriptor.pspgemm_csr->gemm->flags) ? 'n' : 't', + request->descriptor.pspgemm_csr->gemm->m, request->descriptor.pspgemm_csr->gemm->n, request->descriptor.pspgemm_csr->gemm->k, + request->descriptor.pspgemm_csr->gemm->lda, request->descriptor.pspgemm_csr->gemm->ldb, request->descriptor.pspgemm_csr->gemm->ldc, + request->descriptor.pspgemm_csr->packed_width, + /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & request->descriptor.pspgemm_csr->gemm->flags) ? 0 : */1, + 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & request->descriptor.pspgemm_csr->gemm->flags) ? 0 : 1, + uid, nnz); + } + } + } break; + case LIBXSMM_BUILD_KIND_PSPGEMM_CSC: { /* packed sparse gemm kernel, CSC format */ + LIBXSMM_ASSERT(NULL != request->descriptor.pspgemm_csc && 0 != request->descriptor.pspgemm_csc->gemm); + LIBXSMM_ASSERT(NULL != request->descriptor.pspgemm_csc->row_idx && 0 != request->descriptor.pspgemm_csc->column_ptr && 0 != request->descriptor.pspgemm_csc->values); + /* only floating point */ + if (LIBXSMM_GEMM_PRECISION_F64 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.pspgemm_csc->gemm->datatype) || + LIBXSMM_GEMM_PRECISION_F32 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.pspgemm_csc->gemm->datatype)) + { + const unsigned int nnz = (request->descriptor.pspgemm_csc->gemm->lda == 0) ? + request->descriptor.pspgemm_csc->column_ptr[request->descriptor.pspgemm_csc->gemm->k] : request->descriptor.pspgemm_csc->column_ptr[request->descriptor.pspgemm_csc->gemm->n]; + const unsigned int gemm_factor = (request->descriptor.pspgemm_csc->gemm->lda == 0) ? request->descriptor.pspgemm_csc->gemm->n : request->descriptor.pspgemm_csc->gemm->m; + extra.nflops = 2 * nnz * gemm_factor * request->descriptor.pspgemm_csc->packed_width; + LIBXSMM_NO_OFFLOAD(void, libxsmm_generator_packed_spgemm_csc_kernel, &generated_code, request->descriptor.pspgemm_csc->gemm, + request->descriptor.pspgemm_csc->row_idx, request->descriptor.pspgemm_csc->column_ptr, request->descriptor.pspgemm_csc->values, request->descriptor.pspgemm_csc->packed_width); +# if !defined(LIBXSMM_VTUNE) + if (0 > libxsmm_verbosity) +# endif + { + const int uid = libxsmm_gemm_prefetch2uid((libxsmm_gemm_prefetch_type)request->descriptor.pspgemm_csc->gemm->prefetch); + const char *const tname = libxsmm_typename((libxsmm_datatype)request->descriptor.pspgemm_csc->gemm->datatype); + /* adopt scheme which allows kernel names of LIBXSMM to appear in order (Intel VTune, etc.) */ + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_%s_%c%c_%ux%ux%u_%u_%u_%u_w%u_a%i_b%i_p%i_nnz%u.pspgemm_csc", target_arch, tname, + 0 == (LIBXSMM_GEMM_FLAG_TRANS_A & request->descriptor.pspgemm_csc->gemm->flags) ? 'n' : 't', + 0 == (LIBXSMM_GEMM_FLAG_TRANS_B & request->descriptor.pspgemm_csc->gemm->flags) ? 'n' : 't', + request->descriptor.pspgemm_csc->gemm->m, request->descriptor.pspgemm_csc->gemm->n, request->descriptor.pspgemm_csc->gemm->k, + request->descriptor.pspgemm_csc->gemm->lda, request->descriptor.pspgemm_csc->gemm->ldb, request->descriptor.pspgemm_csc->gemm->ldc, + request->descriptor.pspgemm_csc->packed_width, + /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & request->descriptor.pspgemm_csc->gemm->flags) ? 0 : */1, + 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & request->descriptor.pspgemm_csc->gemm->flags) ? 0 : 1, + uid, nnz); + } + } + } break; + case LIBXSMM_BUILD_KIND_PGEMMRMAC: { /* packed GEMM, B regular matrix, row-major */ + LIBXSMM_ASSERT(NULL != request->descriptor.pgemmacrm && 0 != request->descriptor.pgemmacrm->gemm); + /* only floating point */ + if (LIBXSMM_GEMM_PRECISION_F64 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.pgemmacrm->gemm->datatype) || + LIBXSMM_GEMM_PRECISION_F32 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.pgemmacrm->gemm->datatype)) + { + extra.nflops = 2 * request->descriptor.pgemmacrm->packed_width * request->descriptor.pgemmacrm->gemm->m * request->descriptor.pgemmacrm->gemm->n * request->descriptor.pgemmacrm->gemm->k; + LIBXSMM_NO_OFFLOAD(void, libxsmm_generator_packed_gemm_ac_rm, &generated_code, request->descriptor.pgemmacrm->gemm, request->descriptor.pgemmacrm->packed_width); +# if !defined(LIBXSMM_VTUNE) + if (0 > libxsmm_verbosity) +# endif + { + const int uid = libxsmm_gemm_prefetch2uid((libxsmm_gemm_prefetch_type)request->descriptor.pgemmacrm->gemm->prefetch); + const char *const tname = libxsmm_typename((libxsmm_datatype)request->descriptor.pgemmacrm->gemm->datatype); + /* adopt scheme which allows kernel names of LIBXSMM to appear in order (Intel VTune, etc.) */ + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_%s_%c%c_%ux%ux%u_%u_%u_%u_w%u_a%i_b%i_p%i.pgemmacrm", target_arch, tname, + 0 == (LIBXSMM_GEMM_FLAG_TRANS_A & request->descriptor.pgemmacrm->gemm->flags) ? 'n' : 't', + 0 == (LIBXSMM_GEMM_FLAG_TRANS_B & request->descriptor.pgemmacrm->gemm->flags) ? 'n' : 't', + request->descriptor.pgemmacrm->gemm->m, request->descriptor.pgemmacrm->gemm->n, request->descriptor.pgemmacrm->gemm->k, + request->descriptor.pgemmacrm->gemm->lda, request->descriptor.pgemmacrm->gemm->ldb, request->descriptor.pgemmacrm->gemm->ldc, + request->descriptor.pgemmacrm->packed_width, + /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & request->descriptor.pgemmacrm->gemm->flags) ? 0 : */1, + 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & request->descriptor.pgemmacrm->gemm->flags) ? 0 : 1, + uid); + } + } + } break; + case LIBXSMM_BUILD_KIND_PGEMMRMBC: { /* packed GEMM, A regular matrix, row-major */ + LIBXSMM_ASSERT(NULL != request->descriptor.pgemmbcrm && 0 != request->descriptor.pgemmbcrm->gemm); + /* only floating point */ + if (LIBXSMM_GEMM_PRECISION_F64 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.pgemmbcrm->gemm->datatype) || + LIBXSMM_GEMM_PRECISION_F32 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.pgemmbcrm->gemm->datatype)) + { + extra.nflops = 2 * request->descriptor.pgemmbcrm->packed_width * request->descriptor.pgemmbcrm->gemm->m * request->descriptor.pgemmbcrm->gemm->n * request->descriptor.pgemmbcrm->gemm->k; + LIBXSMM_NO_OFFLOAD(void, libxsmm_generator_packed_gemm_bc_rm, &generated_code, request->descriptor.pgemmbcrm->gemm, request->descriptor.pgemmbcrm->packed_width); +# if !defined(LIBXSMM_VTUNE) + if (0 > libxsmm_verbosity) +# endif + { + const int uid = libxsmm_gemm_prefetch2uid((libxsmm_gemm_prefetch_type)request->descriptor.pgemmbcrm->gemm->prefetch); + const char *const tname = libxsmm_typename((libxsmm_datatype)request->descriptor.pgemmbcrm->gemm->datatype); + /* adopt scheme which allows kernel names of LIBXSMM to appear in order (Intel VTune, etc.) */ + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_%s_%c%c_%ux%ux%u_%u_%u_%u_w%u_a%i_b%i_p%i.pgemmbcrm", target_arch, tname, + 0 == (LIBXSMM_GEMM_FLAG_TRANS_A & request->descriptor.pgemmbcrm->gemm->flags) ? 'n' : 't', + 0 == (LIBXSMM_GEMM_FLAG_TRANS_B & request->descriptor.pgemmbcrm->gemm->flags) ? 'n' : 't', + request->descriptor.pgemmbcrm->gemm->m, request->descriptor.pgemmbcrm->gemm->n, request->descriptor.pgemmbcrm->gemm->k, + request->descriptor.pgemmbcrm->gemm->lda, request->descriptor.pgemmbcrm->gemm->ldb, request->descriptor.pgemmbcrm->gemm->ldc, + request->descriptor.pgemmbcrm->packed_width, + /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & request->descriptor.pgemmbcrm->gemm->flags) ? 0 : */1, + 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & request->descriptor.pgemmbcrm->gemm->flags) ? 0 : 1, + uid); + } + } + } break; + case LIBXSMM_BUILD_KIND_SREG: { /* sparse register kernel */ + LIBXSMM_ASSERT(NULL != request->descriptor.sreg && 0 != request->descriptor.sreg->gemm); + LIBXSMM_ASSERT(NULL != request->descriptor.sreg->row_ptr && 0 != request->descriptor.sreg->column_idx && 0 != request->descriptor.sreg->values); + /* only floating point */ + if (LIBXSMM_GEMM_PRECISION_F64 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.sreg->gemm->datatype) || + LIBXSMM_GEMM_PRECISION_F32 == /*LIBXSMM_GETENUM_OUT*/(request->descriptor.sreg->gemm->datatype)) + { + const unsigned int nnz = request->descriptor.sreg->row_ptr[request->descriptor.sreg->gemm->m]; + extra.nflops = 2 * libxsmm_cpuid_vlen32(libxsmm_target_archid)/2 * request->descriptor.sreg->gemm->n * nnz; + LIBXSMM_NO_OFFLOAD(void, libxsmm_generator_spgemm_csr_reg_kernel, &generated_code, request->descriptor.sreg->gemm, target_arch, + request->descriptor.sreg->row_ptr, request->descriptor.sreg->column_idx, + (const double*)request->descriptor.sreg->values); +# if !defined(LIBXSMM_VTUNE) + if (0 > libxsmm_verbosity) +# endif + { + const int uid = libxsmm_gemm_prefetch2uid((libxsmm_gemm_prefetch_type)request->descriptor.sreg->gemm->prefetch); + const char *const tname = libxsmm_typename((libxsmm_datatype)request->descriptor.sreg->gemm->datatype); + /* adopt scheme which allows kernel names of LIBXSMM to appear in order (Intel VTune, etc.) */ + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_%s_%c%c_%ux%ux%u_%u_%u_%u_a%i_b%i_p%i.sreg", target_arch, tname, + 0 == (LIBXSMM_GEMM_FLAG_TRANS_A & request->descriptor.sreg->gemm->flags) ? 'n' : 't', + 0 == (LIBXSMM_GEMM_FLAG_TRANS_B & request->descriptor.sreg->gemm->flags) ? 'n' : 't', + request->descriptor.sreg->gemm->m, request->descriptor.sreg->gemm->n, request->descriptor.sreg->gemm->k, + request->descriptor.sreg->gemm->lda, request->descriptor.sreg->gemm->ldb, request->descriptor.sreg->gemm->ldc, + /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & request->descriptor.sreg->gemm->flags) ? 0 : */1, + 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & request->descriptor.sreg->gemm->flags) ? 0 : 1, + uid); + } + } + } break; + case LIBXSMM_BUILD_KIND_MELTW: { /* matcopy kernel */ + LIBXSMM_ASSERT(NULL != request->descriptor.meltw); + { + /* dispatch eltwise code with AVX512_BF16 by demoting seemlessly to the current CPU arch */ + if ( ( generated_code.arch >= LIBXSMM_X86_AVX512_SPR ) && + ( generated_code.arch <= LIBXSMM_X86_ALLFEAT ) ) { + int emu_amx = 0; + const char *const env_emu_amx = getenv("EMULATE_AMX"); + if ( 0 == env_emu_amx ) { + } else { + emu_amx = atoi(env_emu_amx); + } + if (emu_amx > 0) { + generated_code.arch = libxsmm_cpuid(); + } + } + LIBXSMM_NO_OFFLOAD(void, libxsmm_generator_mateltwise_kernel, &generated_code, request->descriptor.meltw); +# if !defined(LIBXSMM_VTUNE) + if (0 > libxsmm_verbosity) +# endif + { + char tsizename[4]; + internal_get_typesize_string(tsizename, sizeof(tsizename), request->descriptor.meltw->datatype); + /* adopt scheme which allows kernel names of LIBXSMM to appear in order (Intel VTune, etc.) */ + if ( request->descriptor.meltw->operation == LIBXSMM_MELTW_OPERATION_REDUCE_COLS_IDX ) { + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_tsize%s_idxtsize%u_%u_%ux%u_opcode%u_flags%u.meltw", target_arch, tsizename, + request->descriptor.meltw->n, request->descriptor.meltw->m, request->descriptor.meltw->ldi, request->descriptor.meltw->ldo, + (unsigned int)request->descriptor.meltw->operation, (unsigned int)request->descriptor.meltw->flags); + } else { + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_tsize%s_%ux%u_%ux%u_opcode%u_flags%u_params%u.meltw", target_arch, tsizename, + request->descriptor.meltw->m, request->descriptor.meltw->n, request->descriptor.meltw->ldi, request->descriptor.meltw->ldo, + (unsigned int)request->descriptor.meltw->operation, (unsigned int)request->descriptor.meltw->flags, (unsigned int)request->descriptor.meltw->param); + } + } + } + } break; + case LIBXSMM_BUILD_KIND_MEQN: { /* matequation kernel */ + LIBXSMM_ASSERT(NULL != request->descriptor.meltw); + { + /* dispatch eltwise code with AVX512_BF16 by demoting seemlessly to the current CPU arch */ + if ( ( generated_code.arch >= LIBXSMM_X86_AVX512_SPR ) && + ( generated_code.arch <= LIBXSMM_X86_ALLFEAT ) ) { + int emu_amx = 0; + const char *const env_emu_amx = getenv("EMULATE_AMX"); + if ( 0 == env_emu_amx ) { + } else { + emu_amx = atoi(env_emu_amx); + } + if (emu_amx > 0) { + generated_code.arch = libxsmm_cpuid(); + } + } + LIBXSMM_NO_OFFLOAD(void, libxsmm_generator_matequation_kernel, &generated_code, request->descriptor.meqn); +# if !defined(LIBXSMM_VTUNE) + if (0 > libxsmm_verbosity) +# endif + { + char tsizename[4]; + internal_get_typesize_string(tsizename, sizeof(tsizename), request->descriptor.meqn->datatype); + LIBXSMM_SNPRINTF(jit_name, sizeof(jit_name), "libxsmm_%s_tsize%s_%ux%u_%u_eqn-idx%u.meltw", target_arch, tsizename, + request->descriptor.meqn->m, request->descriptor.meqn->n, request->descriptor.meqn->ldo, + (unsigned int)request->descriptor.meqn->eqn_idx ); + } + } + } break; + case LIBXSMM_BUILD_KIND_USER: break; +# if !defined(NDEBUG) /* library code is expected to be mute */ + default: { /* unknown kind */ + static int error_once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { + fprintf(stderr, "LIBXSMM ERROR: invalid build request discovered!\n"); + } + /*result = EXIT_FAILURE;*/ + } +# endif + } + + if (0 == generated_code.last_error /* no error raised */ + && 0 != generated_code.code_size /*check (tcopy issue?)*/) + { + char* code_buffer = NULL; +# if defined(__APPLE__) && defined(__arm64__) +# else + void* code_buffer_result = &code_buffer; +# endif + LIBXSMM_ASSERT(generated_code.code_size <= LIBXSMM_CODE_MAXSIZE); + LIBXSMM_ASSERT(NULL != generated_code.generated_code); + /* attempt to create executable buffer */ +# if defined(__APPLE__) && defined(__arm64__) + code_buffer = mmap( 0, generated_code.code_size, PROT_WRITE | PROT_EXEC | PROT_READ, + MAP_PRIVATE | MAP_ANONYMOUS | MAP_JIT, -1, 0 ); + if ( (long long)code_buffer >= 0 ) { + result = EXIT_SUCCESS; + } else { + result = EXIT_FAILURE; + } +# else + result = libxsmm_xmalloc((void**)code_buffer_result, generated_code.code_size, 0/*auto*/, + /* flag must be a superset of what's populated by libxsmm_malloc_attrib */ + LIBXSMM_MALLOC_FLAG_RWX, &extra, sizeof(extra)); +# endif + if (EXIT_SUCCESS == result) { /* check for success */ + LIBXSMM_ASSERT(NULL != code_buffer); +# if defined(__APPLE__) && defined(__arm64__) + pthread_jit_write_protect_np(0/*false*/); +# endif + /* copy temporary buffer into the prepared executable buffer */ +# if defined(NDEBUG) + { int i; /* precondition: jit_buffer == generated_code.generated_code */ + for (i = 0; i < (int)generated_code.code_size; ++i) code_buffer[i] = jit_buffer[i]; + } +# else + memcpy(code_buffer, generated_code.generated_code, generated_code.code_size); +# endif +# if defined(__APPLE__) && defined(__arm64__) + code->ptr = code_buffer; /* commit buffer */ + LIBXSMM_ASSERT(NULL != code->ptr && 0 == (LIBXSMM_CODE_STATIC & code->uval)); + sys_icache_invalidate(code_buffer, generated_code.code_size); + pthread_jit_write_protect_np(1/*true*/); +# else + /* attribute/protect buffer and revoke unnecessary flags */ + result = libxsmm_malloc_attrib((void**)code_buffer_result, LIBXSMM_MALLOC_FLAG_X, jit_name); + if (EXIT_SUCCESS == result) { /* check for success */ + code->ptr = code_buffer; /* commit buffer */ + LIBXSMM_ASSERT(NULL != code->ptr && 0 == (LIBXSMM_CODE_STATIC & code->uval)); +# if defined(__aarch64__) +# if defined(__clang__) + __clear_cache(code_buffer, code_buffer + generated_code.code_size); +# else + __builtin___clear_cache(code_buffer, code_buffer + generated_code.code_size); +# endif +# endif + } + else { /* release buffer */ + libxsmm_xfree(code_buffer, 0/*no check*/); + } +# endif + } + } + else if (request->kind == LIBXSMM_BUILD_KIND_USER && NULL != request->descriptor.ptr) { /* user-data */ + if (0 != request->user_size) { + void* user_data = &code->ptr; + result = libxsmm_xmalloc((void**)user_data, request->user_size, 0/*auto*/, + LIBXSMM_MALLOC_FLAG_PRIVATE, &extra, sizeof(extra)); + } + else { + result = EXIT_SUCCESS; + code->ptr = NULL; + } + } + else { + result = (0 != generated_code.last_error ? generated_code.last_error : EXIT_FAILURE); + } +#else /* unsupported platform */ + LIBXSMM_UNUSED(request); LIBXSMM_UNUSED(regindex); LIBXSMM_UNUSED(code); + /* libxsmm_get_target_arch also serves as a runtime check whether JIT is available or not */ + if (LIBXSMM_X86_GENERIC <= libxsmm_target_archid) result = EXIT_FAILURE; +#endif + return result; +} + + +LIBXSMM_API_INLINE void internal_pad_descriptor(libxsmm_descriptor* desc, signed char size) +{ + LIBXSMM_ASSERT(LIBXSMM_DESCRIPTOR_MAXSIZE < 128 && NULL != desc); + LIBXSMM_ASSERT(LIBXSMM_DIFF_SIZE <= LIBXSMM_DESCRIPTOR_MAXSIZE); + LIBXSMM_ASSERT(LIBXSMM_HASH_SIZE <= LIBXSMM_DIFF_SIZE); + for (; size < LIBXSMM_DIFF_SIZE; ++size) desc->data[size] = 0; +} + + +LIBXSMM_API_INLINE libxsmm_code_pointer internal_find_code(libxsmm_descriptor* desc, size_t desc_size, size_t user_size, unsigned int* hash) +{ + libxsmm_code_pointer flux_entry = { 0 }; + const int is_big_desc = LIBXSMM_DESCRIPTOR_ISBIG(desc->kind); + const signed char size = (signed char)(sizeof(libxsmm_descriptor_kind) + desc_size); + LIBXSMM_DIFF_DECL(LIBXSMM_DIFF_SIZE, xdesc); +#if !defined(NDEBUG) && (0 != LIBXSMM_JIT) + int build = EXIT_SUCCESS; +#endif +#if defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) +# if defined(LIBXSMM_NTHREADS_USE) + const unsigned int tid = libxsmm_get_tid(); + internal_cache_type *const cache = internal_cache_buffer + tid; +# else + static LIBXSMM_TLS internal_cache_type internal_cache_buffer; + internal_cache_type *const cache = &internal_cache_buffer; +# endif + unsigned char cache_index; + const unsigned int ninit = LIBXSMM_ATOMIC_LOAD(&libxsmm_ninit, LIBXSMM_ATOMIC_RELAXED); + internal_pad_descriptor(desc, size); + LIBXSMM_ASSERT(NULL != hash); + if (0 == is_big_desc) { + LIBXSMM_DIFF_LOAD(LIBXSMM_DIFF_SIZE, xdesc, desc); + LIBXSMM_DIFF_N(unsigned char, cache_index, LIBXSMM_DIFF(LIBXSMM_DIFF_SIZE), xdesc, cache->entry.keys, + LIBXSMM_DIFF_SIZE, LIBXSMM_CACHE_STRIDE, cache->entry.hit, cache->entry.size); + } + else { + cache_index = (unsigned char)libxsmm_diff_n(desc, cache->entry.keys, + size, LIBXSMM_CACHE_STRIDE, cache->entry.hit, cache->entry.size); + } + if (ninit == cache->entry.id && cache_index < cache->entry.size) { /* valid hit */ + flux_entry = cache->entry.code[cache_index]; + cache->entry.hit = cache_index; + } + else +#else + internal_pad_descriptor(desc, size); + LIBXSMM_ASSERT(NULL != hash); +#endif + { + unsigned int i, i0, mode = 0, diff = 1; + *hash = LIBXSMM_CRC32(LIBXSMM_HASH_SIZE)(LIBXSMM_HASH_SEED, desc); + i0 = i = LIBXSMM_MOD2(*hash, LIBXSMM_CAPACITY_REGISTRY); + LIBXSMM_ASSERT(&desc->kind == &desc->gemm.pad && desc->kind == desc->gemm.pad); + LIBXSMM_ASSERT(NULL != internal_registry); + do { /* use calculated location and check if the requested code is already JITted */ +#if (1 < INTERNAL_REGLOCK_MAXN) || !LIBXSMM_LOCK_TYPE_ISRW(LIBXSMM_REGLOCK) /* read registered code */ +# if 1 /* omitting an atomic load is safe but avoids race-detectors to highlight this location */ + uintptr_t *const fluxaddr = &internal_registry[i].uval; + flux_entry.uval = LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_LOAD, LIBXSMM_BITS)(fluxaddr, LIBXSMM_ATOMIC_RELAXED); +# else + flux_entry = internal_registry[i]; +# endif +#else + LIBXSMM_LOCK_ACQREAD(LIBXSMM_REGLOCK, internal_reglock_ptr); + flux_entry = internal_registry[i]; /* read registered code */ + LIBXSMM_LOCK_RELREAD(LIBXSMM_REGLOCK, internal_reglock_ptr); +#endif + if ((NULL != flux_entry.ptr_const || 1 == mode) && 2 > mode) { /* confirm entry */ + if (NULL != flux_entry.ptr_const) { + if (0 == is_big_desc) { +#if !defined(LIBXSMM_CACHE_MAXSIZE) || (0 == (LIBXSMM_CACHE_MAXSIZE)) + LIBXSMM_DIFF_LOAD(LIBXSMM_DIFF_SIZE, xdesc, desc); +#endif + diff = LIBXSMM_DIFF(LIBXSMM_DIFF_SIZE)(xdesc, internal_registry_keys + i, 0/*dummy*/); + } + else { + diff = libxsmm_diff(desc, internal_registry_keys + i, size); + } + } +#if !defined(NDEBUG) + else LIBXSMM_ASSERT(0 != diff); +#endif + if (0 != diff) { /* search for code version */ + if (0 == mode) { /* transition to higher mode */ + i0 = i; /* keep current position on record */ +#if defined(LIBXSMM_HASH_COLLISION) + /* enter code generation, and collision fix-up */ + if (0 == (LIBXSMM_HASH_COLLISION & flux_entry.uval)) { + LIBXSMM_ASSERT(NULL != flux_entry.ptr_const); /* collision */ + mode = 3; + } + else +#endif /* search for an existing code version */ + mode = 1; /* else */ + } + i = LIBXSMM_MOD2(i + 1, LIBXSMM_CAPACITY_REGISTRY); + if (i == i0) { /* search finished, no code version exists */ +#if defined(LIBXSMM_HASH_COLLISION) + mode = 3; /* enter code generation, and collision fix-up */ +#else + mode = 2; /* enter code generation */ +#endif + if (LIBXSMM_KERNEL_KIND_MATMUL == LIBXSMM_DESCRIPTOR_KIND(desc->kind)) { + internal_update_mmstatistic(&desc->gemm.desc, 0, 1/*collision*/, 0, 0); + } + } + LIBXSMM_ASSERT(0 != diff); /* continue */ + } + } + else { /* enter code generation (there is no code version yet) */ + LIBXSMM_ASSERT(0 == mode || 1 < mode); +#if (0 == LIBXSMM_JIT) + LIBXSMM_UNUSED(user_size); +#else + if (LIBXSMM_X86_GENERIC <= libxsmm_target_archid || /* check if JIT is supported (CPUID) */ + (LIBXSMM_KERNEL_KIND_USER == LIBXSMM_DESCRIPTOR_KIND(desc->kind))) + { + LIBXSMM_ASSERT(0 != mode || NULL == flux_entry.ptr_const/*code version does not exist*/); + INTERNAL_FIND_CODE_LOCK(lock, i, diff, flux_entry.ptr); /* lock the registry entry */ + if (NULL == internal_registry[i].ptr_const) { /* double-check registry after acquiring the lock */ + libxsmm_build_request request; /* setup the code build request */ + LIBXSMM_ASSERT(LIBXSMM_KERNEL_UNREGISTERED > LIBXSMM_DESCRIPTOR_KIND(desc->kind)); + request.kind = (libxsmm_build_kind)LIBXSMM_DESCRIPTOR_KIND(desc->kind); + request.descriptor.ptr = &desc->gemm.desc; + request.user_size = user_size; +# if defined(NDEBUG) + if (EXIT_SUCCESS == libxsmm_build(&request, i, &flux_entry) && NULL != flux_entry.ptr_const) +# else + build = libxsmm_build(&request, i, &flux_entry); + if (EXIT_SUCCESS == build && NULL != flux_entry.ptr_const) +# endif + { + LIBXSMM_ASSIGN127(internal_registry_keys + i, desc); +# if (1 < INTERNAL_REGLOCK_MAXN) + LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE, LIBXSMM_BITS)(&internal_registry[i].ptr, flux_entry.ptr, LIBXSMM_ATOMIC_SEQ_CST); +# else + internal_registry[i] = flux_entry; +# endif +# if defined(LIBXSMM_HASH_COLLISION) + if (2 < mode) { /* arrived from collision state; now mark as collision */ + libxsmm_code_pointer fix_entry; +# if (1 < INTERNAL_REGLOCK_MAXN) + fix_entry.ptr = LIBXSMM_ATOMIC_LOAD(&internal_registry[i0].ptr, LIBXSMM_ATOMIC_RELAXED); +# else + fix_entry = internal_registry[i0]; +# endif + LIBXSMM_ASSERT(NULL != fix_entry.ptr_const); + if (0 == (LIBXSMM_HASH_COLLISION & fix_entry.uval)) { + fix_entry.uval |= LIBXSMM_HASH_COLLISION; /* mark current entry as collision */ +# if (1 < INTERNAL_REGLOCK_MAXN) + LIBXSMM_ATOMIC_STORE(&internal_registry[i0].ptr, fix_entry.ptr, LIBXSMM_ATOMIC_RELAXED); +# else + internal_registry[i0] = fix_entry; +# endif + } + } +# endif + } + if (LIBXSMM_KERNEL_KIND_MATMUL == LIBXSMM_DESCRIPTOR_KIND(desc->kind)) { + internal_update_mmstatistic(&desc->gemm.desc, 1/*try*/, 0, 0, 0); + } + /* leave here even in case of a build-error; do not use break (inside of locked region) */ + diff = 0; + } + INTERNAL_FIND_CODE_UNLOCK(lock); + if (0 != diff) { /* acquire registry slot */ + if (0 == mode) { /* initial condition */ + mode = 2; /* continue to linearly search for an empty slot */ + i0 = i; /* keep current position on record */ + } + do { /* continue to linearly search for an available slot */ + i = LIBXSMM_MOD2(i + 1, LIBXSMM_CAPACITY_REGISTRY); + if (NULL == internal_registry[i].ptr_const) break; + } while (i != i0); + if (i == i0) { /* out of capacity (no registry slot available) */ + diff = 0; /* do not use break if inside of locked region */ + } + flux_entry.ptr = NULL; /* no result */ + } + } + else /* JIT-code generation not available */ +#endif + { /* leave the dispatch loop */ + if (LIBXSMM_KERNEL_KIND_MATMUL == LIBXSMM_DESCRIPTOR_KIND(desc->kind)) { + internal_update_mmstatistic(&desc->gemm.desc, 1/*try*/, 0, 0, 0); + } +#if !defined(NDEBUG) && (0 != LIBXSMM_JIT) + build = EXIT_FAILURE; +#endif + flux_entry.ptr = NULL; + diff = 0; + } + } + } while (0 != diff); +#if defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + if (NULL != flux_entry.ptr_const) { /* keep code version on record (cache) */ + LIBXSMM_ASSERT(0 == diff); + if (ninit == cache->entry.id) { /* maintain cache */ + if (cache->entry.size < internal_cache_size) { /* grow */ + INTERNAL_FIND_CODE_CACHE_GROW(cache_index, cache->entry.size); + LIBXSMM_ASSERT(cache->entry.size <= internal_cache_size); + } + else { /* evict */ + LIBXSMM_ASSERT(cache->entry.hit < cache->entry.size); + INTERNAL_FIND_CODE_CACHE_EVICT(cache_index, cache->entry.size, cache->entry.hit); + } + } + else if (0 != internal_cache_size) { /* reset cache */ + /* INTERNAL_FIND_CODE_CACHE_GROW doubles size (and would expose invalid entries) */ + memset(cache->entry.keys, 0, LIBXSMM_CACHE_MAXSIZE * sizeof(*cache->entry.keys)); + cache->entry.id = ninit; + cache->entry.size = 1; + cache_index = 0; + } + LIBXSMM_MEMCPY127(cache->entry.keys + cache_index, desc, 0 == is_big_desc ? LIBXSMM_DIFF_SIZE : size); + cache->entry.code[cache_index] = flux_entry; + cache->entry.hit = cache_index; + } +# if !defined(NDEBUG) + else { + memset(cache, 0, sizeof(*cache)); + } +# endif +#endif + } +#if defined(LIBXSMM_HASH_COLLISION) + flux_entry.uval &= ~(LIBXSMM_CODE_STATIC | LIBXSMM_HASH_COLLISION); /* clear non-JIT and collision flag */ +#else + flux_entry.uval &= ~LIBXSMM_CODE_STATIC; /* clear non-JIT flag */ +#endif +#if (0 != LIBXSMM_JIT) + assert( /*!LIBXSMM_ASSERT*/ + LIBXSMM_KERNEL_KIND_MATMUL != LIBXSMM_DESCRIPTOR_KIND(desc->kind) + || NULL != flux_entry.ptr_const + || 1 == internal_reglock_count + || EXIT_SUCCESS != build); +#endif + return flux_entry; +} + + +LIBXSMM_API_INTERN const libxsmm_kernel_xinfo* libxsmm_get_kernel_xinfo(libxsmm_code_pointer code, + const libxsmm_descriptor** desc, size_t* code_size) +{ + libxsmm_kernel_xinfo* result = NULL; + void *const result_address = &result; + int flags = LIBXSMM_MALLOC_FLAG_X; + if (NULL != code.ptr_const && EXIT_SUCCESS == libxsmm_get_malloc_xinfo( + code.ptr_const, code_size, &flags, (void**)result_address) && NULL != result) + { + if (NULL != desc) { + if (NULL != internal_registry && NULL != internal_registry_keys && result->registered < (LIBXSMM_CAPACITY_REGISTRY) +#if defined(LIBXSMM_HASH_COLLISION) + && code.uval == (~LIBXSMM_HASH_COLLISION & internal_registry[result->registered].uval) +#else + && code.ptr_const == internal_registry[result->registered].ptr_const +#endif + && LIBXSMM_KERNEL_UNREGISTERED > LIBXSMM_DESCRIPTOR_KIND(internal_registry_keys[result->registered].entry.kind)) + { + *desc = &internal_registry_keys[result->registered].entry; + } + else *desc = NULL; + } + } + else { + LIBXSMM_ASSERT(NULL == result); + if (NULL != code_size) *code_size = 0; + if (NULL != desc) *desc = NULL; + } + return result; +} + + +LIBXSMM_API int libxsmm_get_kernel_info(const void* kernel, libxsmm_kernel_info* info) +{ + int result; + const libxsmm_kernel_xinfo* xinfo; + libxsmm_kernel_info result_info; + const libxsmm_descriptor* desc; + libxsmm_code_pointer code; + code.ptr_const = kernel; + LIBXSMM_MEMZERO127(&result_info); + xinfo = libxsmm_get_kernel_xinfo(code, &desc, &result_info.code_size); + if (NULL != xinfo) { + if (NULL != desc) { + const libxsmm_kernel_kind kind = (libxsmm_kernel_kind)LIBXSMM_DESCRIPTOR_KIND(desc->kind); + result_info.kind = kind; + if (LIBXSMM_KERNEL_KIND_USER == kind) { + result_info.code_size = 0; /* invalid */ + } + } + else { + result_info.kind = LIBXSMM_KERNEL_UNREGISTERED; + } + result_info.nflops = xinfo->nflops; + LIBXSMM_ASSIGN127(info, &result_info); + result = EXIT_SUCCESS; + } + else { + LIBXSMM_ASSERT(NULL == desc); + if (NULL != info) { + LIBXSMM_ASSIGN127(info, &result_info); + result = EXIT_FAILURE; + } + else { + result = EXIT_SUCCESS; + } + } + return result; +} + + +LIBXSMM_API int libxsmm_get_mmkernel_info(libxsmm_xmmfunction kernel, libxsmm_mmkernel_info* info) +{ + libxsmm_code_pointer code; + static int error_once = 0; + int result; + code.xgemm = kernel; + if (NULL != info) { + const libxsmm_descriptor* desc; + if (NULL != libxsmm_get_kernel_xinfo(code, &desc, NULL/*code_size*/) && + NULL != desc && LIBXSMM_KERNEL_KIND_MATMUL == LIBXSMM_DESCRIPTOR_KIND(desc->kind)) + { + info->iprecision = (libxsmm_gemm_precision)LIBXSMM_GETENUM_INP(desc->gemm.desc.datatype); + info->oprecision = (libxsmm_gemm_precision)LIBXSMM_GETENUM_OUT(desc->gemm.desc.datatype); + info->prefetch = (libxsmm_gemm_prefetch_type)desc->gemm.desc.prefetch; + info->flags = desc->gemm.desc.flags; + info->lda = desc->gemm.desc.lda; + info->ldb = desc->gemm.desc.ldb; + info->ldc = desc->gemm.desc.ldc; + info->m = desc->gemm.desc.m; + info->n = desc->gemm.desc.n; + info->k = desc->gemm.desc.k; + result = EXIT_SUCCESS; + } + else { +#if defined(__APPLE__) && defined(__arm64__) + info->iprecision = 1; + info->oprecision = 1; + info->prefetch = 1; + info->flags = 1; + info->lda = 1; + info->ldb = 1; + info->ldc = 1; + info->m = 1; + info->n = 1; + info->k = 1; + result = EXIT_SUCCESS; +# else + if ( 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + if (NULL == code.ptr_const) { + fprintf(stderr, "LIBXSMM ERROR: NULL-kernel cannot be inspected!\n"); + } + else { + fprintf(stderr, "LIBXSMM ERROR: invalid kernel cannot be inspected!\n"); + } + } + result = EXIT_FAILURE; +# endif + } + } + else { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid argument!\n"); + } + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API int libxsmm_get_meltwkernel_info(libxsmm_xmeltwfunction kernel, libxsmm_meltwkernel_info* info) +{ + libxsmm_code_pointer code; + static int error_once = 0; + int result; + code.xmateltw = kernel; + if (NULL != info) { + const libxsmm_descriptor* desc; + if (NULL != libxsmm_get_kernel_xinfo(code, &desc, NULL/*code_size*/) && + NULL != desc && LIBXSMM_KERNEL_KIND_MELTW == LIBXSMM_DESCRIPTOR_KIND(desc->kind)) + { + info->datatype = desc->meltw.desc.datatype; + info->operation = desc->meltw.desc.operation; + info->flags = desc->meltw.desc.flags; + info->ldi = desc->meltw.desc.ldi; + info->ldo = desc->meltw.desc.ldo; + info->m = desc->meltw.desc.m; + info->n = desc->meltw.desc.n; + result = EXIT_SUCCESS; + } + else { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid kernel cannot be inspected!\n"); + } + result = EXIT_FAILURE; + } + } + else { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid argument!\n"); + } + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API int libxsmm_get_registry_info(libxsmm_registry_info* info) +{ + int result = EXIT_SUCCESS; + LIBXSMM_INIT /* verbosity */ + if (0 != info && 0 != internal_registry) { + size_t i; + LIBXSMM_MEMZERO127(info); /* info->nstatic = 0; info->size = 0; */ + info->nbytes = (LIBXSMM_CAPACITY_REGISTRY) * (sizeof(libxsmm_code_pointer) + sizeof(libxsmm_descriptor)); + info->capacity = LIBXSMM_CAPACITY_REGISTRY; +#if defined(LIBXSMM_CACHE_MAXSIZE) && (0 < (LIBXSMM_CACHE_MAXSIZE)) + info->ncache = internal_cache_size; +#else + info->ncache = 0; +#endif + for (i = 0; i < (LIBXSMM_CAPACITY_REGISTRY); ++i) { + libxsmm_code_pointer code = internal_registry[i]; + if (0 != code.ptr_const && EXIT_SUCCESS == result) { + if (0 == (LIBXSMM_CODE_STATIC & code.uval)) { /* check for allocated/generated JIT-code */ + size_t buffer_size = 0; + void* buffer = 0; +#if defined(LIBXSMM_HASH_COLLISION) + code.uval &= ~LIBXSMM_HASH_COLLISION; /* clear collision flag */ +#endif + result = libxsmm_get_malloc_xinfo(code.ptr_const, &buffer_size, NULL/*flags*/, &buffer); + if (EXIT_SUCCESS == result) { + info->nbytes += LIBXSMM_UP2(buffer_size + (((char*)code.ptr_const) - (char*)buffer), LIBXSMM_PAGE_MINSIZE); + } + } + else { + ++info->nstatic; + } + ++info->size; + } + } + } + else { + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API_INLINE void* internal_get_registry_entry(int i, libxsmm_kernel_kind kind, const void** key) +{ + void* result = NULL; + LIBXSMM_ASSERT(kind < LIBXSMM_KERNEL_UNREGISTERED && NULL != internal_registry); + for (; i < (LIBXSMM_CAPACITY_REGISTRY); ++i) { + const libxsmm_code_pointer regentry = internal_registry[i]; + if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(regentry.ptr_const, + NULL/*code_size*/, NULL/*flags*/, &result) && NULL != result) + { + const libxsmm_kernel_xinfo info = *(const libxsmm_kernel_xinfo*)result; + const libxsmm_descriptor *const desc = &internal_registry_keys[info.registered].entry; + if (LIBXSMM_DESCRIPTOR_KIND(desc->kind) == (int)kind) { + if (NULL != key) *key = desc->user.desc; + result = regentry.ptr; + break; + } + } + } + return result; +} + + +LIBXSMM_API void* libxsmm_get_registry_begin(libxsmm_kernel_kind kind, const void** key) +{ + void* result = NULL; + if (kind < LIBXSMM_KERNEL_UNREGISTERED && NULL != internal_registry) { + result = internal_get_registry_entry(0, kind, key); + } + return result; +} + + +LIBXSMM_API void* libxsmm_get_registry_next(const void* regentry, const void** key) +{ + void* result = NULL; + const libxsmm_descriptor* desc; + libxsmm_code_pointer entry; + entry.ptr_const = regentry; + if (NULL != libxsmm_get_kernel_xinfo(entry, &desc, NULL/*code_size*/) + /* given regentry is indeed a registered kernel */ + && NULL != desc) + { + result = internal_get_registry_entry( + (int)(desc - &internal_registry_keys->entry + 1), + (libxsmm_kernel_kind)LIBXSMM_DESCRIPTOR_KIND(desc->kind), key); + } + return result; +} + + +LIBXSMM_API void* libxsmm_xregister(const void* key, size_t key_size, + size_t value_size, const void* value_init, unsigned int* key_hash) +{ + static int error_once = 0; + void* result; + LIBXSMM_INIT /* verbosity */ + if (NULL != key && 0 < key_size && LIBXSMM_DESCRIPTOR_MAXSIZE >= key_size) { + libxsmm_descriptor wrap; + unsigned int hash = 0; + void* dst; +#if defined(LIBXSMM_UNPACKED) /* CCE/Classic */ + LIBXSMM_MEMSET127(&wrap, 0, key_size); +#endif + LIBXSMM_MEMCPY127(wrap.user.desc, key, key_size); + wrap.kind = (libxsmm_descriptor_kind)(LIBXSMM_DESCRIPTOR_SIGSIZE >= key_size + ? ((libxsmm_descriptor_kind)LIBXSMM_KERNEL_KIND_USER) + : LIBXSMM_DESCRIPTOR_BIG(LIBXSMM_KERNEL_KIND_USER)); + dst = internal_find_code(&wrap, key_size, value_size, &hash).ptr; + if (NULL != key_hash) *key_hash = hash; + if (NULL != dst) { + size_t size; + if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(dst, &size, NULL/*flags*/, NULL/*extra*/) + && value_size <= size) + { + if (NULL != value_init) memcpy(dst, value_init, value_size); + result = dst; + } + else { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: value too large for previously registered key!\n"); + } + result = NULL; + } + } + else result = NULL; + } + else { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + if (LIBXSMM_DESCRIPTOR_MAXSIZE >= key_size) { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xregister specified!\n"); + } + else { + fprintf(stderr, "LIBXSMM ERROR: libxsmm_xregister has maximum key-size of %i Byte!\n", + LIBXSMM_DESCRIPTOR_MAXSIZE); + } + } + result = NULL; + } + return result; +} + + +LIBXSMM_API void* libxsmm_xdispatch(const void* key, size_t key_size, unsigned int* key_hash) +{ + void* result; + LIBXSMM_INIT /* verbosity */ +#if !defined(NDEBUG) + if (NULL != key && 0 < key_size && LIBXSMM_DESCRIPTOR_MAXSIZE >= key_size) +#endif + { + unsigned int hash = 0; + libxsmm_descriptor wrap; +#if defined(LIBXSMM_UNPACKED) /* CCE/Classic */ + LIBXSMM_MEMSET127(&wrap, 0, key_size); +#endif + LIBXSMM_MEMCPY127(wrap.user.desc, key, key_size); + wrap.kind = (libxsmm_descriptor_kind)(LIBXSMM_DESCRIPTOR_SIGSIZE >= key_size + ? ((libxsmm_descriptor_kind)LIBXSMM_KERNEL_KIND_USER) + : LIBXSMM_DESCRIPTOR_BIG(LIBXSMM_KERNEL_KIND_USER)); + result = internal_find_code(&wrap, key_size, 0/*user_size*/, &hash).ptr; + if (NULL != key_hash) *key_hash = hash; + } +#if !defined(NDEBUG) + else { + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xdispatch specified!\n"); + } + result = NULL; + } +#endif + return result; +} + + +LIBXSMM_API void libxsmm_xrelease(const void* key, size_t key_size) +{ + libxsmm_release_kernel(libxsmm_xdispatch(key, key_size, NULL/*key_hash*/)); +} + + +LIBXSMM_API libxsmm_xmmfunction libxsmm_xmmdispatch(const libxsmm_gemm_descriptor* descriptor) +{ + libxsmm_xmmfunction result; + LIBXSMM_INIT /* verbosity */ +#if !defined(LIBXSMM_UNPACKED) /* CCE/Classic */ + LIBXSMM_ASSERT((sizeof(*descriptor) + sizeof(libxsmm_descriptor_kind)) <= (LIBXSMM_DESCRIPTOR_MAXSIZE)); +#endif + if (NULL != descriptor) { + unsigned int hash; + const int batch_reduce = + LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | + LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | + LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE; + libxsmm_descriptor wrap; +#if defined(LIBXSMM_UNPACKED) /* CCE/Classic */ + LIBXSMM_MEMSET127(&wrap, 0, sizeof(*descriptor)); +#endif + LIBXSMM_ASSIGN127(&wrap.gemm.desc, descriptor); + wrap.kind = (libxsmm_descriptor_kind)(0 == (batch_reduce & descriptor->flags) + ? ((libxsmm_descriptor_kind)LIBXSMM_KERNEL_KIND_MATMUL) + : LIBXSMM_DESCRIPTOR_BIG(LIBXSMM_KERNEL_KIND_MATMUL)); + if (0 != (0x80 & descriptor->prefetch)) { /* "sign"-bit of byte-value is set */ + wrap.gemm.desc.prefetch = (unsigned char)libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO); + } + result = internal_find_code(&wrap, sizeof(*descriptor), 0/*user_size*/, &hash).xgemm; +#if defined(_DEBUG) + if (LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity && INT_MAX != libxsmm_verbosity && NULL != result.xmm) { + LIBXSMM_STDIO_ACQUIRE(); + fprintf(stderr, "\nLIBXSMM: "); + libxsmm_gemm_xprint(stderr, result, NULL/*a*/, NULL/*b*/, NULL/*c*/); + LIBXSMM_STDIO_RELEASE(); + } +#endif + } + else { /* quietly accept NULL-descriptor */ + result.xmm = NULL; + } + return result; +} + + +LIBXSMM_API libxsmm_dmmfunction libxsmm_dmmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const double* alpha, const double* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_dgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.dmm; +} + + +LIBXSMM_API libxsmm_smmfunction libxsmm_smmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_sgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.smm; +} + + +LIBXSMM_API libxsmm_bsmmfunction libxsmm_bsmmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bsgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.bsmm; +} + + +LIBXSMM_API libxsmm_bmmfunction libxsmm_bmmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.bmm; +} + + +LIBXSMM_API libxsmm_wimmfunction libxsmm_wimmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_wigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.wimm; +} + + +LIBXSMM_API libxsmm_ssbimmfunction libxsmm_ssbimmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.ssbimm; +} + + +LIBXSMM_API libxsmm_usbimmfunction libxsmm_usbimmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_A_UNSIGNED, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.usbimm; +} + + +LIBXSMM_API libxsmm_subimmfunction libxsmm_subimmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.subimm; +} + + +LIBXSMM_API libxsmm_uubimmfunction libxsmm_uubimmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.uubimm; +} + + +LIBXSMM_API libxsmm_sububmmfunction libxsmm_sububmmdispatch(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bbgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_C_UNSIGNED, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.sububmm; +} + + +LIBXSMM_API libxsmm_dmmfunction_reducebatch_addr libxsmm_dmmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const double* alpha, const double* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_dgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.dmra; +} + + +LIBXSMM_API libxsmm_smmfunction_reducebatch_addr libxsmm_smmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_sgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.smra; +} + + +LIBXSMM_API libxsmm_bsmmfunction_reducebatch_addr libxsmm_bsmmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bsgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.bsmra; +} + + +LIBXSMM_API libxsmm_bmmfunction_reducebatch_addr libxsmm_bmmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.bmra; +} + + +LIBXSMM_API libxsmm_wimmfunction_reducebatch_addr libxsmm_wimmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_wigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.wimra; +} + + +LIBXSMM_API libxsmm_ssbimmfunction_reducebatch_addr libxsmm_ssbimmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.ssbimra; +} + + +LIBXSMM_API libxsmm_usbimmfunction_reducebatch_addr libxsmm_usbimmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_A_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.usbimra; +} + + +LIBXSMM_API libxsmm_subimmfunction_reducebatch_addr libxsmm_subimmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.subimra; +} + + +LIBXSMM_API libxsmm_uubimmfunction_reducebatch_addr libxsmm_uubimmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_AB_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.uubimra; +} + + +LIBXSMM_API libxsmm_sububmmfunction_reducebatch_addr libxsmm_sububmmdispatch_reducebatch_addr(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bbgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_C_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.sububmra; +} + + +LIBXSMM_API libxsmm_dmmfunction_reducebatch_addr libxsmm_dmmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const double* alpha, const double* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_dgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.dmra; +} + + +LIBXSMM_API libxsmm_smmfunction_reducebatch_addr libxsmm_smmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_sgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.smra; +} + + +LIBXSMM_API libxsmm_bsmmfunction_reducebatch_addr libxsmm_bsmmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bsgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.bsmra; +} + + +LIBXSMM_API libxsmm_bmmfunction_reducebatch_addr libxsmm_bmmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.bmra; +} + + +LIBXSMM_API libxsmm_wimmfunction_reducebatch_addr libxsmm_wimmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_wigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.wimra; +} + + +LIBXSMM_API libxsmm_ssbimmfunction_reducebatch_addr libxsmm_ssbimmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.ssbimra; +} + + +LIBXSMM_API libxsmm_usbimmfunction_reducebatch_addr libxsmm_usbimmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_A_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.usbimra; +} + + +LIBXSMM_API libxsmm_subimmfunction_reducebatch_addr libxsmm_subimmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.subimra; +} + + +LIBXSMM_API libxsmm_uubimmfunction_reducebatch_addr libxsmm_uubimmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_AB_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.uubimra; +} + + +LIBXSMM_API libxsmm_sububmmfunction_reducebatch_addr libxsmm_sububmmdispatch_reducebatch_addr_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bbgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_C_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.sububmra; +} + + +LIBXSMM_API libxsmm_dmmfunction_reducebatch_offs libxsmm_dmmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const double* alpha, const double* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_dgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.dmro; +} + + +LIBXSMM_API libxsmm_smmfunction_reducebatch_offs libxsmm_smmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_sgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.smro; +} + + +LIBXSMM_API libxsmm_bsmmfunction_reducebatch_offs libxsmm_bsmmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bsgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.bsmro; +} + + +LIBXSMM_API libxsmm_bmmfunction_reducebatch_offs libxsmm_bmmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.bmro; +} + + +LIBXSMM_API libxsmm_wimmfunction_reducebatch_offs libxsmm_wimmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_wigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.wimro; +} + + +LIBXSMM_API libxsmm_ssbimmfunction_reducebatch_offs libxsmm_ssbimmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.ssbimro; +} + + +LIBXSMM_API libxsmm_usbimmfunction_reducebatch_offs libxsmm_usbimmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_A_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.usbimro; +} + + +LIBXSMM_API libxsmm_subimmfunction_reducebatch_offs libxsmm_subimmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.subimro; +} + + +LIBXSMM_API libxsmm_uubimmfunction_reducebatch_offs libxsmm_uubimmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_AB_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.uubimro; +} + + +LIBXSMM_API libxsmm_sububmmfunction_reducebatch_offs libxsmm_sububmmdispatch_reducebatch_offs(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + const libxsmm_gemm_descriptor *const desc = libxsmm_bbgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_C_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result = libxsmm_xmmdispatch(desc); + return result.sububmro; +} + + +LIBXSMM_API libxsmm_dmmfunction_reducebatch_offs libxsmm_dmmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const double* alpha, const double* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_dgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.dmro; +} + + +LIBXSMM_API libxsmm_smmfunction_reducebatch_offs libxsmm_smmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_sgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.smro; +} + + +LIBXSMM_API libxsmm_bsmmfunction_reducebatch_offs libxsmm_bsmmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bsgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.bsmro; +} + + +LIBXSMM_API libxsmm_bmmfunction_reducebatch_offs libxsmm_bmmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.bmro; +} + + +LIBXSMM_API libxsmm_wimmfunction_reducebatch_offs libxsmm_wimmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_wigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.wimro; +} + + +LIBXSMM_API libxsmm_ssbimmfunction_reducebatch_offs libxsmm_ssbimmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.ssbimro; +} + + +LIBXSMM_API libxsmm_usbimmfunction_reducebatch_offs libxsmm_usbimmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_A_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.usbimro; +} + + +LIBXSMM_API libxsmm_subimmfunction_reducebatch_offs libxsmm_subimmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.subimro; +} + + +LIBXSMM_API libxsmm_uubimmfunction_reducebatch_offs libxsmm_uubimmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_AB_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.uubimro; +} + + +LIBXSMM_API libxsmm_sububmmfunction_reducebatch_offs libxsmm_sububmmdispatch_reducebatch_offs_unroll(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bbgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_C_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + result = libxsmm_xmmdispatch(desc); + return result.sububmro; +} + + +LIBXSMM_API libxsmm_dmmfunction_reducebatch_strd libxsmm_dmmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const double* alpha, const double* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_dgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.dmrs; +} + + +LIBXSMM_API libxsmm_smmfunction_reducebatch_strd libxsmm_smmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_sgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.smrs; +} + + +LIBXSMM_API libxsmm_bsmmfunction_reducebatch_strd libxsmm_bsmmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bsgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.bsmrs; +} + + +LIBXSMM_API libxsmm_bmmfunction_reducebatch_strd libxsmm_bmmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.bmrs; +} + + +LIBXSMM_API libxsmm_wimmfunction_reducebatch_strd libxsmm_wimmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_wigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.wimrs; +} + + +LIBXSMM_API libxsmm_ssbimmfunction_reducebatch_strd libxsmm_ssbimmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.ssbimrs; +} + + +LIBXSMM_API libxsmm_usbimmfunction_reducebatch_strd libxsmm_usbimmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_A_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.usbimrs; +} + + +LIBXSMM_API libxsmm_subimmfunction_reducebatch_strd libxsmm_subimmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.subimrs; +} + + +LIBXSMM_API libxsmm_uubimmfunction_reducebatch_strd libxsmm_uubimmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_AB_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.uubimrs; +} + + +LIBXSMM_API libxsmm_sububmmfunction_reducebatch_strd libxsmm_sububmmdispatch_reducebatch_strd( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bbgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_C_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, + libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.sububmrs; +} + + +LIBXSMM_API libxsmm_dmmfunction_reducebatch_strd libxsmm_dmmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const double* alpha, const double* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_dgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.dmrs; +} + + +LIBXSMM_API libxsmm_smmfunction_reducebatch_strd libxsmm_smmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? LIBXSMM_FLAGS : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_sgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.smrs; +} + + +LIBXSMM_API libxsmm_bsmmfunction_reducebatch_strd libxsmm_bsmmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bsgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.bsmrs; +} + + +LIBXSMM_API libxsmm_bmmfunction_reducebatch_strd libxsmm_bmmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const float* alpha, const float* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.bmrs; +} + + +LIBXSMM_API libxsmm_wimmfunction_reducebatch_strd libxsmm_wimmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_wigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.wimrs; +} + + +LIBXSMM_API libxsmm_ssbimmfunction_reducebatch_strd libxsmm_ssbimmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.ssbimrs; +} + + +LIBXSMM_API libxsmm_usbimmfunction_reducebatch_strd libxsmm_usbimmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_A_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.usbimrs; +} + + +LIBXSMM_API libxsmm_subimmfunction_reducebatch_strd libxsmm_subimmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.subimrs; +} + + +LIBXSMM_API libxsmm_uubimmfunction_reducebatch_strd libxsmm_uubimmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bigemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_AB_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.uubimrs; +} + + +LIBXSMM_API libxsmm_sububmmfunction_reducebatch_strd libxsmm_sububmmdispatch_reducebatch_strd_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const int* alpha, const int* beta, const int* flags, const int* prefetch) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bbgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_B_UNSIGNED | LIBXSMM_GEMM_FLAG_C_UNSIGNED | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, + libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + result = libxsmm_xmmdispatch(desc); + return result.sububmrs; +} + + +/* GEMMs fused with eltwise kernels */ +LIBXSMM_API libxsmm_bmmfunction_reducebatch_strd_meltwfused libxsmm_bmmdispatch_reducebatch_strd_meltwfused( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, const float* alpha, const float* beta, const int* flags, const int* prefetch, + libxsmm_meltw_operation meltw_op, libxsmm_datatype meltw_dt, libxsmm_meltw_flags meltw_flags, unsigned char meltw_param, unsigned int meltw_ldx, unsigned int meltw_ldy, unsigned int meltw_ldz) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + desc->meltw_datatype_aux = (unsigned char)meltw_dt; + desc->meltw_flags = (unsigned short)meltw_flags; + desc->meltw_operation = (unsigned char)meltw_op; + desc->meltw_param = (unsigned char)meltw_param; + desc->meltw_ldx = (unsigned int) meltw_ldx; + desc->meltw_ldy = (unsigned int) meltw_ldy; + desc->meltw_ldz = (unsigned int) meltw_ldz; + result = libxsmm_xmmdispatch(desc); + return result.bmrs_meltwfused; +} + + +LIBXSMM_API libxsmm_bmmfunction_reducebatch_strd_meltwfused libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, const float* alpha, const float* beta, const int* flags, const int* prefetch, + libxsmm_meltw_operation meltw_op, libxsmm_datatype meltw_dt, libxsmm_meltw_flags meltw_flags, unsigned char meltw_param, unsigned int meltw_ldx, unsigned int meltw_ldy, unsigned int meltw_ldz) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + desc->meltw_datatype_aux = (unsigned char)meltw_dt; + desc->meltw_flags = (unsigned short)meltw_flags; + desc->meltw_operation = (unsigned char)meltw_op; + desc->meltw_param = (unsigned char)meltw_param; + desc->meltw_ldx = (unsigned int) meltw_ldx; + desc->meltw_ldy = (unsigned int) meltw_ldy; + desc->meltw_ldz = (unsigned int) meltw_ldz; + result = libxsmm_xmmdispatch(desc); + return result.bmrs_meltwfused; +} + + +LIBXSMM_API libxsmm_bsmmfunction_reducebatch_strd_meltwfused libxsmm_bsmmdispatch_reducebatch_strd_meltwfused( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, const float* alpha, const float* beta, const int* flags, const int* prefetch, + libxsmm_meltw_operation meltw_op, libxsmm_datatype meltw_dt, libxsmm_meltw_flags meltw_flags, unsigned char meltw_param, unsigned int meltw_ldx, unsigned int meltw_ldy, unsigned int meltw_ldz) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bsgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + desc->meltw_datatype_aux = (unsigned char)meltw_dt; + desc->meltw_flags = (unsigned short)meltw_flags; + desc->meltw_operation = (unsigned char)meltw_op; + desc->meltw_param = (unsigned char)meltw_param; + desc->meltw_ldx = (unsigned int) meltw_ldx; + desc->meltw_ldy = (unsigned int) meltw_ldy; + desc->meltw_ldz = (unsigned int) meltw_ldz; + result = libxsmm_xmmdispatch(desc); + return result.bsmrs_meltwfused; +} + + +LIBXSMM_API libxsmm_bsmmfunction_reducebatch_strd_meltwfused libxsmm_bsmmdispatch_reducebatch_strd_meltwfused_unroll( + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, libxsmm_blasint stride_a, libxsmm_blasint stride_b, libxsmm_blasint unroll_hint, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, const float* alpha, const float* beta, const int* flags, const int* prefetch, + libxsmm_meltw_operation meltw_op, libxsmm_datatype meltw_dt, libxsmm_meltw_flags meltw_flags, unsigned char meltw_param, unsigned int meltw_ldx, unsigned int meltw_ldy, unsigned int meltw_ldz) +{ + const int gemm_flags = (NULL == flags ? (LIBXSMM_FLAGS | LIBXSMM_GEMM_FLAG_VNNI_A) : *flags); + libxsmm_descriptor_blob blob; + /*const*/ libxsmm_gemm_descriptor *const desc = libxsmm_bsgemm_descriptor_init(&blob, m, n, k, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n), + NULL != ldc ? *ldc : m, NULL != alpha ? *alpha : LIBXSMM_ALPHA, NULL != beta ? *beta : LIBXSMM_BETA, + gemm_flags | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, libxsmm_get_gemm_xprefetch(prefetch)); + /*const*/ libxsmm_xmmfunction result; + desc->c1 = (unsigned long long)stride_a; + desc->c2 = (unsigned long long)stride_b; + desc->c3 = (unsigned char)(((unroll_hint < 255) && (unroll_hint > 0)) ? unroll_hint : 0); + if ( (stride_a < 0) || (stride_b < 0) ) { + return NULL; + } + desc->meltw_datatype_aux = (unsigned char)meltw_dt; + desc->meltw_flags = (unsigned short)meltw_flags; + desc->meltw_operation = (unsigned char)meltw_op; + desc->meltw_param = (unsigned char)meltw_param; + desc->meltw_ldx = (unsigned int) meltw_ldx; + desc->meltw_ldy = (unsigned int) meltw_ldy; + desc->meltw_ldz = (unsigned int) meltw_ldz; + result = libxsmm_xmmdispatch(desc); + return result.bsmrs_meltwfused; +} + + +LIBXSMM_API libxsmm_xmeltwfunction libxsmm_dispatch_meltw(const libxsmm_meltw_descriptor* descriptor) +{ + libxsmm_xmeltwfunction result; + LIBXSMM_INIT /* verbosity */ +#if !defined(LIBXSMM_UNPACKED) /* CCE/Classic */ + LIBXSMM_ASSERT((sizeof(*descriptor) + sizeof(libxsmm_descriptor_kind)) <= (LIBXSMM_DESCRIPTOR_MAXSIZE)); +#endif + if (NULL != descriptor) { + unsigned int hash; + libxsmm_descriptor wrap; +#if defined(LIBXSMM_UNPACKED) /* CCE/Classic */ + LIBXSMM_MEMSET127(&wrap, 0, sizeof(*descriptor)); +#endif + LIBXSMM_ASSIGN127(&wrap.meltw.desc, descriptor); + wrap.kind = LIBXSMM_DESCRIPTOR_BIG(LIBXSMM_KERNEL_KIND_MELTW); + result = internal_find_code(&wrap, sizeof(*descriptor), 0/*user_size*/, &hash).xmateltw; + } + else { + result.xmeltw = NULL; + } + return result; +} + + +LIBXSMM_API libxsmm_meltwfunction_reduce_cols_idx libxsmm_dispatch_meltw_reduce_cols_idx( + libxsmm_blasint m, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo, + libxsmm_datatype in_type, libxsmm_datatype out_type, libxsmm_datatype idx_type) +{ + libxsmm_descriptor_blob blob; + libxsmm_blasint idx_dtype_size = libxsmm_typesize(idx_type); + const libxsmm_meltw_descriptor *const desc = libxsmm_meltw_descriptor_init(&blob, + in_type, out_type, m, idx_dtype_size, (ldi == NULL) ? m : *ldi, (ldo == NULL) ? m : *ldo, + 0, 0, LIBXSMM_MELTW_OPERATION_REDUCE_COLS_IDX); + + libxsmm_xmeltwfunction result = libxsmm_dispatch_meltw(desc); + + return result.meltw_reduce_cols_idx; +} + + +LIBXSMM_API libxsmm_meltwfunction_opreduce_vecs_idx libxsmm_dispatch_meltw_opreduce_vecs_idx( + libxsmm_blasint m, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo, + libxsmm_datatype in_type, libxsmm_datatype out_type, libxsmm_datatype idx_type, libxsmm_meltw_opreduce_vecs_flags flags) +{ + libxsmm_descriptor_blob blob; + libxsmm_blasint idx_dtype_size = libxsmm_typesize(idx_type); + unsigned char argidx_params = (unsigned char) (((flags & LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_0) | (flags & LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_1)) >> 16); + const libxsmm_meltw_descriptor *const desc = libxsmm_meltw_descriptor_init(&blob, + in_type, out_type, m, idx_dtype_size, (ldi == NULL) ? m : *ldi, (ldo == NULL) ? m : *ldo, + (unsigned short)flags, argidx_params, LIBXSMM_MELTW_OPERATION_OPREDUCE_VECS_IDX); + + libxsmm_xmeltwfunction result = libxsmm_dispatch_meltw(desc); + + return result.meltw_opreduce_vecs_idx; +} + + +LIBXSMM_API libxsmm_meltwfunction_unary libxsmm_dispatch_meltw_unary( + libxsmm_blasint m, libxsmm_blasint n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo, + libxsmm_datatype in_type, libxsmm_datatype compute_type, libxsmm_datatype out_type, libxsmm_meltw_unary_flags flags, libxsmm_meltw_unary_type type) +{ + libxsmm_descriptor_blob blob; + const libxsmm_meltw_descriptor *const desc = libxsmm_meltw_descriptor_init2(&blob, + in_type, compute_type, out_type, LIBXSMM_DATATYPE_UNSUPPORTED, m, n, (ldi == NULL) ? m : *ldi, (ldo == NULL) ? m : *ldo, 0, 0, + (unsigned short)flags, (unsigned char)type, LIBXSMM_MELTW_OPERATION_UNARY); + + libxsmm_xmeltwfunction result = libxsmm_dispatch_meltw(desc); + + return result.meltw_unary; +} + + +LIBXSMM_API libxsmm_meltwfunction_binary libxsmm_dispatch_meltw_binary( + libxsmm_blasint m, libxsmm_blasint n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldi2, const libxsmm_blasint* ldo, + libxsmm_datatype in_type, libxsmm_datatype compute_type, libxsmm_datatype out_type, libxsmm_meltw_binary_flags flags, libxsmm_meltw_binary_type type) +{ + libxsmm_descriptor_blob blob; + const libxsmm_meltw_descriptor *const desc = libxsmm_meltw_descriptor_init2(&blob, + in_type, compute_type, out_type, LIBXSMM_DATATYPE_UNSUPPORTED, m, n, (ldi == NULL) ? m : *ldi, (ldo == NULL) ? m : *ldo, (ldi2 == NULL) ? m : *ldi2, 0, + (unsigned short)flags, (unsigned char)type, LIBXSMM_MELTW_OPERATION_BINARY); + + libxsmm_xmeltwfunction result = libxsmm_dispatch_meltw(desc); + + return result.meltw_binary; +} + + +LIBXSMM_API libxsmm_meltwfunction_ternary libxsmm_dispatch_meltw_ternary( + libxsmm_blasint m, libxsmm_blasint n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldi2, const libxsmm_blasint* ldi3, const libxsmm_blasint* ldo, + libxsmm_datatype in_type, libxsmm_datatype compute_type, libxsmm_datatype out_type, libxsmm_meltw_ternary_flags flags, libxsmm_meltw_ternary_type type) +{ + libxsmm_descriptor_blob blob; + const libxsmm_meltw_descriptor *const desc = libxsmm_meltw_descriptor_init2(&blob, + in_type, compute_type, out_type, LIBXSMM_DATATYPE_UNSUPPORTED, m, n, (ldi == NULL) ? m : *ldi, (ldo == NULL) ? m : *ldo, (ldi2 == NULL) ? m : *ldi2, (ldi3 == NULL) ? m : *ldi3, + (unsigned short)flags, (unsigned char)type, LIBXSMM_MELTW_OPERATION_TERNARY); + + libxsmm_xmeltwfunction result = libxsmm_dispatch_meltw(desc); + + return result.meltw_ternary; +} + + +LIBXSMM_API libxsmm_matrix_eqn_function libxsmm_dispatch_matrix_eqn_desc( const libxsmm_meqn_descriptor* descriptor ) { + libxsmm_matrix_eqn_function result; + LIBXSMM_INIT /* verbosity */ +#if !defined(LIBXSMM_UNPACKED) /* CCE/Classic */ + LIBXSMM_ASSERT((sizeof(*descriptor) + sizeof(libxsmm_descriptor_kind)) <= (LIBXSMM_DESCRIPTOR_MAXSIZE)); +#endif + if (NULL != descriptor) { + unsigned int hash; + libxsmm_descriptor wrap; + + /* check if equation is ready for JIT */ + if ( libxsmm_matrix_eqn_is_ready_for_jit( descriptor->eqn_idx) == 0 ) { +#if defined(LIBXSMM_UNPACKED) /* CCE/Classic */ + LIBXSMM_MEMSET127(&wrap, 0, sizeof(*descriptor)); +#endif + LIBXSMM_ASSIGN127(&wrap.meqn.desc, descriptor); + wrap.kind = LIBXSMM_DESCRIPTOR_BIG(LIBXSMM_KERNEL_KIND_MEQN); + result = internal_find_code(&wrap, sizeof(*descriptor), 0/*user_size*/, &hash).xmateqn; + } else { + result = NULL; + } + } + else { + result = NULL; + } + return result; +} + + +LIBXSMM_API libxsmm_matrix_eqn_function libxsmm_dispatch_matrix_eqn( + const libxsmm_blasint m, const libxsmm_blasint n, const libxsmm_blasint* ldo, + const libxsmm_datatype out_type, const unsigned int eqn_idx ) +{ + libxsmm_descriptor_blob blob; + const libxsmm_meqn_descriptor *const desc = libxsmm_meqn_descriptor_init(&blob, + out_type, m, n, (ldo == NULL) ? m : *ldo, eqn_idx ); + + return libxsmm_dispatch_matrix_eqn_desc( desc ); +} + + +LIBXSMM_API libxsmm_xmmfunction libxsmm_create_packed_spxgemm_csr(const libxsmm_gemm_descriptor* descriptor, unsigned int packed_width, + const unsigned int* row_ptr, const unsigned int* column_idx, const void* values) +{ + libxsmm_code_pointer result = { 0 }; + LIBXSMM_INIT + if (NULL != descriptor && NULL != row_ptr && NULL != column_idx && NULL != values) { + libxsmm_pspgemm_csr_descriptor pspgemm_csr; + libxsmm_build_request request; + libxsmm_gemm_descriptor desc; + if (0 == (0x80 & descriptor->prefetch)) { + pspgemm_csr.gemm = descriptor; + } + else { /* "sign"-bit of byte-value is set */ + LIBXSMM_ASSIGN127(&desc, descriptor); + desc.prefetch = (unsigned char)libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO); + pspgemm_csr.gemm = &desc; + } + pspgemm_csr.row_ptr = row_ptr; + pspgemm_csr.column_idx = column_idx; + pspgemm_csr.values = values; + pspgemm_csr.packed_width = packed_width; + request.descriptor.pspgemm_csr = &pspgemm_csr; + request.kind = LIBXSMM_BUILD_KIND_PSPGEMM_CSR; + libxsmm_build(&request, LIBXSMM_CAPACITY_REGISTRY/*not managed*/, &result); + } + return result.xgemm; +} + + +LIBXSMM_API libxsmm_xmmfunction libxsmm_create_packed_spxgemm_csc(const libxsmm_gemm_descriptor* descriptor, unsigned int packed_width, + const unsigned int* column_ptr, const unsigned int* row_idx, const void* values) +{ + libxsmm_code_pointer result = { 0 }; + LIBXSMM_INIT + if (NULL != descriptor && NULL != column_ptr && NULL != row_idx && NULL != values) { + libxsmm_pspgemm_csc_descriptor pspgemm_csc; + libxsmm_build_request request; + libxsmm_gemm_descriptor desc; + if (0 == (0x80 & descriptor->prefetch)) { + pspgemm_csc.gemm = descriptor; + } + else { /* "sign"-bit of byte-value is set */ + LIBXSMM_ASSIGN127(&desc, descriptor); + desc.prefetch = (unsigned char)libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO); + pspgemm_csc.gemm = &desc; + } + pspgemm_csc.column_ptr = column_ptr; + pspgemm_csc.row_idx = row_idx; + pspgemm_csc.values = values; + pspgemm_csc.packed_width = packed_width; + request.descriptor.pspgemm_csc = &pspgemm_csc; + request.kind = LIBXSMM_BUILD_KIND_PSPGEMM_CSC; + libxsmm_build(&request, LIBXSMM_CAPACITY_REGISTRY/*not managed*/, &result); + } + return result.xgemm; +} + + +LIBXSMM_API libxsmm_xmmfunction libxsmm_create_packed_xgemm_ac_rm(const libxsmm_gemm_descriptor* descriptor, unsigned int packed_width) +{ + libxsmm_code_pointer result = { 0 }; + LIBXSMM_INIT + if (NULL != descriptor) { + libxsmm_pgemm_ac_rm_descriptor pgemmacrm; + libxsmm_build_request request; + libxsmm_gemm_descriptor desc; + if (0 == (0x80 & descriptor->prefetch)) { + pgemmacrm.gemm = descriptor; + } + else { /* "sign"-bit of byte-value is set */ + LIBXSMM_ASSIGN127(&desc, descriptor); + desc.prefetch = (unsigned char)libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO); + pgemmacrm.gemm = &desc; + } + pgemmacrm.packed_width = packed_width; + request.descriptor.pgemmacrm = &pgemmacrm; + request.kind = LIBXSMM_BUILD_KIND_PGEMMRMAC; + libxsmm_build(&request, LIBXSMM_CAPACITY_REGISTRY/*not managed*/, &result); + } + return result.xgemm; +} + + +LIBXSMM_API libxsmm_xmmfunction libxsmm_create_packed_xgemm_bc_rm(const libxsmm_gemm_descriptor* descriptor, unsigned int packed_width) +{ + libxsmm_code_pointer result = { 0 }; + LIBXSMM_INIT + if (NULL != descriptor) { + libxsmm_pgemm_bc_rm_descriptor pgemmbcrm; + libxsmm_build_request request; + libxsmm_gemm_descriptor desc; + if (0 == (0x80 & descriptor->prefetch)) { + pgemmbcrm.gemm = descriptor; + } + else { /* "sign"-bit of byte-value is set */ + LIBXSMM_ASSIGN127(&desc, descriptor); + desc.prefetch = (unsigned char)libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO); + pgemmbcrm.gemm = &desc; + } + pgemmbcrm.packed_width = packed_width; + request.descriptor.pgemmbcrm = &pgemmbcrm; + request.kind = LIBXSMM_BUILD_KIND_PGEMMRMBC; + libxsmm_build(&request, LIBXSMM_CAPACITY_REGISTRY/*not managed*/, &result); + } + return result.xgemm; +} + + +LIBXSMM_API libxsmm_dmmfunction libxsmm_create_dcsr_reg(const libxsmm_gemm_descriptor* descriptor, + const unsigned int* row_ptr, const unsigned int* column_idx, const double* values) +{ + libxsmm_code_pointer result = { 0 }; + LIBXSMM_INIT + if (NULL != descriptor && NULL != row_ptr && NULL != column_idx && NULL != values) { + libxsmm_csr_reg_descriptor sreg; + libxsmm_build_request request; + libxsmm_gemm_descriptor desc; + if (0 == (0x80 & descriptor->prefetch)) { + sreg.gemm = descriptor; + } + else { /* "sign"-bit of byte-value is set */ + LIBXSMM_ASSIGN127(&desc, descriptor); + desc.prefetch = (unsigned char)libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO); + sreg.gemm = &desc; + } + sreg.row_ptr = row_ptr; + sreg.column_idx = column_idx; + sreg.values = values; + request.descriptor.sreg = &sreg; + request.kind = LIBXSMM_BUILD_KIND_SREG; + libxsmm_build(&request, LIBXSMM_CAPACITY_REGISTRY/*not managed*/, &result); + } + return result.xgemm.dmm; +} + + +LIBXSMM_API libxsmm_smmfunction libxsmm_create_scsr_reg(const libxsmm_gemm_descriptor* descriptor, + const unsigned int* row_ptr, const unsigned int* column_idx, const float* values) +{ + libxsmm_code_pointer result = { 0 }; + LIBXSMM_INIT + if (NULL != descriptor && NULL != row_ptr && NULL != column_idx && NULL != values) { + libxsmm_csr_reg_descriptor sreg; + libxsmm_build_request request; + const unsigned int n = row_ptr[descriptor->m]; + double *const d_values = (double*)(0 != n ? malloc(n * sizeof(double)) : NULL); + if (NULL != d_values) { + libxsmm_gemm_descriptor desc; + unsigned int i; + /* we need to copy the values into a double precision buffer */ + for (i = 0; i < n; ++i) d_values[i] = (double)values[i]; + if (0 == (0x80 & descriptor->prefetch)) { + sreg.gemm = descriptor; + } + else { /* "sign"-bit of byte-value is set */ + LIBXSMM_ASSIGN127(&desc, descriptor); + desc.prefetch = (unsigned char)libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO); + sreg.gemm = &desc; + } + sreg.row_ptr = row_ptr; + sreg.column_idx = column_idx; + sreg.values = d_values; + request.descriptor.sreg = &sreg; + request.kind = LIBXSMM_BUILD_KIND_SREG; + libxsmm_build(&request, LIBXSMM_CAPACITY_REGISTRY/*not managed*/, &result); + free(d_values); + } + } + return result.xgemm.smm; +} + + +LIBXSMM_API void libxsmm_release_kernel(const void* kernel) +{ + if (NULL != kernel) { + static int error_once = 0; + libxsmm_kernel_xinfo* extra = NULL; + void *const extra_address = &extra; + LIBXSMM_INIT + if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo( + kernel, NULL/*size*/, NULL/*flags*/, (void**)extra_address) && NULL != extra) + { + const unsigned int regindex = extra->registered; + if ((LIBXSMM_CAPACITY_REGISTRY) <= regindex) { + libxsmm_xfree(kernel, 0/*no check*/); + } + else { /* attempt to unregister kernel */ + libxsmm_kernel_info info; +#if !defined(LIBXSMM_ENABLE_DEREG) + if (EXIT_SUCCESS == libxsmm_get_kernel_info(kernel, &info) + && LIBXSMM_KERNEL_KIND_USER == info.kind) +#endif + { + LIBXSMM_ASSERT(LIBXSMM_KERNEL_UNREGISTERED > info.kind); + /* coverity[check_return] */ + LIBXSMM_ATOMIC_ADD_FETCH(&libxsmm_ninit, 1, LIBXSMM_ATOMIC_RELAXED); /* invalidate code cache (TLS) */ + internal_registry[regindex].ptr = NULL; +#if !defined(NDEBUG) + memset(internal_registry_keys + regindex, 0, sizeof(*internal_registry_keys)); +#endif + libxsmm_xfree(kernel, 0/*no check*/); + } +#if !defined(LIBXSMM_ENABLE_DEREG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: attempt to unregister JIT-kernel!\n"); + } +#endif + } + } + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: failed to release kernel!\n"); + } + } +} + + +#if defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__)) + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_init)(void); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_init)(void) +{ + libxsmm_init(); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_finalize)(void); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_finalize)(void) +{ + libxsmm_finalize(); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_release_kernel)(const void** /*kernel*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_release_kernel)(const void** kernel) +{ +#if !defined(NDEBUG) + if (NULL != kernel) +#endif + { + libxsmm_release_kernel(*kernel); + } +#if !defined(NDEBUG) + else { + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid argument passed into libxsmm_release_kernel!\n"); + } + } +#endif +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmdispatch2)(intptr_t* /*fn*/, const int* /*iprec*/, const int* /*oprec*/, + const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const libxsmm_blasint* /*k*/, + const libxsmm_blasint* /*lda*/, const libxsmm_blasint* /*ldb*/, const libxsmm_blasint* /*ldc*/, + const void* /*alpha*/, const void* /*beta*/, const int* /*flags*/, const int* /*prefetch*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmdispatch2)(intptr_t* fn, const int* iprec, const int* oprec, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const void* alpha, const void* beta, const int* flags, const int* prefetch) +{ +#if !defined(NDEBUG) + if (NULL != fn && NULL != m + && (NULL == iprec || (0 <= *iprec && *iprec < LIBXSMM_DATATYPE_UNSUPPORTED)) + && (NULL == oprec || (0 <= *oprec && *oprec < LIBXSMM_DATATYPE_UNSUPPORTED))) +#endif + { + const int gemm_flags = (NULL != flags ? *flags : LIBXSMM_FLAGS); + const libxsmm_gemm_descriptor* descriptor; + libxsmm_gemm_prefetch_type gemm_prefetch; + libxsmm_descriptor_blob blob; + libxsmm_code_pointer result; +#if !defined(NDEBUG) + const libxsmm_gemm_precision itype = (NULL != iprec ? ((libxsmm_gemm_precision)*iprec) : LIBXSMM_GEMM_PRECISION_F64); + const libxsmm_gemm_precision otype = (NULL != oprec ? ((libxsmm_gemm_precision)*oprec) : itype); + const libxsmm_blasint kk = *(NULL != k ? k : m), nn = (NULL != n ? *n : kk); +#else + const libxsmm_gemm_precision itype = (libxsmm_gemm_precision)*iprec, otype = (libxsmm_gemm_precision)*oprec; + const libxsmm_blasint kk = *k, nn = *n; +#endif + LIBXSMM_PRAGMA_FORCEINLINE + gemm_prefetch = libxsmm_get_gemm_xprefetch(prefetch); + LIBXSMM_PRAGMA_FORCEINLINE + descriptor = libxsmm_gemm_descriptor_init2(&blob, itype, otype, *m, nn, kk, + NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? *m : kk), + NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? kk : nn), + *(NULL != ldc ? ldc : m), alpha, beta, gemm_flags, gemm_prefetch); +#if !defined(NDEBUG) + if (NULL != descriptor) +#endif + { + LIBXSMM_PRAGMA_FORCEINLINE + result.xgemm = libxsmm_xmmdispatch(descriptor); + *fn = result.ival; + } +#if !defined(NDEBUG) + else { /* quiet */ + *fn = 0; + } +#endif + } +#if !defined(NDEBUG) + else { + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid argument passed into libxsmm_xmmdispatch!\n"); + } + if (NULL != fn) *fn = 0; + } +#endif +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmdispatch)(intptr_t* /*fn*/, const int* /*precision*/, + const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const libxsmm_blasint* /*k*/, + const libxsmm_blasint* /*lda*/, const libxsmm_blasint* /*ldb*/, const libxsmm_blasint* /*ldc*/, + const void* /*alpha*/, const void* /*beta*/, const int* /*flags*/, const int* /*prefetch*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmdispatch)(intptr_t* fn, const int* precision, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, + const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc, + const void* alpha, const void* beta, const int* flags, const int* prefetch) +{ + LIBXSMM_FSYMBOL(libxsmm_xmmdispatch2)(fn, precision, precision, m, n, k, lda, ldb, ldc, alpha, beta, flags, prefetch); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmcall_abc)( + const libxsmm_xmmfunction* /*fn*/, const void* /*a*/, const void* /*b*/, void* /*c*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmcall_abc)( + const libxsmm_xmmfunction* fn, const void* a, const void* b, void* c) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != fn && NULL != a && NULL != b && NULL != c) +#endif + { +#if !defined(NDEBUG) + if (NULL != fn->xmm) +#endif + { + fn->xmm(a, b, c); + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: NULL-function passed into libxsmm_xmmcall_abc!\n"); + } +#endif + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xmmcall_abc specified!\n"); + } +#endif +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmcall_prf)( + const libxsmm_xmmfunction* /*fn*/, const void* /*a*/, const void* /*b*/, void* /*c*/, + const void* /*pa*/, const void* /*pb*/, const void* /*pc*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmcall_prf)( + const libxsmm_xmmfunction* fn, const void* a, const void* b, void* c, + const void* pa, const void* pb, const void* pc) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != fn && NULL != a && NULL != b && NULL != c) +#endif + { +#if !defined(NDEBUG) + if (NULL != fn->xmm) +#endif + { + fn->xmm(a, b, c, pa, pb, pc); + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: NULL-function passed into libxsmm_xmmcall_prf!\n"); + } +#endif + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xmmcall_prf specified!\n"); + } +#endif +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmcall)( + const libxsmm_xmmfunction* /*fn*/, const void* /*a*/, const void* /*b*/, void* /*c*/, + const void* /*pa*/, const void* /*pb*/, const void* /*pc*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xmmcall)( + const libxsmm_xmmfunction* fn, const void* a, const void* b, void* c, + const void* pa, const void* pb, const void* pc) +{ + LIBXSMM_FSYMBOL(libxsmm_xmmcall_prf)(fn, a, b, c, pa, pb, pc); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xregister)(void** /*regval*/, const void* /*key*/, const int* /*keysize*/, + const int* /*valsize*/, const void* /*valinit*/, int* /*keyhash*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xregister)(void** regval, const void* key, const int* keysize, + const int* valsize, const void* valinit, int* keyhash) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != regval && NULL != key && NULL != keysize && NULL != valsize) +#endif + { + unsigned int hash = 0; + *regval = libxsmm_xregister(key, *keysize, *valsize, valinit, &hash); + if (NULL != keyhash) { + *keyhash = (hash & 0x7FFFFFFF/*sign-bit*/); + } + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xregister specified!\n"); + } +#endif +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xdispatch)(void** /*regval*/, const void* /*key*/, const int* /*keysize*/, int* /*keyhash*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xdispatch)(void** regval, const void* key, const int* keysize, int* keyhash) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != regval && NULL != key && NULL != keysize) +#endif + { + unsigned int hash = 0; + *regval = libxsmm_xdispatch(key, *keysize, &hash); + if (NULL != keyhash) { + *keyhash = (hash & 0x7FFFFFFF/*sign-bit*/); + } + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xdispatch specified!\n"); + } +#endif +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xrelease)(const void* /*key*/, const int* /*keysize*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xrelease)(const void* key, const int* keysize) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != key && NULL != keysize) +#endif + { + libxsmm_xrelease(key, *keysize); + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xrelease specified!\n"); + } +#endif +} + +#endif /*defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/ + diff --git a/third_party/libxsmm/src/libxsmm_main.h b/third_party/libxsmm/src/libxsmm_main.h new file mode 100644 index 0000000000000000000000000000000000000000..d33cc5dbfe1e2fe0671c0e43194a479c86036942 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_main.h @@ -0,0 +1,1069 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_MAIN_H +#define LIBXSMM_MAIN_H + +#include +/** + * TF includes src/libxsmm_main.h and uses LIBXSMM's sync primitives + * without including libxsmm_sync. However, libxsmm_sync.h shall be + * an explicit include separate from including libxsmm.h. + */ +#include "libxsmm_sync.h" + +/** Allow external definition to enable testing corner cases (exhausted registry space). */ +#if !defined(LIBXSMM_CAPACITY_REGISTRY) /* must be POT */ +# define LIBXSMM_CAPACITY_REGISTRY 131072 +#endif +#if !defined(LIBXSMM_CAPACITY_CACHE) /* must be POT */ +# define LIBXSMM_CAPACITY_CACHE 16 +#endif + +#if !defined(LIBXSMM_PAGE_MINSIZE) +# define LIBXSMM_PAGE_MINSIZE 4096 /* 4 KB */ +#endif + +#if !defined(LIBXSMM_BATCH_CHECK) && !defined(NDEBUG) +# define LIBXSMM_BATCH_CHECK +#endif + +#if !defined(LIBXSMM_NTHREADS_MAX) +# if (0 != LIBXSMM_SYNC) +# define LIBXSMM_NTHREADS_MAX 1024 +# else +# define LIBXSMM_NTHREADS_MAX 1 +# endif +#endif +/* relies on LIBXSMM_NTHREADS_MAX */ +#if !defined(LIBXSMM_NTHREADS_USE) && 0 +# define LIBXSMM_NTHREADS_USE +#endif +#if !defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) +# define LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS LIBXSMM_NTHREADS_MAX +#endif +#if !defined(LIBXSMM_MALLOC_SCRATCH_SCALE) +# define LIBXSMM_MALLOC_SCRATCH_SCALE 1.0 +#endif +#if !defined(LIBXSMM_MALLOC_LIMIT) +# define LIBXSMM_MALLOC_LIMIT (2U << 20) /* 2 MB */ +#endif +/* map memory also for non-executable buffers */ +#if !defined(LIBXSMM_MALLOC_MMAP) && 0 +# define LIBXSMM_MALLOC_MMAP +#endif +/* map memory for hooked allocation */ +#if !defined(LIBXSMM_MALLOC_MMAP_HOOK) && 1 +# define LIBXSMM_MALLOC_MMAP_HOOK +#endif +/* map memory for scratch buffers */ +#if !defined(LIBXSMM_MALLOC_MMAP_SCRATCH) && 1 +# define LIBXSMM_MALLOC_MMAP_SCRATCH +#endif +/* align even if interceptor is disabled at runtime */ +#if !defined(LIBXSMM_MALLOC_ALIGN_ALL) && 1 +# define LIBXSMM_MALLOC_ALIGN_ALL +#endif +#if !defined(LIBXSMM_MALLOC_HOOK_INTRINSIC) && 1 +# if defined(LIBXSMM_PLATFORM_X86) && defined(LIBXSMM_INTRINSICS_INCLUDE) && \ + !defined(LIBXSMM_INTRINSICS_DEBUG) && !defined(LIBXSMM_MALLOC_MMAP) +# define LIBXSMM_MALLOC_HOOK_INTRINSIC +# endif +#endif +#if !defined(LIBXSMM_MALLOC_HOOK_REALLOC) && 1 +# if !defined(LIBXSMM_MALLOC_HOOK_INTRINSIC) +# define LIBXSMM_MALLOC_HOOK_REALLOC +# endif +#endif +#if !defined(LIBXSMM_MALLOC_HOOK_CALLOC) && 1 +# define LIBXSMM_MALLOC_HOOK_CALLOC +#endif +#if !defined(LIBXSMM_MALLOC_INTERNAL_CALLER_ID) +# define LIBXSMM_MALLOC_INTERNAL_CALLER_ID ((uintptr_t)LIBXSMM_UNLIMITED) +#endif +#if !defined(LIBXSMM_MALLOC_INTERNAL_CALLER) +# define LIBXSMM_MALLOC_INTERNAL_CALLER ((const void*)(LIBXSMM_MALLOC_INTERNAL_CALLER_ID)) +#endif + +#if !defined(LIBXSMM_INTERCEPT_DYNAMIC) && defined(LIBXSMM_BUILD) && \ + (defined(__GNUC__) || defined(_CRAYC)) && !defined(_WIN32) && !defined(__CYGWIN__) && \ + !(defined(__APPLE__) && defined(__MACH__) && LIBXSMM_VERSION2(6, 1) >= \ + LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) +# define LIBXSMM_INTERCEPT_DYNAMIC +#endif + +#if !defined(LIBXSMM_MALLOC_HOOK_STATIC) && \ + (defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD))) /* GLIBC */ && \ + (!defined(_WIN32)) /* TODO */ +# define LIBXSMM_MALLOC_HOOK_STATIC +#endif +#if !defined(LIBXSMM_MALLOC_HOOK_DYNAMIC) && defined(LIBXSMM_INTERCEPT_DYNAMIC) && \ + defined(LIBXSMM_MALLOC_HOOK_STATIC) && !defined(_CRAYC) && !defined(__TRACE) +# define LIBXSMM_MALLOC_HOOK_DYNAMIC +#endif +#if (defined(LIBXSMM_MALLOC_HOOK_STATIC) || defined(LIBXSMM_MALLOC_HOOK_DYNAMIC)) +# define LIBXSMM_MALLOC_HOOK +#endif +#if !defined(LIBXSMM_DNN_CONVOLUTION_SETUP_USE_NTS) && defined(LIBXSMM_MALLOC_HOOK) && \ + (defined(LIBXSMM_MALLOC_ALIGN_ALL) || (defined(LIBXSMM_MALLOC) && (0 != LIBXSMM_MALLOC))) +# define LIBXSMM_DNN_CONVOLUTION_SETUP_USE_NTS +#endif + +#if defined(LIBXSMM_INTERCEPT_DYNAMIC) +# if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +# endif +# include +# if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +# endif +# if !defined(RTLD_NEXT) +# define LIBXSMM_RTLD_NEXT ((void*)-1l) +# else +# define LIBXSMM_RTLD_NEXT RTLD_NEXT +# endif +#endif + +#if !defined(LIBXSMM_VERBOSITY_HIGH) +# define LIBXSMM_VERBOSITY_HIGH 3 /* secondary warning or info-verbosity */ +#endif +#if !defined(LIBXSMM_VERBOSITY_WARN) +# define LIBXSMM_VERBOSITY_WARN ((LIBXSMM_VERBOSITY_HIGH) - LIBXSMM_MIN(1, LIBXSMM_VERBOSITY_HIGH)) +#endif + +#if !defined(LIBXSMM_LOCK) +# define LIBXSMM_LOCK LIBXSMM_LOCK_DEFAULT +#endif + +#if !defined(LIBXSMM_EXT_MIN_NTASKS) +# define LIBXSMM_MIN_NTASKS(NT) 1 +#endif +#if !defined(LIBXSMM_OVERHEAD) +# define LIBXSMM_OVERHEAD(NT) 0 +#endif +#if !defined(LIBXSMM_NOOP_ARGS) +# define LIBXSMM_NOOP_ARGS(...) +#endif +#if !defined(LIBXSMM_NOOP) +# define LIBXSMM_NOOP +#endif + +/** Check if M, N, K, or LDx fits into the descriptor. */ +#if (0 != LIBXSMM_ILP64) +# define LIBXSMM_GEMM_NO_BYPASS_DIMS(M, N, K) (0xFFFFFFFF >= (M) && 0xFFFFFFFF >= (N) && 0xFFFFFFFF >= (K)) +#else /* always fits */ +# define LIBXSMM_GEMM_NO_BYPASS_DIMS(M, N, K) 1 +#endif + +#if defined(LIBXSMM_ASSERT) /* assert available */ +# define LIBXSMM_GEMM_DESCRIPTOR_DIM_CHECK(M, N, K) LIBXSMM_ASSERT(LIBXSMM_GEMM_NO_BYPASS_DIMS(M, N, K)) +#else +# define LIBXSMM_GEMM_DESCRIPTOR_DIM_CHECK(M, N, K) +#endif + +#if defined(LIBXSMM_UNPACKED) +# define LIBXSMM_DESCRIPTOR_CLEAR_AUX(DST, SIZE) LIBXSMM_MEMSET127(DST, 0, SIZE) +#else +# define LIBXSMM_DESCRIPTOR_CLEAR_AUX(DST, SIZE) +#endif +#define LIBXSMM_DESCRIPTOR_CLEAR(BLOB) \ + LIBXSMM_ASSERT((LIBXSMM_DESCRIPTOR_MAXSIZE) == sizeof(*(BLOB))); \ + LIBXSMM_DESCRIPTOR_CLEAR_AUX(BLOB, LIBXSMM_DESCRIPTOR_MAXSIZE) + +/** Low-level/internal GEMM descriptor initialization. */ +#define LIBXSMM_GEMM_DESCRIPTOR(DESCRIPTOR, DATA_TYPE, FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) \ + LIBXSMM_GEMM_DESCRIPTOR_DIM_CHECK(LDA, LDB, LDC); \ + LIBXSMM_GEMM_DESCRIPTOR_DIM_CHECK(M, N, K); \ + LIBXSMM_DESCRIPTOR_CLEAR_AUX(&(DESCRIPTOR), sizeof(DESCRIPTOR)); \ + (DESCRIPTOR).datatype = (unsigned char)(DATA_TYPE); (DESCRIPTOR).prefetch = (unsigned char)(PREFETCH); \ + (DESCRIPTOR).flags = (unsigned int)((FLAGS) \ + /*| (LIBXSMM_NEQ(0, ALPHA) ? 0 : LIBXSMM_GEMM_FLAG_ALPHA_0)*/ \ + | (LIBXSMM_NEQ(0, BETA) ? 0 : LIBXSMM_GEMM_FLAG_BETA_0)); \ + (DESCRIPTOR).m = (unsigned int)(M); (DESCRIPTOR).n = (unsigned int)(N); (DESCRIPTOR).k = (unsigned int)(K); \ + (DESCRIPTOR).lda = (unsigned int)(LDA); (DESCRIPTOR).ldb = (unsigned int)(LDB); (DESCRIPTOR).ldc = (unsigned int)(LDC); \ + (DESCRIPTOR).meltw_datatype_aux = 0; (DESCRIPTOR).c1 = 0; (DESCRIPTOR).c2 = 0; (DESCRIPTOR).c3 = 0; \ + (DESCRIPTOR).meltw_ldx = 0; (DESCRIPTOR).meltw_ldy = 0; (DESCRIPTOR).meltw_ldz = 0; \ + (DESCRIPTOR).meltw_param = 0; (DESCRIPTOR).meltw_flags = 0; \ + (DESCRIPTOR).meltw_operation = 0 + +/** Similar to LIBXSMM_GEMM_DESCRIPTOR, but separately taking the input-/output-precision. */ +#define LIBXSMM_GEMM_DESCRIPTOR2(DESCRIPTOR, IPREC, OPREC, FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) \ + LIBXSMM_GEMM_DESCRIPTOR(DESCRIPTOR, LIBXSMM_GETENUM(IPREC, OPREC), FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) + +/** Declare and construct a GEMM descriptor. */ +#define LIBXSMM_GEMM_DESCRIPTOR_TYPE(DESCRIPTOR, DATA_TYPE, FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) \ + libxsmm_gemm_descriptor DESCRIPTOR; LIBXSMM_GEMM_DESCRIPTOR(DESCRIPTOR, DATA_TYPE, \ + FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) + +/** Similar to LIBXSMM_GEMM_DESCRIPTOR_TYPE, but separately taking the input-/output-precision. */ +#define LIBXSMM_GEMM_DESCRIPTOR2_TYPE(DESCRIPTOR, IPREC, OPREC, FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) \ + LIBXSMM_GEMM_DESCRIPTOR_TYPE(DESCRIPTOR, LIBXSMM_GETENUM(IPREC, OPREC), FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) + +#define LIBXSMM_REGDESC_DEFAULT +#define LIBXSMM_REGDESC(START, MODIFIER) \ + START libxsmm_gemm_descriptor MODIFIER gemm; \ + START libxsmm_meltw_descriptor MODIFIER meltw; \ + START libxsmm_meqn_descriptor MODIFIER meqn + +/** +* Packed structure, which stores the argument description of GEMM routines. +* The size of the structure is padded to LIBXSMM_DESCRIPTOR_MAXSIZE. +*/ +LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_gemm_descriptor { + /** Extents of the matrix. */ + unsigned int m, n, k; + /** Leading dimensions. */ + unsigned int lda, ldb, ldc; + /** Set of flags. */ + unsigned int flags; + /** Prefetch strategy. */ + unsigned char prefetch; + /** Denotes the data-type. */ + unsigned char datatype; + /** + * Do not reorder elements between above and below blocks! + */ + /** Denotes of optional eltwise data-type */ + unsigned char meltw_datatype_aux; + /** multipurpose 64-bit field, currently used for: a) stride_a in brgemm */ + unsigned long long c1; + /** multipurpose 64-bit field, currently used for: a) stride_b in brgemm */ + unsigned long long c2; + /** multipurpose 8-bit field, currently used for: a) unroll hint in brgemm */ + unsigned char c3; + /** LDx, LDy, LDz, additional meltw LDs */ + unsigned int meltw_ldx, meltw_ldy, meltw_ldz; + /** optional param field */ + unsigned char meltw_param; + /** Set of flags */ + unsigned short meltw_flags; + /** operation specifier */ + unsigned char meltw_operation; +}; + +/** Packed structure storing the mateltw argument description. */ +LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_meltw_descriptor { + /** LDx, M, and N. */ + unsigned int m, n, ldi, ldo, ldi2, ldi3; + /** Size of data element. */ + unsigned char datatype; + unsigned char datatype2; + /** Set of flags */ + unsigned short flags; + /** optional param field */ + unsigned char param; + /** operation specifier */ + unsigned char operation; +}; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_pspgemm_csr_descriptor { + const libxsmm_gemm_descriptor* gemm; + const unsigned int* row_ptr; + const unsigned int* column_idx; + const void* values; + unsigned int packed_width; +} libxsmm_pspgemm_csr_descriptor; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_pspgemm_csc_descriptor { + const libxsmm_gemm_descriptor* gemm; + const unsigned int* column_ptr; + const unsigned int* row_idx; + const void* values; + unsigned int packed_width; +} libxsmm_pspgemm_csc_descriptor; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_pgemm_ac_rm_descriptor { + const libxsmm_gemm_descriptor* gemm; + unsigned int packed_width; +} libxsmm_pgemm_ac_rm_descriptor; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_pgemm_bc_rm_descriptor { + const libxsmm_gemm_descriptor* gemm; + unsigned int packed_width; +} libxsmm_pgemm_bc_rm_descriptor; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_csr_reg_descriptor { + const libxsmm_gemm_descriptor* gemm; + const unsigned int* row_ptr; + const unsigned int* column_idx; + const void* values; +} libxsmm_csr_reg_descriptor; + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_xcopykernel { + libxsmm_meltwfunction_unary function; + const void* ptr; +} libxsmm_xcopykernel; + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_code_pointer { + void (*ptr_fn)(LIBXSMM_VARIADIC); + const void* ptr_const; + void* ptr; + uintptr_t uval; + intptr_t ival; + libxsmm_xmmfunction xgemm; /* GEMM: smm, dmm, wimm, or void-function */ + libxsmm_xmeltwfunction xmateltw; + libxsmm_matrix_eqn_function xmateqn; +} libxsmm_code_pointer; + +/** Structure which describes all tensors in LIBXSMM's DNN module */ +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_tensor { + libxsmm_dnn_tensor_datalayout* layout; /* data-layout descriptor */ + void* data; /* pointer to data */ + unsigned char scf; /* fix point scaling factor for this tensor */ +}; + +/* Structure to record segment in stream of code */ +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE segment_t { + int segment_type; + int n_convs; + int aux_index; + int img; + int ofm; + int ifm; +} segment_t; + +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_layer { + libxsmm_dnn_datatype datatype_in; + libxsmm_dnn_datatype datatype_out; + libxsmm_dnn_conv_desc desc; + libxsmm_dnn_conv_algo algo; + libxsmm_dnn_tensor_format buffer_format; + libxsmm_dnn_tensor_format filter_format; + libxsmm_dnn_conv_fuse_op fuse_ops; + libxsmm_dnn_conv_option options; + int target_archid; + + /* additional size for internal data types */ + int ifhp; + int ifwp; + int ofh; + int ofw; + int ofhp; + int ofwp; + int ifmblock; + int ofmblock; + int blocksifm; + int blocksofm; + int fwd_ofw_rb; + int fwd_ofh_rb; + int bwd_ofw_rb; + int bwd_ofh_rb; + int upd_ofw_rb; + int upd_ofh_rb; + int fm_lp_block; /* additional blocking for low precision datatypes of feature maps */ + int blocksifm_blocking; + int blocksofm_blocking; + int avoid_acc_load; + int avoid_acc_load_bwd; + int pack_input; + int pack_input_bwd; + int spread_input_bwd; + int weight_copies; + int loop_order; + int use_ofm_parallelization; + int use_ifm_parallelization; + int avoid_fmas_in_rim; + int upd_use_batchreduce; + int upd_pack_input; + int upd_loop_order; + int upd_linearized_tasklist; + int upd_avoid_rim_fmas; + int fwd_flags; + int bwd_flags; + int shuffle_filter_accesses; + int use_fallback_fwd_loops; + int use_fallback_bwd_loops; + int fwd_gemm_pixels; + int bwd_gemm_pixels; + int input_pixels; + int output_pixels; + int n_used_pixels; + int pixel_blocking; + int use_intermediate_f32_wt_tensor; + int upd_linearized_pixels; + int ifwp_extended; + int ofwp_extended; + int batchreduce_h_pixels; + int on_the_fly_input_packing; + int upd_pack_input_upfront; + int use_hybrid_imgofm_parallelization; + int remainder_pixels; + int pack_to_cnhw; + int fuse_upd_transposes; + int compute_pixels; + int upd_trans_w_only; + int fwd_padding_copy; + int upd_padding_copy; + int block_fwd_oj; + int block_fwd_ifm; + int block_fwd_ofm; + int block_bwd_oj; + int block_bwd_ifm; + int block_bwd_ofm; + int block_upd_ifm; + int block_upd_ofm; + + libxsmm_meltwfunction_unary tr_kernel; + libxsmm_meltwfunction_unary fwd_cvtfp32bf16_kernel; + + /* Hoisting the compute kernels for FWD */ + libxsmm_bsmmfunction fwd_config_kernel; + libxsmm_bsmmfunction_reducebatch_addr fwd_compute_kernel_addr; + libxsmm_bsmmfunction_reducebatch_offs fwd_compute_kernel_offs_b; + libxsmm_bmmfunction_reducebatch_offs fwd_compute_kernel_offs_a; + libxsmm_bmmfunction_reducebatch_strd fwd_compute_kernel_strd; + libxsmm_smmfunction_reducebatch_addr fwd_compute_kernel_addr_a_f32; + libxsmm_smmfunction_reducebatch_addr fwd_compute_kernel_addr_b_f32; + libxsmm_smmfunction_reducebatch_offs fwd_compute_kernel_offs_f32; + libxsmm_smmfunction_reducebatch_strd fwd_compute_kernel_strd_f32; + + /* Hoisting the compute kernels for BWD */ + libxsmm_bsmmfunction bwd_config_kernel; + libxsmm_bsmmfunction_reducebatch_addr bwd_compute_kernel_addr; + libxsmm_bsmmfunction_reducebatch_offs bwd_compute_kernel_offs; + libxsmm_bsmmfunction_reducebatch_strd bwd_compute_kernel_strd; + + /* Hoisting the compute kernels for UPD */ + libxsmm_bsmmfunction upd_config_kernel; + libxsmm_bsmmfunction_reducebatch_strd upd_compute_kernel_brgemm_no_linearized_pixels; + libxsmm_bsmmfunction_reducebatch_strd upd_compute_kernel_brgemm_linearized_pixels_hybrid_par_no_cnhw; + libxsmm_bsmmfunction upd_compute_kernel_gemm_linearized_pixels_hybrid_par_cnhw; + libxsmm_bsmmfunction upd_compute_kernel_gemm_linearized_pixels_no_hybrid_par; + + libxsmm_bsmmfunction tilerelease_kernel; + + unsigned long long *A_offsets; + unsigned long long *B_offsets; + unsigned long long *A_offsets_bwd; + unsigned long long *B_offsets_bwd; + + /* AMX specific fields */ + int x_rows; + int n_pixel_tiles; + int n_ofm_tiles; + int wrb_1; + int wrb_2; + int wrb_3; + int wrb_4; + int hrb_1; + int hrb_2; + int n_compute_pixels; + int pixels; + int linearize_pixels; + int split_pixel; + int reconfig; + int zero_rim; + char tc[64]; + char tc2[64]; + char tc_upd[64]; + int input_padded_pixels; + int output_padded_pixels; + int blocks_pixels; + /* End of AMX specific fields */ + + /* internal data representation */ + libxsmm_dnn_tensor* reg_input; + libxsmm_dnn_tensor* reg_output; + libxsmm_dnn_tensor* reg_filter; + libxsmm_dnn_tensor* grad_input; + libxsmm_dnn_tensor* grad_output; + libxsmm_dnn_tensor* grad_filter; + libxsmm_dnn_tensor* reg_bias; + libxsmm_dnn_tensor* grad_bias; + /* internal data representations for copies of tensors */ + libxsmm_dnn_tensor* reg_input_tr; + libxsmm_dnn_tensor* reg_filter_tr; + /* batchnorm stats */ + libxsmm_dnn_tensor* batch_stats; + /* maxstats used in low-precision kernels */ + libxsmm_dnn_tensor* maxstats_fwd; + libxsmm_dnn_tensor* maxstats_bwd; + libxsmm_dnn_tensor* maxstats_upd; + + /* barrier */ + libxsmm_barrier* barrier; + + /* scratch */ + size_t fwd_packing_padding_scratch_size; + size_t fwd_lp_output_full_scratch_size; + size_t fwd_lp_output_block_scratch_size; + size_t fwd_packing_padding_scratch_offset; + size_t fwd_lp_output_full_scratch_offset; + size_t fwd_lp_output_block_scratch_offset; + size_t fwd_scratch_size; + + size_t bwd_filter_trans_scratch_size; + size_t bwd_packing_padding_scratch_size; + size_t bwd_lp_input_full_scratch_size; + size_t bwd_filter_trans_scratch_offset; + size_t bwd_packing_padding_scratch_offset; + size_t bwd_lp_input_full_scratch_offset; + size_t bwd_scratch_size; + + size_t upd_packing_padding_scratch_size; + size_t upd_lp_output_full_scratch_size; + size_t upd_lp_input_full_scratch_size; + size_t upd_filter_scratch_size; + size_t upd_lp_filter_full_scratch_size; + size_t upd_packing_padding_scratch_offset; + size_t upd_lp_output_full_scratch_offset; + size_t upd_lp_input_full_scratch_offset; + size_t upd_lp_filter_full_scratch_offset; + size_t upd_filter_scratch_offset; + size_t upd_scratch_size; + + void* scratch; + size_t scratch_size; + + libxsmm_code_pointer gemm_fwd; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd2; /* ability to hoist forward GEMMs */ + + /* JIT-generated convolution code */ + libxsmm_code_pointer code_fwd[3]; + libxsmm_code_pointer code_bwd[3]; + libxsmm_code_pointer code_upd[5]; + + libxsmm_code_pointer matcopy_fwd[4]; + libxsmm_code_pointer matcopy_bwd[4]; + libxsmm_code_pointer matcopy_upd[3]; +}; + +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedbatchnorm { + libxsmm_dnn_fusedbatchnorm_desc desc; + libxsmm_dnn_tensor* reg_input; /* input tensor */ + libxsmm_dnn_tensor* reg_output; /* output tensor */ + libxsmm_dnn_tensor* grad_input; /* grad input tensor */ + libxsmm_dnn_tensor* grad_output; /* grad output tensor */ + libxsmm_dnn_tensor* reg_add; /* elementwise tensor */ + libxsmm_dnn_tensor* grad_add; /* grad elementwise tensor */ + libxsmm_dnn_tensor* reg_beta; /* beta tensor */ + libxsmm_dnn_tensor* reg_gamma; /* gamma tensor */ + libxsmm_dnn_tensor* grad_beta; /* grad beta tensor */ + libxsmm_dnn_tensor* grad_gamma; /* grad gamma tensor */ + libxsmm_dnn_tensor* expvalue; /* expected value */ + libxsmm_dnn_tensor* rcpstddev; /* reciprocal of standard derivation */ + libxsmm_dnn_tensor* variance; /* variance */ + libxsmm_dnn_tensor* relumask; /* relumask */ + libxsmm_barrier* barrier; /* barrier */ + int ifmblock; + int ofmblock; + int blocksifm; + int blocksofm; + size_t scratch_size; + void* scratch; +}; + +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_softmaxloss { + libxsmm_dnn_softmaxloss_desc desc; + libxsmm_dnn_tensor* reg_input; /* input tensor */ + libxsmm_dnn_tensor* reg_output; /* output tensor */ + libxsmm_dnn_tensor* grad_input; /* grad input tensor */ + libxsmm_dnn_tensor* label; /* labels tensor */ + libxsmm_barrier* barrier; /* barrier */ + int bc; + int Bc; + int bn; + int Bn; + float loss; + size_t scratch_size; + void* scratch; +}; + +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_optimizer { + libxsmm_dnn_optimizer_desc desc; + libxsmm_dnn_tensor* reg_filter; /* filter tensor */ + libxsmm_dnn_tensor* grad_filter; /* grad filter tensor */ + libxsmm_dnn_tensor* master_filter; /* master filter tensor */ + libxsmm_barrier* barrier; /* barrier */ + int bc; + int Bc; + int bk; + int Bk; + int fm_lp_block; + size_t scratch_size; + void* scratch; +}; + +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedgroupnorm { + libxsmm_dnn_fusedgroupnorm_desc desc; + libxsmm_dnn_tensor* reg_input; /* input tensor */ + libxsmm_dnn_tensor* reg_output; /* output tensor */ + libxsmm_dnn_tensor* grad_input; /* grad input tensor */ + libxsmm_dnn_tensor* grad_output; /* grad output tensor */ + libxsmm_dnn_tensor* reg_add; /* elementwise tensor */ + libxsmm_dnn_tensor* grad_add; /* grad elementwise tensor */ + libxsmm_dnn_tensor* reg_beta; /* beta tensor */ + libxsmm_dnn_tensor* reg_gamma; /* gamma tensor */ + libxsmm_dnn_tensor* grad_beta; /* grad beta tensor */ + libxsmm_dnn_tensor* grad_gamma; /* grad gamma tensor */ + libxsmm_dnn_tensor* expvalue; /* expected value */ + libxsmm_dnn_tensor* rcpstddev; /* reciprocal of standard derivation */ + libxsmm_dnn_tensor* variance; /* variance */ + libxsmm_dnn_tensor* relumask; /* relumask */ + libxsmm_barrier* barrier; /* barrier */ + int ifmblock; + int ofmblock; + int blocksifm; + int blocksofm; + size_t scratch_size; + void* scratch; +}; + +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_fullyconnected { + libxsmm_dnn_fullyconnected_desc desc; + libxsmm_dnn_tensor* reg_input; /* input tensor */ + libxsmm_dnn_tensor* reg_output; /* output tensor */ + libxsmm_dnn_tensor* grad_input; /* grad input tensor */ + libxsmm_dnn_tensor* grad_output; /* grad output tensor */ + libxsmm_dnn_tensor* reg_filter; /* filter tensor */ + libxsmm_dnn_tensor* grad_filter; /* grad filter tensor */ + libxsmm_dnn_tensor* reg_bias; /* bias tensor */ + libxsmm_dnn_tensor* grad_bias; /* grad bais tensor */ + libxsmm_dnn_tensor* relumask; /* relumask */ + libxsmm_barrier* barrier; /* barrier */ + int target_archid; + + int ifmblock; + int ofmblock; + int blocksifm; + int blocksofm; + /* Parameters to tune/specialize FC algorithms */ + int fwd_2d_blocking; + int bwd_2d_blocking; + int upd_2d_blocking; + int fwd_bf; + int bwd_bf; + int upd_bf; + int fwd_row_teams; + int fwd_column_teams; + int bwd_row_teams; + int bwd_column_teams; + int upd_row_teams; + int upd_column_teams; + int ifm_subtasks; + int ofm_subtasks; + int compressed_A; + int sparsity_factor_A; + + int fm_lp_block; + int bn; + int bk; + int bc; + size_t scratch_size; + size_t doutput_scratch_mark; + void* scratch; + + libxsmm_bsmmfunction fwd_config_kernel; + libxsmm_bsmmfunction bwd_config_kernel; + libxsmm_bsmmfunction upd_config_kernel; + libxsmm_bsmmfunction tilerelease_kernel; + + libxsmm_meltwfunction_unary tr_kernel; + libxsmm_code_pointer gemm_fwd; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd2; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd3; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd4; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd5; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd6; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd7; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd8; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd9; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd10; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd11; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd12; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd13; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd14; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd15; /* ability to hoist forward GEMMs */ + libxsmm_code_pointer gemm_fwd16; /* ability to hoist forward GEMMs */ + + libxsmm_code_pointer gemm_bwd; /* ability to hoist backward GEMMs */ + libxsmm_code_pointer gemm_bwd2; /* ability to hoist backward GEMMs */ + libxsmm_code_pointer gemm_bwd3; /* ability to hoist backward GEMMs */ + libxsmm_code_pointer gemm_upd; /* ability to hoist update GEMMs */ + libxsmm_code_pointer gemm_upd2; /* ability to hoist update GEMMs */ + libxsmm_code_pointer gemm_upd3; /* ability to hoist update GEMMs */ + + /* JITed eltwise kernels... */ + libxsmm_meltwfunction_unary fwd_cvtfp32bf16_kernel; + libxsmm_meltwfunction_unary bwd_cvtfp32bf16_kernel; + libxsmm_meltwfunction_unary bwd_relu_kernel; + libxsmm_meltwfunction_unary fwd_cvtfp32bf16_relu_kernel; + libxsmm_meltwfunction_unary fwd_sigmoid_cvtfp32bf16_kernel; +}; + +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_pooling { + libxsmm_dnn_pooling_desc desc; + libxsmm_dnn_tensor* reg_input; /* input tensor */ + libxsmm_dnn_tensor* reg_output; /* output tensor */ + libxsmm_dnn_tensor* grad_input; /* grad input tensor */ + libxsmm_dnn_tensor* grad_output; /* grad output tensor */ + libxsmm_dnn_tensor* mask; /* elementwise tensor */ + libxsmm_barrier* barrier; /* barrier */ + int ifmblock; + int ofmblock; + int blocksifm; + int blocksofm; + int ofh; + int ofw; + size_t scratch_size; + void* scratch; +}; + +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_rnncell { + libxsmm_dnn_rnncell_desc desc; + libxsmm_blasint T; /* sequence length, must be smaller than max sequence length in desc */ + libxsmm_blasint bk; + libxsmm_blasint bn; + libxsmm_blasint bc; + libxsmm_blasint lpb; + + /* external tensors */ + libxsmm_dnn_tensor* xt; + libxsmm_dnn_tensor* csp; + libxsmm_dnn_tensor* hp; + libxsmm_dnn_tensor* w; + libxsmm_dnn_tensor* wt; + libxsmm_dnn_tensor* r; + libxsmm_dnn_tensor* rt; + libxsmm_dnn_tensor* b; + libxsmm_dnn_tensor* cst; + libxsmm_dnn_tensor* ht; + libxsmm_dnn_tensor* dxt; + libxsmm_dnn_tensor* dcsp; + libxsmm_dnn_tensor* dhp; + libxsmm_dnn_tensor* dw; + libxsmm_dnn_tensor* dr; + libxsmm_dnn_tensor* db; + libxsmm_dnn_tensor* dcs; + libxsmm_dnn_tensor* dht; + libxsmm_dnn_tensor* it; + libxsmm_dnn_tensor* ft; + libxsmm_dnn_tensor* ot; + libxsmm_dnn_tensor* cit; + libxsmm_dnn_tensor* cot; + float forget_bias; + /* internal state */ + void* internal_z; + /* scratch pointers */ + void* scratch_base; + void* scratch_wT; + void* scratch_rT; + void* scratch_w; + void* scratch_r; + void* scratch_xT; + void* scratch_hT; + void* scratch_deltat; + void* scratch_di; + void* scratch_df; + void* scratch_do; + void* scratch_dci; + void* scratch_diB; + void* scratch_dfB; + void* scratch_dpB; + void* scratch_dciB; + void* scratch_dx; + void* scratch_dhp; + void* scratch_db; + void* scratch_t1; + void* scratch_t2; + void* csp_scratch; + void* cst_scratch; + void* ht_scratch; + void* it_scratch; + void* ft_scratch; + void* ot_scratch; + void* cit_scratch; + void* cot_scratch; + /* options */ + int use_fwd_fused_impl; + int fwd_block; + int bwdupd_block; + int fwd_generic; + int bwdupd_generic; + /* Ability to hoist GEMMs */ + libxsmm_bsmmfunction_reducebatch_strd fwd_kernela; + libxsmm_bsmmfunction_reducebatch_strd fwd_kernelb; + libxsmm_bsmmfunction_reducebatch_addr fwd_tileconfig; + libxsmm_bsmmfunction_reducebatch_strd bwdupd_kernela; + libxsmm_bsmmfunction_reducebatch_strd bwdupd_kernelb; + libxsmm_bsmmfunction_reducebatch_strd bwdupd_kernelc; + libxsmm_bsmmfunction_reducebatch_strd bwdupd_kerneld; + libxsmm_bsmmfunction_reducebatch_addr bwdupd_tileconfig; + libxsmm_bsmmfunction tilerelease_kernel; + libxsmm_barrier* barrier; /* barrier */ +}; + +struct LIBXSMM_RETARGETABLE libxsmm_dfsspmdm { + int M; + int N; + int K; + int ldb; + int ldc; + int N_chunksize; + double* a_dense; + libxsmm_dmmfunction kernel; +}; + +struct LIBXSMM_RETARGETABLE libxsmm_sfsspmdm { + int M; + int N; + int K; + int ldb; + int ldc; + int N_chunksize; + float* a_dense; + libxsmm_smmfunction kernel; +}; + +/** Packed structure storing the mateltw argument description. */ +LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_meqn_descriptor { + /** LDx, M, and N. */ + unsigned int m, n, ldo; + /** Size of data element. */ + unsigned char datatype; + /** Set of flags */ + unsigned int eqn_idx; +}; + +typedef enum libxsmm_build_kind { + LIBXSMM_BUILD_KIND_GEMM = LIBXSMM_KERNEL_KIND_MATMUL, + LIBXSMM_BUILD_KIND_MELTW = LIBXSMM_KERNEL_KIND_MELTW, + LIBXSMM_BUILD_KIND_MEQN = LIBXSMM_KERNEL_KIND_MEQN, + LIBXSMM_BUILD_KIND_USER = LIBXSMM_KERNEL_KIND_USER, + LIBXSMM_BUILD_KIND_PGEMMRMAC = LIBXSMM_KERNEL_UNREGISTERED, + LIBXSMM_BUILD_KIND_PGEMMRMBC, + LIBXSMM_BUILD_KIND_PSPGEMM_CSR, + LIBXSMM_BUILD_KIND_PSPGEMM_CSC, + LIBXSMM_BUILD_KIND_SREG +} libxsmm_build_kind; + +/** Integral type (libxsmm_kernel_kind, libxsmm_build_kind). */ +#if defined(LIBXSMM_UNPACKED) +# define LIBXSMM_DESCRIPTOR_BIG(KIND) ((libxsmm_descriptor_kind)((KIND) | 0x8000000000000000)) +# define LIBXSMM_DESCRIPTOR_ISBIG(KIND) ((int)(((libxsmm_descriptor_kind)(KIND)) >> 63)) +# define LIBXSMM_DESCRIPTOR_KIND(KIND) ((int)(((libxsmm_descriptor_kind)(KIND)) & 0x7FFFFFFFFFFFFFFF)) +typedef uint64_t libxsmm_descriptor_kind; +#else +# define LIBXSMM_DESCRIPTOR_BIG(KIND) ((libxsmm_descriptor_kind)((KIND) | 0x80)) +# define LIBXSMM_DESCRIPTOR_ISBIG(KIND) ((int)((KIND) >> 7)) +# define LIBXSMM_DESCRIPTOR_KIND(KIND) ((int)((KIND) & 0x7F)) +typedef unsigned char libxsmm_descriptor_kind; +#endif + +/** All descriptor types, which are valid for code-registration. */ +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_descriptor { + char data[LIBXSMM_DESCRIPTOR_MAXSIZE]; + libxsmm_descriptor_kind kind; /* kind: must be the first member */ + LIBXSMM_REGDESC(LIBXSMM_PACKED(struct) { libxsmm_descriptor_kind /*repeated kind*/ pad; , desc; }); + LIBXSMM_PACKED(struct) { libxsmm_descriptor_kind /*repeated kind*/ pad; char desc[1]; } user; +} libxsmm_descriptor; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_build_request { + union { + const void* ptr; /* raw content */ + LIBXSMM_REGDESC(LIBXSMM_REGDESC_DEFAULT, const*); + const libxsmm_pspgemm_csr_descriptor* pspgemm_csr; + const libxsmm_pspgemm_csc_descriptor* pspgemm_csc; + const libxsmm_pgemm_ac_rm_descriptor* pgemmacrm; + const libxsmm_pgemm_bc_rm_descriptor* pgemmbcrm; + const libxsmm_csr_reg_descriptor* sreg; + } descriptor; + libxsmm_build_kind kind; + /* used by user-kind */ + size_t user_size; +} libxsmm_build_request; + +typedef enum libxsmm_malloc_flags { + LIBXSMM_MALLOC_FLAG_DEFAULT = 0, + LIBXSMM_MALLOC_FLAG_SCRATCH = 1, + LIBXSMM_MALLOC_FLAG_PRIVATE = 2, + LIBXSMM_MALLOC_FLAG_REALLOC = 4, + LIBXSMM_MALLOC_FLAG_PHUGE = 8, + LIBXSMM_MALLOC_FLAG_PLOCK = 16, + LIBXSMM_MALLOC_FLAG_MMAP = 32, + LIBXSMM_MALLOC_FLAG_R = 64, + LIBXSMM_MALLOC_FLAG_W = 128, + LIBXSMM_MALLOC_FLAG_X = 256, + LIBXSMM_MALLOC_FLAG_RW = LIBXSMM_MALLOC_FLAG_R | LIBXSMM_MALLOC_FLAG_W, + LIBXSMM_MALLOC_FLAG_WX = LIBXSMM_MALLOC_FLAG_X | LIBXSMM_MALLOC_FLAG_W, + LIBXSMM_MALLOC_FLAG_RWX = LIBXSMM_MALLOC_FLAG_X | LIBXSMM_MALLOC_FLAG_RW, + LIBXSMM_MALLOC_FLAG_VALID = LIBXSMM_MALLOC_FLAG_SCRATCH | + LIBXSMM_MALLOC_FLAG_PRIVATE | LIBXSMM_MALLOC_FLAG_REALLOC | + LIBXSMM_MALLOC_FLAG_PHUGE | LIBXSMM_MALLOC_FLAG_PLOCK | + LIBXSMM_MALLOC_FLAG_MMAP | LIBXSMM_MALLOC_FLAG_RWX +} libxsmm_malloc_flags; + +LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void* (*libxsmm_realloc_fun)(void* /*ptr*/, size_t /*size*/); + +#if defined(LIBXSMM_MALLOC_HOOK_DYNAMIC) +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_malloc_fntype { + union { const void* dlsym; void* (*ptr)(size_t, size_t); } alignmem; + union { const void* dlsym; void* (*ptr)(size_t, size_t); } memalign; + union { const void* dlsym; libxsmm_malloc_fun ptr; } malloc; +# if defined(LIBXSMM_MALLOC_HOOK_CALLOC) + union { const void* dlsym; void* (*ptr)(size_t, size_t); } calloc; +# endif +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + union { const void* dlsym; libxsmm_realloc_fun ptr; } realloc; +# endif + union { const void* dlsym; libxsmm_free_fun ptr; } free; +} libxsmm_malloc_fntype; +LIBXSMM_APIVAR_PRIVATE(libxsmm_malloc_fntype libxsmm_malloc_fn); +#endif + +#if (defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD))) +/* prototypes for GLIBC internal implementation */ +LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void* __libc_memalign(size_t alignment, size_t size); +LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void* __libc_malloc(size_t size); +#if defined(LIBXSMM_MALLOC_HOOK_CALLOC) +LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void* __libc_calloc(size_t num, size_t size); +#endif +#if defined(LIBXSMM_MALLOC_HOOK_REALLOC) +LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void* __libc_realloc(void* ptr, size_t size); +#endif +LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void __libc_free(void* ptr); +#endif /*(defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD)))*/ + +LIBXSMM_API_INTERN void* libxsmm_memalign_internal(size_t alignment, size_t size); + +/* See https://sourceware.org/binutils/docs-2.34/ld/Options.html#index-_002d_002dwrap_003dsymbol */ +LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void* __real_memalign(size_t alignment, size_t size); +LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void* __real_malloc(size_t size); +#if defined(LIBXSMM_MALLOC_HOOK_CALLOC) +LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void* __real_calloc(size_t num, size_t size); +#endif +#if defined(LIBXSMM_MALLOC_HOOK_REALLOC) +LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void* __real_realloc(void* ptr, size_t size); +#endif +LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void __real_free(void* ptr); + +/** Retrieve internal information about a buffer (default memory domain). */ +LIBXSMM_API int libxsmm_get_malloc_xinfo(const void* memory, size_t* size, int* flags, void** extra); + +/** Initializes malloc hooks and other internals. */ +LIBXSMM_API_INTERN void libxsmm_malloc_init(void); +LIBXSMM_API_INTERN void libxsmm_malloc_finalize(void); + +/** Calculates an alignment depending on supposedly allocated size; alignment can be zero ("auto"). */ +LIBXSMM_API_INTERN size_t libxsmm_alignment(size_t size, size_t alignment); + +/** Same as libxsmm_set_default_allocator, but takes a lock (can be NULL). */ +LIBXSMM_API_INTERN int libxsmm_xset_default_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock, + const void* context, libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn); +/** Same as libxsmm_get_default_allocator, but takes a lock (can be NULL). */ +LIBXSMM_API_INTERN int libxsmm_xget_default_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock, + const void** context, libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn); + +/** Same as libxsmm_set_scratch_allocator, but takes a lock (can be NULL). */ +LIBXSMM_API_INTERN int libxsmm_xset_scratch_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock, + const void* context, libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn); +/** Same as libxsmm_get_scratch_allocator, but takes a lock (can be NULL). */ +LIBXSMM_API_INTERN int libxsmm_xget_scratch_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock, + const void** context, libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn); + +/** + * Attribute memory allocation and protect with only the necessary flags. + * This procedure is expected to run only one time per buffer, and may + * relocate the given memory. + */ +LIBXSMM_API_INTERN int libxsmm_malloc_attrib(void** memory, int flags, + /** If a name is given, an executable buffer will be dumped into a file. */ + const char* name); + +/** Like libxsmm_release_scratch, but takes a lock (can be NULL). */ +LIBXSMM_API_INTERN void libxsmm_xrelease_scratch(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock); + +/** Allocate memory of the requested size, which is aligned according to the given alignment. */ +LIBXSMM_API int libxsmm_xmalloc(void** memory, size_t size, size_t alignment, int flags, + /* The extra information is stored along with the allocated chunk; can be NULL/zero. */ + const void* extra, size_t extra_size); +/** Release memory, which was allocated using libxsmm_[*]malloc. */ +LIBXSMM_API void libxsmm_xfree(const void* memory, int check); + +/** + * Format for instance an amount of Bytes like libxsmm_format_value(result, sizeof(result), nbytes, "KMGT", "B", 10). + * The value returned is in requested/determined unit so that the user can decide about printing the buffer. + */ +LIBXSMM_API_INTERN size_t libxsmm_format_value(char buffer[32], int buffer_size, size_t nbytes, const char scale[], const char* unit, int base); + +/** Returns the type-name of data-type (can be also libxsmm_gemm_precision). */ +LIBXSMM_API_INTERN const char* libxsmm_typename(libxsmm_datatype datatype); + +/** Dump data and (optionally) checks attempt to dump different data into an existing file (unique). */ +LIBXSMM_API_INTERN int libxsmm_dump(const char* title, const char* name, const void* data, size_t size, int unique); + +/** Services a build request, and (optionally) registers the code (use regindex=LIBXSMM_CAPACITY_REGISTRY for unmanaged code). */ +LIBXSMM_API_INTERN int libxsmm_build(const libxsmm_build_request* request, unsigned int regindex, libxsmm_code_pointer* code); + +/** Returns the type-size of data-type (can be also libxsmm_gemm_precision). */ +LIBXSMM_API unsigned char libxsmm_typesize(libxsmm_datatype datatype); + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_kernel_xinfo { + /** Non-zero if kernel is registered. */ + unsigned int registered; + /** Number of FLoating Point OPerationS (FLOPS). */ + unsigned int nflops; +} libxsmm_kernel_xinfo; + +/** Receive information about JIT-generated code. */ +LIBXSMM_API_INTERN const libxsmm_kernel_xinfo* libxsmm_get_kernel_xinfo(libxsmm_code_pointer code, const libxsmm_descriptor** desc, size_t* code_size); + +/** Calculates duration in seconds from given RTC ticks. */ +LIBXSMM_API_INTERN double libxsmm_timer_duration_rtc(libxsmm_timer_tickint tick0, libxsmm_timer_tickint tick1); +/** Returns the current tick of platform-specific real-time clock. */ +LIBXSMM_API_INTERN libxsmm_timer_tickint libxsmm_timer_tick_rtc(void); +/** Returns the current tick of a (monotonic) platform-specific counter. */ +LIBXSMM_API_INTERN libxsmm_timer_tickint libxsmm_timer_tick_tsc(void); + +LIBXSMM_API_INTERN void libxsmm_memory_init(int target_arch); +LIBXSMM_API_INTERN void libxsmm_memory_finalize(void); + +LIBXSMM_API_INTERN void libxsmm_dnn_init(int target_arch); +LIBXSMM_API_INTERN void libxsmm_dnn_finalize(void); + +/** intern function to calculate blockings, that's private API hence it's in this function */ +LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_get_feature_map_blocks( + int C, int K, int* C_block, int* K_block, int* fm_lp_block, + libxsmm_dnn_datatype datatype_in, libxsmm_dnn_datatype datatype_out); + +/** Global lock; create an own lock for an independent domain. */ +LIBXSMM_APIVAR_PUBLIC(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK) libxsmm_lock_global); +/** Determines whether a threaded implementation is synchronized or not. */ +LIBXSMM_APIVAR_PUBLIC(int libxsmm_nosync); + +/** Function used to allocate default memory. */ +LIBXSMM_APIVAR_PRIVATE(libxsmm_malloc_function libxsmm_default_malloc_fn); +/** Function used to allocate scratch memory. */ +LIBXSMM_APIVAR_PRIVATE(libxsmm_malloc_function libxsmm_scratch_malloc_fn); +/** Function used to release default memory. */ +LIBXSMM_APIVAR_PRIVATE(libxsmm_free_function libxsmm_default_free_fn); +/** Function used to release scratch memory. */ +LIBXSMM_APIVAR_PRIVATE(libxsmm_free_function libxsmm_scratch_free_fn); +/** If non-NULL, this context is used by the context-form of memory allocation. */ +LIBXSMM_APIVAR_PRIVATE(const void* libxsmm_default_allocator_context); +/** If non-NULL, this context is used by the context-form of memory allocation. */ +LIBXSMM_APIVAR_PRIVATE(const void* libxsmm_scratch_allocator_context); +/** Number of scratch memory pools used; clamped against internal maximum. */ +LIBXSMM_APIVAR_PRIVATE(unsigned int libxsmm_scratch_pools); +/** Growth factor used to scale the scratch memory in case of reallocation. */ +LIBXSMM_APIVAR_PRIVATE(double libxsmm_scratch_scale); +/** Number of seconds per RDTSC-cycle (zero or negative if RDTSC invalid). */ +LIBXSMM_APIVAR_PRIVATE(double libxsmm_timer_scale); +/** Counts the number of attempts to create an SPMDM-handle. */ +LIBXSMM_APIVAR_PRIVATE(unsigned int libxsmm_statistic_num_spmdm); +/** Counts the maximum number of thread that have been active. */ +LIBXSMM_APIVAR_PRIVATE(unsigned int libxsmm_thread_count); + +#if (0 != LIBXSMM_SYNC) +LIBXSMM_APIVAR_PRIVATE(LIBXSMM_TLS_TYPE libxsmm_tlskey); +#endif + +#endif /*LIBXSMM_MAIN_H*/ + diff --git a/third_party/libxsmm/src/libxsmm_malloc.c b/third_party/libxsmm/src/libxsmm_malloc.c new file mode 100644 index 0000000000000000000000000000000000000000..6555d5cc9b1dd6e815972769cc3bafe8e3fdf477 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_malloc.c @@ -0,0 +1,2617 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_trace.h" +#include "libxsmm_main.h" +#include "libxsmm_hash.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#if (defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD))) +# include +# include +#endif +#if !defined(LIBXSMM_MALLOC_GLIBC) +# if defined(__GLIBC__) +# define LIBXSMM_MALLOC_GLIBC __GLIBC__ +# else +# define LIBXSMM_MALLOC_GLIBC 6 +# endif +#endif +#if defined(_WIN32) +# include +# include +# include +#else +# include +# if defined(__linux__) +# include +# include +# endif +# if defined(MAP_POPULATE) +# include +# endif +# include +# include +# include +# include +# if defined(__MAP_ANONYMOUS) +# define LIBXSMM_MAP_ANONYMOUS __MAP_ANONYMOUS +# elif defined(MAP_ANONYMOUS) +# define LIBXSMM_MAP_ANONYMOUS MAP_ANONYMOUS +# elif defined(MAP_ANON) +# define LIBXSMM_MAP_ANONYMOUS MAP_ANON +# else +# define LIBXSMM_MAP_ANONYMOUS 0x20 +# endif +# if defined(MAP_SHARED) +# define LIBXSMM_MAP_SHARED MAP_SHARED +# else +# define LIBXSMM_MAP_SHARED 0 +# endif +LIBXSMM_EXTERN int ftruncate(int, off_t) LIBXSMM_THROW; +LIBXSMM_EXTERN int mkstemp(char*) LIBXSMM_NOTHROW; +#endif +#if !defined(LIBXSMM_MALLOC_FINAL) +# define LIBXSMM_MALLOC_FINAL 3 +#endif +#if defined(LIBXSMM_VTUNE) +# if (2 <= LIBXSMM_VTUNE) /* no header file required */ +# if !defined(LIBXSMM_VTUNE_JITVERSION) +# define LIBXSMM_VTUNE_JITVERSION LIBXSMM_VTUNE +# endif +# define LIBXSMM_VTUNE_JIT_DESC_TYPE iJIT_Method_Load_V2 +# define LIBXSMM_VTUNE_JIT_LOAD 21 +# define LIBXSMM_VTUNE_JIT_UNLOAD 14 +# define iJIT_SAMPLING_ON 0x0001 +LIBXSMM_EXTERN unsigned int iJIT_GetNewMethodID(void); +LIBXSMM_EXTERN /*iJIT_IsProfilingActiveFlags*/int iJIT_IsProfilingActive(void); +LIBXSMM_EXTERN int iJIT_NotifyEvent(/*iJIT_JVM_EVENT*/int event_type, void *EventSpecificData); +LIBXSMM_EXTERN_C typedef struct LineNumberInfo { + unsigned int Offset; + unsigned int LineNumber; +} LineNumberInfo; +LIBXSMM_EXTERN_C typedef struct iJIT_Method_Load_V2 { + unsigned int method_id; + char* method_name; + void* method_load_address; + unsigned int method_size; + unsigned int line_number_size; + LineNumberInfo* line_number_table; + char* class_file_name; + char* source_file_name; + char* module_name; +} iJIT_Method_Load_V2; +# else /* more safe due to header dependency */ +# include +# if !defined(LIBXSMM_VTUNE_JITVERSION) +# define LIBXSMM_VTUNE_JITVERSION 2 +# endif +# if (2 <= LIBXSMM_VTUNE_JITVERSION) +# define LIBXSMM_VTUNE_JIT_DESC_TYPE iJIT_Method_Load_V2 +# define LIBXSMM_VTUNE_JIT_LOAD iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 +# else +# define LIBXSMM_VTUNE_JIT_DESC_TYPE iJIT_Method_Load +# define LIBXSMM_VTUNE_JIT_LOAD iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED +# endif +# define LIBXSMM_VTUNE_JIT_UNLOAD iJVM_EVENT_TYPE_METHOD_UNLOAD_START +# endif +# if !defined(LIBXSMM_MALLOC_FALLBACK) +# define LIBXSMM_MALLOC_FALLBACK LIBXSMM_MALLOC_FINAL +# endif +#else /* VTune JIT-API not enabled */ +# if !defined(LIBXSMM_MALLOC_FALLBACK) +# define LIBXSMM_MALLOC_FALLBACK 0 +# endif +#endif /*defined(LIBXSMM_VTUNE)*/ +#if !defined(LIBXSMM_MALLOC_XMAP_TEMPLATE) +# define LIBXSMM_MALLOC_XMAP_TEMPLATE ".libxsmm_jit." LIBXSMM_MKTEMP_PATTERN +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif +#if defined(LIBXSMM_PERF) +# include "libxsmm_perf.h" +#endif + +#if !defined(LIBXSMM_MALLOC_ALIGNMAX) +# define LIBXSMM_MALLOC_ALIGNMAX (2 << 20) /* 2 MB */ +#endif +#if !defined(LIBXSMM_MALLOC_ALIGNFCT) +# define LIBXSMM_MALLOC_ALIGNFCT 16 +#endif +#if !defined(LIBXSMM_MALLOC_SEED) +# define LIBXSMM_MALLOC_SEED 1051981 +#endif + +#if !defined(LIBXSMM_MALLOC_HOOK_KMP) && 0 +# define LIBXSMM_MALLOC_HOOK_KMP +#endif +#if !defined(LIBXSMM_MALLOC_HOOK_QKMALLOC) && 0 +# define LIBXSMM_MALLOC_HOOK_QKMALLOC +#endif +#if !defined(LIBXSMM_MALLOC_HOOK_IMALLOC) && 1 +# define LIBXSMM_MALLOC_HOOK_IMALLOC +#endif +#if !defined(LIBXSMM_MALLOC_HOOK_CHECK) && 0 +# define LIBXSMM_MALLOC_HOOK_CHECK 1 +#endif + +#if !defined(LIBXSMM_MALLOC_CRC_LIGHT) && !defined(_DEBUG) && 1 +# define LIBXSMM_MALLOC_CRC_LIGHT +#endif +#if !defined(LIBXSMM_MALLOC_CRC_OFF) +# if defined(NDEBUG) && !defined(LIBXSMM_MALLOC_HOOK) +# define LIBXSMM_MALLOC_CRC_OFF +# elif !defined(LIBXSMM_BUILD) +# define LIBXSMM_MALLOC_CRC_OFF +# endif +#endif + +#if !defined(LIBXSMM_MALLOC_SCRATCH_LIMIT) +# define LIBXSMM_MALLOC_SCRATCH_LIMIT 0xFFFFFFFF /* ~4 GB */ +#endif +#if !defined(LIBXSMM_MALLOC_SCRATCH_PADDING) +# define LIBXSMM_MALLOC_SCRATCH_PADDING LIBXSMM_CACHELINE +#endif +/* pointers are checked first if they belong to scratch */ +#if !defined(LIBXSMM_MALLOC_SCRATCH_DELETE_FIRST) && 1 +# define LIBXSMM_MALLOC_SCRATCH_DELETE_FIRST +#endif +/* can clobber memory if allocations are not exactly scoped */ +#if !defined(LIBXSMM_MALLOC_SCRATCH_TRIM_HEAD) && 0 +# define LIBXSMM_MALLOC_SCRATCH_TRIM_HEAD +#endif +#if !defined(LIBXSMM_MALLOC_SCRATCH_JOIN) && 1 +# define LIBXSMM_MALLOC_SCRATCH_JOIN +#endif +#if !defined(LIBXSMM_MALLOC_HUGE_PAGES) && 1 +# define LIBXSMM_MALLOC_HUGE_PAGES +#endif +#if !defined(LIBXSMM_MALLOC_LOCK_PAGES) && 1 +/* 0: on-map, 1: mlock, 2: mlock2/on-fault */ +# define LIBXSMM_MALLOC_LOCK_PAGES 1 +#endif +#if !defined(LIBXSMM_MALLOC_LOCK_ALL) && \ + defined(LIBXSMM_MALLOC_ALIGN_ALL) && 0 +# define LIBXSMM_MALLOC_LOCK_ALL +#endif +/* record real allocation size */ +#if !defined(LIBXSMM_MALLOC_INFO_ALLOCSIZE) && 0 +# define LIBXSMM_MALLOC_INFO_ALLOCSIZE +#endif +/* protected against double-delete (if possible) */ +#if !defined(LIBXSMM_MALLOC_DELETE_SAFE) && 0 +# define LIBXSMM_MALLOC_DELETE_SAFE +#elif !defined(NDEBUG) +# define LIBXSMM_MALLOC_DELETE_SAFE +#endif + +#define INTERNAL_MEMALIGN_REAL(RESULT, ALIGNMENT, SIZE) do { \ + const size_t internal_memalign_real_alignment_ = INTERNAL_MALLOC_AUTOALIGN(SIZE, ALIGNMENT); \ + (RESULT) = (0 != internal_memalign_real_alignment_ \ + ? __real_memalign(internal_memalign_real_alignment_, SIZE) \ + : __real_malloc(SIZE)); \ +} while(0) +#define INTERNAL_REALLOC_REAL(RESULT, PTR, SIZE) (RESULT) = __real_realloc(PTR, SIZE) +#define INTERNAL_FREE_REAL(PTR) __real_free(PTR) + +#if defined(LIBXSMM_MALLOC_LOCK_ALL) && defined(LIBXSMM_MALLOC_LOCK_PAGES) && 0 != (LIBXSMM_MALLOC_LOCK_PAGES) +# if 1 == (LIBXSMM_MALLOC_LOCK_PAGES) || !defined(MLOCK_ONFAULT) || !defined(SYS_mlock2) +# define INTERNAL_MALLOC_LOCK_PAGES(BUFFER, SIZE) if ((LIBXSMM_MALLOC_ALIGNFCT * LIBXSMM_MALLOC_ALIGNMAX) <= (SIZE)) \ + mlock(BUFFER, SIZE) +# else +# define INTERNAL_MALLOC_LOCK_PAGES(BUFFER, SIZE) if ((LIBXSMM_MALLOC_ALIGNFCT * LIBXSMM_MALLOC_ALIGNMAX) <= (SIZE)) \ + syscall(SYS_mlock2, BUFFER, SIZE, MLOCK_ONFAULT) +# endif +#else +# define INTERNAL_MALLOC_LOCK_PAGES(BUFFER, SIZE) +#endif + +#if defined(LIBXSMM_MALLOC_ALIGN_ALL) +# define INTERNAL_MALLOC_AUTOALIGN(SIZE, ALIGNMENT) libxsmm_alignment(SIZE, ALIGNMENT) +#else +# define INTERNAL_MALLOC_AUTOALIGN(SIZE, ALIGNMENT) (ALIGNMENT) +#endif + +#if defined(LIBXSMM_MALLOC_HOOK) && defined(LIBXSMM_MALLOC) && (0 != LIBXSMM_MALLOC) +# define INTERNAL_MEMALIGN_HOOK(RESULT, FLAGS, ALIGNMENT, SIZE, CALLER) { \ + const int internal_memalign_hook_recursive_ = LIBXSMM_ATOMIC_ADD_FETCH( \ + &internal_malloc_recursive, 1, LIBXSMM_ATOMIC_RELAXED); \ + if ( 1 < internal_memalign_hook_recursive_ /* protect against recursion */ \ + || 0 == (internal_malloc_kind & 1) || 0 >= internal_malloc_kind \ + || (internal_malloc_limit[0] > (SIZE)) \ + || (internal_malloc_limit[1] < (SIZE) && 0 != internal_malloc_limit[1])) \ + { \ + INTERNAL_MEMALIGN_REAL(RESULT, ALIGNMENT, SIZE); \ + } \ + else { /* redirect */ \ + LIBXSMM_INIT \ + if (NULL == (CALLER)) { /* libxsmm_trace_caller_id may allocate memory */ \ + internal_scratch_malloc(&(RESULT), SIZE, ALIGNMENT, FLAGS, \ + libxsmm_trace_caller_id(0/*level*/)); \ + } \ + else { \ + internal_scratch_malloc(&(RESULT), SIZE, ALIGNMENT, FLAGS, CALLER); \ + } \ + } \ + LIBXSMM_ATOMIC_SUB_FETCH(&internal_malloc_recursive, 1, LIBXSMM_ATOMIC_RELAXED); \ + } +# define INTERNAL_REALLOC_HOOK(RESULT, FLAGS, PTR, SIZE, CALLER) \ + if (0 == (internal_malloc_kind & 1) || 0 >= internal_malloc_kind \ + /*|| (0 != LIBXSMM_ATOMIC_LOAD(&internal_malloc_recursive, LIBXSMM_ATOMIC_RELAXED))*/ \ + || (internal_malloc_limit[0] > (SIZE)) \ + || (internal_malloc_limit[1] < (SIZE) && 0 != internal_malloc_limit[1])) \ + { \ + INTERNAL_REALLOC_REAL(RESULT, PTR, SIZE); \ + } \ + else { \ + const int nzeros = LIBXSMM_INTRINSICS_BITSCANFWD64((uintptr_t)(PTR)), alignment = 1 << nzeros; \ + LIBXSMM_ASSERT(0 == ((uintptr_t)(PTR) & ~(0xFFFFFFFFFFFFFFFF << nzeros))); \ + if (NULL == (CALLER)) { /* libxsmm_trace_caller_id may allocate memory */ \ + internal_scratch_malloc(&(PTR), SIZE, (size_t)alignment, FLAGS, \ + libxsmm_trace_caller_id(0/*level*/)); \ + } \ + else { \ + internal_scratch_malloc(&(PTR), SIZE, (size_t)alignment, FLAGS, CALLER); \ + } \ + (RESULT) = (PTR); \ + } +# define INTERNAL_FREE_HOOK(PTR, CALLER) { \ + LIBXSMM_UNUSED(CALLER); \ + if (0 == (internal_malloc_kind & 1) || 0 >= internal_malloc_kind \ + /*|| (0 != LIBXSMM_ATOMIC_LOAD(&internal_malloc_recursive, LIBXSMM_ATOMIC_RELAXED))*/ \ + ){ \ + INTERNAL_FREE_REAL(PTR); \ + } \ + else { /* recognize pointers not issued by LIBXSMM */ \ + libxsmm_free(PTR); \ + } \ + } +#elif defined(LIBXSMM_MALLOC_ALIGN_ALL) +# define INTERNAL_MEMALIGN_HOOK(RESULT, FLAGS, ALIGNMENT, SIZE, CALLER) do { \ + LIBXSMM_UNUSED(FLAGS); LIBXSMM_UNUSED(CALLER); \ + INTERNAL_MEMALIGN_REAL(RESULT, ALIGNMENT, SIZE); \ + INTERNAL_MALLOC_LOCK_PAGES(RESULT, SIZE); \ + } while(0) +# define INTERNAL_REALLOC_HOOK(RESULT, FLAGS, PTR, SIZE, CALLER) do { \ + LIBXSMM_UNUSED(FLAGS); LIBXSMM_UNUSED(CALLER); \ + INTERNAL_REALLOC_REAL(RESULT, PTR, SIZE); \ + INTERNAL_MALLOC_LOCK_PAGES(RESULT, SIZE); \ + } while(0) +# define INTERNAL_FREE_HOOK(PTR, CALLER) do { \ + LIBXSMM_UNUSED(CALLER); \ + INTERNAL_FREE_REAL(PTR); \ + } while(0) +#endif + +#if !defined(WIN32) +# if defined(MAP_32BIT) +# define INTERNAL_XMALLOC_MAP32(ENV, MAPSTATE, MFLAGS, SIZE, BUFFER, REPTR) \ + if (MAP_FAILED == (BUFFER) && 0 != (MAP_32BIT & (MFLAGS))) { \ + (BUFFER) = internal_xmalloc_xmap(ENV, SIZE, (MFLAGS) & ~MAP_32BIT, REPTR); \ + if (MAP_FAILED != (BUFFER)) (MAPSTATE) = 0; \ + } +# else +# define INTERNAL_XMALLOC_MAP32(ENV, MAPSTATE, MFLAGS, SIZE, BUFFER, REPTR) +# endif + +# define INTERNAL_XMALLOC(I, ENTRYPOINT, ENVVAR, ENVDEF, MAPSTATE, MFLAGS, SIZE, BUFFER, REPTR) \ + if ((ENTRYPOINT) <= (I) && (MAP_FAILED == (BUFFER) || NULL == (BUFFER))) { \ + static const char* internal_xmalloc_env_ = NULL; \ + LIBXSMM_ASSERT(NULL != (ENVVAR) && '\0' != *(ENVVAR)); \ + if (NULL == internal_xmalloc_env_) { \ + internal_xmalloc_env_ = getenv(ENVVAR); \ + if (NULL == internal_xmalloc_env_) internal_xmalloc_env_ = ENVDEF; \ + } \ + (BUFFER) = internal_xmalloc_xmap(internal_xmalloc_env_, SIZE, MFLAGS, REPTR); \ + INTERNAL_XMALLOC_MAP32(internal_xmalloc_env_, MAPSTATE, MFLAGS, SIZE, BUFFER, REPTR); \ + if (MAP_FAILED != (BUFFER)) (ENTRYPOINT) = (I); \ + } + +# define INTERNAL_XMALLOC_WATERMARK(NAME, WATERMARK, LIMIT, SIZE) { \ + const size_t internal_xmalloc_watermark_ = (WATERMARK) + (SIZE) / 2; /* accept data-race */ \ + if (internal_xmalloc_watermark_ < (LIMIT)) { \ + static size_t internal_xmalloc_watermark_verbose_ = 0; \ + (LIMIT) = internal_xmalloc_watermark_; /* accept data-race */ \ + if (internal_xmalloc_watermark_verbose_ < internal_xmalloc_watermark_ && \ + (LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity)) \ + { /* muted */ \ + char internal_xmalloc_watermark_buffer_[32]; \ + /* coverity[check_return] */ \ + libxsmm_format_value(internal_xmalloc_watermark_buffer_, sizeof(internal_xmalloc_watermark_buffer_), \ + internal_xmalloc_watermark_, "KM", "B", 10); \ + fprintf(stderr, "LIBXSMM WARNING: " NAME " watermark reached at %s!\n", internal_xmalloc_watermark_buffer_); \ + internal_xmalloc_watermark_verbose_ = internal_xmalloc_watermark_; \ + } \ + } \ +} + +# define INTERNAL_XMALLOC_KIND(KIND, NAME, FLAG, FLAGS, MFLAGS, WATERMARK, LIMIT, INFO, SIZE, BUFFER) \ + if (0 != ((KIND) & (MFLAGS))) { \ + if (MAP_FAILED != (BUFFER)) { \ + LIBXSMM_ASSERT(NULL != (BUFFER)); \ + LIBXSMM_ATOMIC_ADD_FETCH(&(WATERMARK), SIZE, LIBXSMM_ATOMIC_RELAXED); \ + (FLAGS) |= (FLAG); \ + } \ + else { /* retry */ \ + (BUFFER) = mmap(NULL == (INFO) ? NULL : (INFO)->pointer, SIZE, PROT_READ | PROT_WRITE, \ + MAP_PRIVATE | LIBXSMM_MAP_ANONYMOUS | ((MFLAGS) & ~(KIND)), -1, 0/*offset*/); \ + if (MAP_FAILED != (BUFFER)) { /* successful retry */ \ + LIBXSMM_ASSERT(NULL != (BUFFER)); \ + INTERNAL_XMALLOC_WATERMARK(NAME, WATERMARK, LIMIT, SIZE); \ + } \ + (FLAGS) &= ~(FLAG); \ + } \ + } \ + else (FLAGS) &= ~(FLAG) +#endif + + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE internal_malloc_info_type { + libxsmm_free_function free; + void *pointer, *reloc; + const void* context; +#if defined(LIBXSMM_MALLOC_INFO_ALLOCSIZE) + /* real/allocated size */ + size_t size_alloc; +#endif + /* user-requested size */ + size_t size; + int flags; +#if defined(LIBXSMM_VTUNE) + unsigned int code_id; +#endif +#if !defined(LIBXSMM_MALLOC_CRC_OFF) /* hash *must* be the last entry */ + unsigned int hash; +#endif +} internal_malloc_info_type; + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_malloc_pool_type { + char pad[LIBXSMM_MALLOC_SCRATCH_PADDING]; + struct { + size_t minsize, counter, incsize; + char *buffer, *head; +#if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + const void* site; +# if (0 != LIBXSMM_SYNC) + unsigned int tid; +# endif +#endif + } instance; +} internal_malloc_pool_type; + +/* Scratch pool, which supports up to MAX_NSCRATCH allocation sites. */ +#if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) +/* LIBXSMM_ALIGNED appears to contradict LIBXSMM_APIVAR, and causes multiple defined symbols (if below is seen in multiple translation units) */ +LIBXSMM_APIVAR_DEFINE(char internal_malloc_pool_buffer[(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)*sizeof(internal_malloc_pool_type)+(LIBXSMM_MALLOC_SCRATCH_PADDING)-1]); +#endif +/* Maximum total size of the scratch memory domain. */ +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_scratch_limit); +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_scratch_nmallocs); +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_private_max); +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_private_cur); +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_public_max); +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_public_cur); +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_local_max); +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_local_cur); +LIBXSMM_APIVAR_DEFINE(int internal_malloc_recursive); +/** 0: regular, 1/odd: intercept/scratch, otherwise: all/scratch */ +LIBXSMM_APIVAR_DEFINE(int internal_malloc_kind); +#if defined(LIBXSMM_MALLOC_HOOK) && defined(LIBXSMM_MALLOC) && (0 != LIBXSMM_MALLOC) +/* Interval of bytes that permit interception (internal_malloc_kind) */ +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_limit[2]); +#endif +#if (0 != LIBXSMM_SYNC) && defined(LIBXSMM_MALLOC_SCRATCH_JOIN) +LIBXSMM_APIVAR_DEFINE(int internal_malloc_join); +#endif +#if !defined(_WIN32) +# if defined(MAP_HUGETLB) && defined(LIBXSMM_MALLOC_HUGE_PAGES) +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_hugetlb); +# endif +# if defined(MAP_LOCKED) && defined(LIBXSMM_MALLOC_LOCK_PAGES) +LIBXSMM_APIVAR_DEFINE(size_t internal_malloc_plocked); +# endif +#endif + + +LIBXSMM_API_INTERN size_t libxsmm_alignment(size_t size, size_t alignment) +{ + size_t result; + if ((LIBXSMM_MALLOC_ALIGNFCT * LIBXSMM_MALLOC_ALIGNMAX) <= size) { + result = libxsmm_lcm(0 == alignment ? (LIBXSMM_ALIGNMENT) : libxsmm_lcm(alignment, LIBXSMM_ALIGNMENT), LIBXSMM_MALLOC_ALIGNMAX); + } + else { /* small-size request */ + if ((LIBXSMM_MALLOC_ALIGNFCT * LIBXSMM_ALIGNMENT) <= size) { + result = (0 == alignment ? (LIBXSMM_ALIGNMENT) : libxsmm_lcm(alignment, LIBXSMM_ALIGNMENT)); + } + else if (0 != alignment) { /* custom alignment */ + result = libxsmm_lcm(alignment, sizeof(void*)); + } + else { /* tiny-size request */ + result = sizeof(void*); + } + } + return result; +} + + +LIBXSMM_API size_t libxsmm_offset(const size_t offset[], const size_t shape[], size_t ndims, size_t* size) +{ + size_t result = 0, size1 = 0; + if (0 != ndims && NULL != shape) { + size_t i; + result = (NULL != offset ? offset[0] : 0); + size1 = shape[0]; + for (i = 1; i < ndims; ++i) { + result += (NULL != offset ? offset[i] : 0) * size1; + size1 *= shape[i]; + } + } + if (NULL != size) *size = size1; + return result; +} + + +LIBXSMM_API_INLINE +LIBXSMM_ATTRIBUTE_NO_SANITIZE(address) +internal_malloc_info_type* internal_malloc_info(const void* memory, int check) +{ + const char *const buffer = (const char*)memory; + internal_malloc_info_type* result = (internal_malloc_info_type*)(NULL != memory + ? (buffer - sizeof(internal_malloc_info_type)) : NULL); +#if defined(LIBXSMM_MALLOC_HOOK_CHECK) + if ((LIBXSMM_MALLOC_HOOK_CHECK) < check) check = (LIBXSMM_MALLOC_HOOK_CHECK); +#endif + if (0 != check && NULL != result) { /* check ownership */ +#if !defined(_WIN32) /* mprotect: pass address rounded down to page/4k alignment */ + if (1 == check || 0 == mprotect((void*)(((uintptr_t)result) & 0xFFFFFFFFFFFFF000), + sizeof(internal_malloc_info_type), PROT_READ | PROT_WRITE) || ENOMEM != errno) +#endif + { + const int flags_rs = LIBXSMM_MALLOC_FLAG_REALLOC | LIBXSMM_MALLOC_FLAG_SCRATCH; + const int flags_px = LIBXSMM_MALLOC_FLAG_X | LIBXSMM_MALLOC_FLAG_PRIVATE; + const int flags_mx = LIBXSMM_MALLOC_FLAG_X | LIBXSMM_MALLOC_FLAG_MMAP; + const char *const pointer = (const char*)result->pointer; + union { libxsmm_free_fun fun; const void* ptr; } convert; + convert.fun = result->free.function; + if (((flags_mx != (flags_mx & result->flags)) && NULL != result->reloc) + || (0 == (LIBXSMM_MALLOC_FLAG_X & result->flags) ? 0 : (0 != (flags_rs & result->flags))) + || (0 != (LIBXSMM_MALLOC_FLAG_X & result->flags) && NULL != result->context) +#if defined(LIBXSMM_VTUNE) + || (0 == (LIBXSMM_MALLOC_FLAG_X & result->flags) && 0 != result->code_id) +#endif + || (0 != (~LIBXSMM_MALLOC_FLAG_VALID & result->flags)) + || (0 == (LIBXSMM_MALLOC_FLAG_R & result->flags)) + || (pointer == convert.ptr || pointer == result->context || pointer >= buffer || NULL == pointer) +#if defined(LIBXSMM_MALLOC_INFO_ALLOCSIZE) + || (result->size_alloc < result->size) +#endif + || (LIBXSMM_MAX(LIBXSMM_MAX(internal_malloc_public_max, internal_malloc_local_max), internal_malloc_private_max) < result->size + && 0 == (flags_px & result->flags)) || (0 == result->size) + || (2 > libxsmm_ninit) /* before checksum calculation */ +#if !defined(LIBXSMM_MALLOC_CRC_OFF) /* last check: checksum over info */ +# if defined(LIBXSMM_MALLOC_CRC_LIGHT) + || result->hash != LIBXSMM_CRC32U(LIBXSMM_BITS)(LIBXSMM_MALLOC_SEED, &result) +# else + || result->hash != libxsmm_crc32(LIBXSMM_MALLOC_SEED, result, + (const char*)&result->hash - (const char*)result) +# endif +#endif + ) { /* mismatch */ + result = NULL; + } + } +#if !defined(_WIN32) + else { /* mismatch */ + result = NULL; + } +#endif + } + return result; +} + + +LIBXSMM_API_INLINE size_t internal_get_scratch_size(const internal_malloc_pool_type* exclude) +{ + size_t result = 0; +#if !defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) || (1 >= (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + LIBXSMM_UNUSED(exclude); +#else + const internal_malloc_pool_type* pool = (const internal_malloc_pool_type*)LIBXSMM_UP2( + (uintptr_t)internal_malloc_pool_buffer, LIBXSMM_MALLOC_SCRATCH_PADDING); +# if (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + const internal_malloc_pool_type *const end = pool + libxsmm_scratch_pools; + LIBXSMM_ASSERT(libxsmm_scratch_pools <= LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS); + for (; pool != end; ++pool) +# endif /*(1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS))*/ + { + if (0 != pool->instance.minsize) { +# if 1 /* memory info is not used */ + if (pool != exclude && (LIBXSMM_MALLOC_INTERNAL_CALLER) != pool->instance.site) { + result += pool->instance.minsize; + } +# else + const internal_malloc_info_type *const info = internal_malloc_info(pool->instance.buffer, 0/*no check*/); + if (NULL != info && pool != exclude && (LIBXSMM_MALLOC_INTERNAL_CALLER) != pool->instance.site) { + result += info->size; + } +# endif + } + else break; /* early exit */ + } +#endif /*defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS))*/ + return result; +} + + +LIBXSMM_API_INLINE internal_malloc_pool_type* internal_scratch_malloc_pool(const void* memory) +{ + internal_malloc_pool_type* result = NULL; + internal_malloc_pool_type* pool = (internal_malloc_pool_type*)LIBXSMM_UP2( + (uintptr_t)internal_malloc_pool_buffer, LIBXSMM_MALLOC_SCRATCH_PADDING); + const char *const buffer = (const char*)memory; +#if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + const unsigned int npools = libxsmm_scratch_pools; +#else + const unsigned int npools = 1; +#endif + internal_malloc_pool_type *const end = pool + npools; + LIBXSMM_ASSERT(npools <= LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS); + LIBXSMM_ASSERT(NULL != memory); + for (; pool != end; ++pool) { + if (0 != pool->instance.minsize) { + if (0 != /*LIBXSMM_ATOMIC_LOAD(&*/pool->instance.counter/*, LIBXSMM_ATOMIC_SEQ_CST)*/ +#if 1 /* should be implied by non-zero counter */ + && NULL != pool->instance.buffer +#endif + ){/* check if memory belongs to scratch domain or local domain */ +#if 1 + const size_t size = pool->instance.minsize; +#else + const internal_malloc_info_type *const info = internal_malloc_info(pool->instance.buffer, 0/*no check*/); + const size_t size = info->size; +#endif + if (pool->instance.buffer == buffer /* fast path */ || + (pool->instance.buffer < buffer && buffer < (pool->instance.buffer + size))) + { + result = pool; + break; + } + } + } + else break; /* early exit */ + } + return result; +} + + +LIBXSMM_API_INTERN int internal_xfree(const void* /*memory*/, internal_malloc_info_type* /*info*/); + + +LIBXSMM_API_INTERN void internal_scratch_free(const void* /*memory*/, internal_malloc_pool_type* /*pool*/); +LIBXSMM_API_INTERN void internal_scratch_free(const void* memory, internal_malloc_pool_type* pool) +{ +#if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + const size_t counter = LIBXSMM_ATOMIC_SUB_FETCH(&pool->instance.counter, 1, LIBXSMM_ATOMIC_SEQ_CST); + char *const pool_buffer = pool->instance.buffer; +# if (!defined(NDEBUG) || defined(LIBXSMM_MALLOC_SCRATCH_TRIM_HEAD)) + char *const buffer = (char*)memory; /* non-const */ + LIBXSMM_ASSERT(pool_buffer <= buffer && buffer < pool_buffer + pool->instance.minsize); +# endif + LIBXSMM_ASSERT(pool_buffer <= pool->instance.head); + if (0 == counter) { /* reuse or reallocate scratch domain */ + internal_malloc_info_type *const info = internal_malloc_info(pool_buffer, 0/*no check*/); + const size_t scale_size = (size_t)(1 != libxsmm_scratch_scale ? (libxsmm_scratch_scale * info->size) : info->size); /* hysteresis */ + const size_t size = pool->instance.minsize + pool->instance.incsize; + LIBXSMM_ASSERT(0 == (LIBXSMM_MALLOC_FLAG_X & info->flags)); /* scratch memory is not executable */ + if (size <= scale_size) { /* reuse scratch domain */ + pool->instance.head = pool_buffer; /* reuse scratch domain */ + } + else { /* release buffer */ +# if !defined(NDEBUG) + static int error_once = 0; +# endif + pool->instance.buffer = pool->instance.head = NULL; +# if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + pool->instance.site = NULL; /* clear affinity */ +# endif +# if !defined(NDEBUG) + if (EXIT_SUCCESS != internal_xfree(pool_buffer, info) /* invalidates info */ + && 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: memory deallocation failed!\n"); + } +# else + internal_xfree(pool_buffer, info); /* !libxsmm_free, invalidates info */ +# endif + } + } +# if defined(LIBXSMM_MALLOC_SCRATCH_TRIM_HEAD) /* TODO: document linear/scoped allocator policy */ + else if (buffer < pool->instance.head) { /* reuse scratch domain */ + pool->instance.head = buffer; + } +# else + LIBXSMM_UNUSED(memory); +# endif +#else + LIBXSMM_UNUSED(memory); LIBXSMM_UNUSED(pool); +#endif +} + + +LIBXSMM_API_INTERN void internal_scratch_malloc(void** /*memory*/, size_t /*size*/, size_t /*alignment*/, int /*flags*/, const void* /*caller*/); +LIBXSMM_API_INTERN void internal_scratch_malloc(void** memory, size_t size, size_t alignment, int flags, const void* caller) +{ + LIBXSMM_ASSERT(NULL != memory && 0 == (LIBXSMM_MALLOC_FLAG_X & flags)); + if (0 == (LIBXSMM_MALLOC_FLAG_REALLOC & flags) || NULL == *memory) { + static int error_once = 0; + size_t local_size = 0; +#if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + if (0 < libxsmm_scratch_pools) { + internal_malloc_pool_type *const pools = (internal_malloc_pool_type*)LIBXSMM_UP2( + (uintptr_t)internal_malloc_pool_buffer, LIBXSMM_MALLOC_SCRATCH_PADDING); + internal_malloc_pool_type *const end = pools + libxsmm_scratch_pools, *pool = pools; + const size_t align_size = libxsmm_alignment(size, alignment), alloc_size = size + align_size - 1; +# if (0 != LIBXSMM_SYNC) + const unsigned int tid = libxsmm_get_tid(); +# endif + unsigned int npools = 1; +# if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + const void *const site = caller; /* no further attempt in case of NULL */ + internal_malloc_pool_type *pool0 = end; + for (; pool != end; ++pool) { /* counter: memory info is not employed as pools are still manipulated */ + if (NULL != pool->instance.buffer) { + if ((LIBXSMM_MALLOC_INTERNAL_CALLER) != pool->instance.site) ++npools; /* count number of occupied pools */ + if ( /* find matching pool and enter fast path (draw from pool-buffer) */ +# if (0 != LIBXSMM_SYNC) && !defined(LIBXSMM_MALLOC_SCRATCH_JOIN) + (site == pool->instance.site && tid == pool->instance.tid)) +# elif (0 != LIBXSMM_SYNC) + (site == pool->instance.site && (0 != internal_malloc_join || tid == pool->instance.tid))) +# else + (site == pool->instance.site)) +# endif + { + break; + } + } + else { + if (end == pool0) pool0 = pool; /* first available pool*/ + if (0 == pool->instance.minsize) { /* early exit */ + pool = pool0; break; + } + } + } +# endif + LIBXSMM_ASSERT(NULL != pool); + if (end != pool && 0 <= internal_malloc_kind) { + const size_t counter = LIBXSMM_ATOMIC_ADD_FETCH(&pool->instance.counter, (size_t)1, LIBXSMM_ATOMIC_SEQ_CST); + if (NULL != pool->instance.buffer || 1 != counter) { /* attempt to (re-)use existing pool */ + const internal_malloc_info_type *const info = internal_malloc_info(pool->instance.buffer, 1/*check*/); + const size_t pool_size = ((NULL != info && 0 != counter) ? info->size : 0); + const size_t used_size = pool->instance.head - pool->instance.buffer; + const size_t req_size = alloc_size + used_size; + if (req_size <= pool_size) { /* fast path: draw from pool-buffer */ +# if (0 != LIBXSMM_SYNC) && defined(LIBXSMM_MALLOC_SCRATCH_JOIN) + void *const headaddr = &pool->instance.head; + char *const head = (0 == internal_malloc_join + ? (pool->instance.head += alloc_size) + : ((char*)LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_ADD_FETCH, LIBXSMM_BITS)( + (uintptr_t*)headaddr, alloc_size, LIBXSMM_ATOMIC_SEQ_CST))); +# else + char *const head = (char*)(pool->instance.head += alloc_size); +# endif + *memory = LIBXSMM_ALIGN(head - alloc_size, align_size); + } + else { /* fallback to local memory allocation */ + const size_t incsize = req_size - LIBXSMM_MIN(pool_size, req_size); + pool->instance.incsize = LIBXSMM_MAX(pool->instance.incsize, incsize); +# if (0 != LIBXSMM_SYNC) && defined(LIBXSMM_MALLOC_SCRATCH_JOIN) + if (0 == internal_malloc_join) { + --pool->instance.counter; + } + else { + LIBXSMM_ATOMIC_SUB_FETCH(&pool->instance.counter, 1, LIBXSMM_ATOMIC_SEQ_CST); + } +# else + --pool->instance.counter; +# endif + if ( +# if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + (LIBXSMM_MALLOC_INTERNAL_CALLER) != pool->instance.site && +# endif + 0 == (LIBXSMM_MALLOC_FLAG_PRIVATE & flags)) + { + const size_t watermark = LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_ADD_FETCH, LIBXSMM_BITS)( + &internal_malloc_local_cur, alloc_size, LIBXSMM_ATOMIC_RELAXED); + if (internal_malloc_local_max < watermark) internal_malloc_local_max = watermark; /* accept data-race */ + } + else { + const size_t watermark = LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_ADD_FETCH, LIBXSMM_BITS)( + &internal_malloc_private_cur, alloc_size, LIBXSMM_ATOMIC_RELAXED); + if (internal_malloc_private_max < watermark) internal_malloc_private_max = watermark; /* accept data-race */ + } + local_size = size; + } + } + else { /* fresh pool */ + const size_t scratch_limit = libxsmm_get_scratch_limit(); + const size_t scratch_size = internal_get_scratch_size(pool); /* exclude current pool */ + const size_t limit_size = (1 < npools ? (scratch_limit - LIBXSMM_MIN(scratch_size, scratch_limit)) : LIBXSMM_SCRATCH_UNLIMITED); + const size_t scale_size = (size_t)(1 != libxsmm_scratch_scale ? (libxsmm_scratch_scale * alloc_size) : alloc_size); /* hysteresis */ + const size_t incsize = (size_t)(libxsmm_scratch_scale * pool->instance.incsize); + const size_t maxsize = LIBXSMM_MAX(scale_size, pool->instance.minsize) + incsize; + const size_t limsize = LIBXSMM_MIN(maxsize, limit_size); + const size_t minsize = limsize; + assert(1 <= libxsmm_scratch_scale); /* !LIBXSMM_ASSERT */ + LIBXSMM_ASSERT(1 == counter); + pool->instance.incsize = 0; /* reset */ + pool->instance.minsize = minsize; +# if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + pool->instance.site = site; +# if (0 != LIBXSMM_SYNC) + pool->instance.tid = tid; +# endif +# endif + if (alloc_size <= minsize && /* allocate scratch pool */ + EXIT_SUCCESS == libxsmm_xmalloc(memory, minsize, 0/*auto-align*/, + (flags | LIBXSMM_MALLOC_FLAG_SCRATCH) & ~LIBXSMM_MALLOC_FLAG_REALLOC, + NULL/*extra*/, 0/*extra_size*/)) + { + pool->instance.buffer = (char*)*memory; + pool->instance.head = pool->instance.buffer + alloc_size; + *memory = LIBXSMM_ALIGN((char*)*memory, align_size); +# if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + if ((LIBXSMM_MALLOC_INTERNAL_CALLER) != pool->instance.site) +# endif + { + LIBXSMM_ATOMIC_ADD_FETCH(&internal_malloc_scratch_nmallocs, 1, LIBXSMM_ATOMIC_RELAXED); + } + } + else { /* fallback to local allocation */ + LIBXSMM_ATOMIC_SUB_FETCH(&pool->instance.counter, 1, LIBXSMM_ATOMIC_SEQ_CST); + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + if (alloc_size <= minsize) { + fprintf(stderr, "LIBXSMM ERROR: failed to allocate scratch memory!\n"); + } + else if ((LIBXSMM_MALLOC_INTERNAL_CALLER) != caller + && (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity)) + { + fprintf(stderr, "LIBXSMM WARNING: scratch memory domain exhausted!\n"); + } + } + local_size = size; + } + } + } + else { /* fallback to local memory allocation */ + local_size = size; + } + } + else { /* fallback to local memory allocation */ + local_size = size; + } + if (0 != local_size) +#else + local_size = size; +#endif /*defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS))*/ + { /* local memory allocation */ + if (EXIT_SUCCESS != libxsmm_xmalloc(memory, local_size, alignment, + flags & ~(LIBXSMM_MALLOC_FLAG_SCRATCH | LIBXSMM_MALLOC_FLAG_REALLOC), NULL/*extra*/, 0/*extra_size*/) + && /* library code is expected to be mute */0 != libxsmm_verbosity + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: scratch memory fallback failed!\n"); + LIBXSMM_ASSERT(NULL == *memory); + } + if ((LIBXSMM_MALLOC_INTERNAL_CALLER) != caller) { + LIBXSMM_ATOMIC_ADD_FETCH(&internal_malloc_scratch_nmallocs, 1, LIBXSMM_ATOMIC_RELAXED); + } + } + } + else { /* reallocate memory */ + const void *const preserve = *memory; +#if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + internal_malloc_pool_type *const pool = internal_scratch_malloc_pool(preserve); + if (NULL != pool) { + const internal_malloc_info_type *const info = internal_malloc_info(pool->instance.buffer, 0/*no check*/); + void* buffer; + LIBXSMM_ASSERT(pool->instance.buffer <= pool->instance.head && NULL != info); + internal_scratch_malloc(&buffer, size, alignment, + ~LIBXSMM_MALLOC_FLAG_REALLOC & (LIBXSMM_MALLOC_FLAG_SCRATCH | flags), caller); + if (NULL != buffer) { + memcpy(buffer, preserve, LIBXSMM_MIN(size, info->size)); /* TODO: memmove? */ + *memory = buffer; + } + internal_scratch_free(memory, pool); + } + else +#endif + { /* non-pooled (potentially foreign pointer) */ +#if !defined(NDEBUG) + const int status = +#endif + libxsmm_xmalloc(memory, size, alignment/* no need here to determine alignment of given buffer */, + ~LIBXSMM_MALLOC_FLAG_SCRATCH & flags, NULL/*extra*/, 0/*extra_size*/); + assert(EXIT_SUCCESS == status || NULL == *memory); /* !LIBXSMM_ASSERT */ + } + } +} + + +#if defined(LIBXSMM_MALLOC_HOOK_DYNAMIC) +LIBXSMM_APIVAR_PRIVATE_DEF(libxsmm_malloc_fntype libxsmm_malloc_fn); + +#if defined(LIBXSMM_MALLOC_HOOK_QKMALLOC) +LIBXSMM_API_INTERN void* internal_memalign_malloc(size_t /*alignment*/, size_t /*size*/); +LIBXSMM_API_INTERN void* internal_memalign_malloc(size_t alignment, size_t size) +{ + LIBXSMM_UNUSED(alignment); + LIBXSMM_ASSERT(NULL != libxsmm_malloc_fn.malloc.dlsym); + return libxsmm_malloc_fn.malloc.ptr(size); +} +#elif defined(LIBXSMM_MALLOC_HOOK_KMP) +LIBXSMM_API_INTERN void* internal_memalign_twiddle(size_t /*alignment*/, size_t /*size*/); +LIBXSMM_API_INTERN void* internal_memalign_twiddle(size_t alignment, size_t size) +{ + LIBXSMM_ASSERT(NULL != libxsmm_malloc_fn.alignmem.dlsym); + return libxsmm_malloc_fn.alignmem.ptr(size, alignment); +} +#endif +#endif /*defined(LIBXSMM_MALLOC_HOOK_DYNAMIC)*/ + + +#if (defined(LIBXSMM_MALLOC_HOOK) && defined(LIBXSMM_MALLOC) && (0 != LIBXSMM_MALLOC)) || defined(LIBXSMM_MALLOC_ALIGN_ALL) +LIBXSMM_API_INTERN void* internal_memalign_hook(size_t /*alignment*/, size_t /*size*/, const void* /*caller*/); +LIBXSMM_API_INTERN void* internal_memalign_hook(size_t alignment, size_t size, const void* caller) +{ + void* result; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_MMAP, alignment, size, caller); +# else + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_DEFAULT, alignment, size, caller); +# endif + return result; +} + +LIBXSMM_API void* __wrap_memalign(size_t /*alignment*/, size_t /*size*/); +LIBXSMM_API void* __wrap_memalign(size_t alignment, size_t size) +{ + void* result; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_MMAP, alignment, size, NULL/*caller*/); +# else + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_DEFAULT, alignment, size, NULL/*caller*/); +# endif + return result; +} + +LIBXSMM_API_INTERN void* internal_malloc_hook(size_t /*size*/, const void* /*caller*/); +LIBXSMM_API_INTERN void* internal_malloc_hook(size_t size, const void* caller) +{ + return internal_memalign_hook(0/*auto-alignment*/, size, caller); +} + +LIBXSMM_API void* __wrap_malloc(size_t /*size*/); +LIBXSMM_API void* __wrap_malloc(size_t size) +{ + void* result; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_MMAP, 0/*auto-alignment*/, size, NULL/*caller*/); +# else + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_DEFAULT, 0/*auto-alignment*/, size, NULL/*caller*/); +# endif + return result; +} + +#if defined(LIBXSMM_MALLOC_HOOK_CALLOC) +LIBXSMM_API void* __wrap_calloc(size_t /*num*/, size_t /*size*/); +LIBXSMM_API void* __wrap_calloc(size_t num, size_t size) +{ + void* result; + const size_t nbytes = num * size; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_MMAP, 0/*auto-alignment*/, nbytes, NULL/*caller*/); +# else + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_DEFAULT, 0/*auto-alignment*/, nbytes, NULL/*caller*/); +# endif + /* TODO: signal anonymous/zeroed pages */ + if (NULL != result) memset(result, 0, nbytes); + return result; +} +#endif + +#if defined(LIBXSMM_MALLOC_HOOK_REALLOC) +LIBXSMM_API_INTERN void* internal_realloc_hook(void* /*ptr*/, size_t /*size*/, const void* /*caller*/); +LIBXSMM_API_INTERN void* internal_realloc_hook(void* ptr, size_t size, const void* caller) +{ + void* result; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_REALLOC_HOOK(result, LIBXSMM_MALLOC_FLAG_REALLOC | LIBXSMM_MALLOC_FLAG_MMAP, ptr, size, caller); +# else + INTERNAL_REALLOC_HOOK(result, LIBXSMM_MALLOC_FLAG_REALLOC | LIBXSMM_MALLOC_FLAG_DEFAULT, ptr, size, caller); +# endif + return result; +} + +LIBXSMM_API void* __wrap_realloc(void* /*ptr*/, size_t /*size*/); +LIBXSMM_API void* __wrap_realloc(void* ptr, size_t size) +{ + void* result; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_REALLOC_HOOK(result, LIBXSMM_MALLOC_FLAG_REALLOC | LIBXSMM_MALLOC_FLAG_MMAP, ptr, size, NULL/*caller*/); +# else + INTERNAL_REALLOC_HOOK(result, LIBXSMM_MALLOC_FLAG_REALLOC | LIBXSMM_MALLOC_FLAG_DEFAULT, ptr, size, NULL/*caller*/); +# endif + return result; +} +#endif + +LIBXSMM_API_INTERN void internal_free_hook(void* /*ptr*/, const void* /*caller*/); +LIBXSMM_API_INTERN void internal_free_hook(void* ptr, const void* caller) +{ + INTERNAL_FREE_HOOK(ptr, caller); +} + +LIBXSMM_API void __wrap_free(void* /*ptr*/); +LIBXSMM_API void __wrap_free(void* ptr) +{ + INTERNAL_FREE_HOOK(ptr, NULL/*caller*/); +} +#endif + +#if defined(LIBXSMM_MALLOC_HOOK_DYNAMIC) && ((defined(LIBXSMM_MALLOC) && (0 != LIBXSMM_MALLOC)) || defined(LIBXSMM_MALLOC_ALIGN_ALL)) +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK LIBXSMM_ATTRIBUTE_MALLOC void* memalign(size_t /*alignment*/, size_t /*size*/) LIBXSMM_THROW; +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK LIBXSMM_ATTRIBUTE_MALLOC void* memalign(size_t alignment, size_t size) LIBXSMM_THROW +{ + void* result; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_MMAP, alignment, size, NULL/*caller*/); +# else + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_DEFAULT, alignment, size, NULL/*caller*/); +# endif + return result; +} + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK LIBXSMM_ATTRIBUTE_MALLOC void* malloc(size_t /*size*/) LIBXSMM_THROW; +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK LIBXSMM_ATTRIBUTE_MALLOC void* malloc(size_t size) LIBXSMM_THROW +{ + void* result; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_MMAP, 0/*auto-alignment*/, size, NULL/*caller*/); +# else + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_DEFAULT, 0/*auto-alignment*/, size, NULL/*caller*/); +# endif + return result; +} + +#if defined(LIBXSMM_MALLOC_HOOK_CALLOC) +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK LIBXSMM_ATTRIBUTE_MALLOC void* calloc(size_t /*num*/, size_t /*size*/) LIBXSMM_THROW; +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK LIBXSMM_ATTRIBUTE_MALLOC void* calloc(size_t num, size_t size) LIBXSMM_THROW +{ + void* result; + const size_t nbytes = num * size; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_MMAP, 0/*auto-alignment*/, nbytes, NULL/*caller*/); +# else + INTERNAL_MEMALIGN_HOOK(result, LIBXSMM_MALLOC_FLAG_DEFAULT, 0/*auto-alignment*/, nbytes, NULL/*caller*/); +# endif + /* TODO: signal anonymous/zeroed pages */ + if (NULL != result) memset(result, 0, nbytes); + return result; +} +#endif + +#if defined(LIBXSMM_MALLOC_HOOK_REALLOC) +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void* realloc(void* /*ptr*/, size_t /*size*/) LIBXSMM_THROW; +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void* realloc(void* ptr, size_t size) LIBXSMM_THROW +{ + void* result; +# if defined(LIBXSMM_MALLOC_MMAP_HOOK) + INTERNAL_REALLOC_HOOK(result, LIBXSMM_MALLOC_FLAG_REALLOC | LIBXSMM_MALLOC_FLAG_MMAP, ptr, size, NULL/*caller*/); +# else + INTERNAL_REALLOC_HOOK(result, LIBXSMM_MALLOC_FLAG_REALLOC | LIBXSMM_MALLOC_FLAG_DEFAULT, ptr, size, NULL/*caller*/); +# endif + return result; +} +#endif + +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void free(void* /*ptr*/) LIBXSMM_THROW; +LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void free(void* ptr) LIBXSMM_THROW +{ + INTERNAL_FREE_HOOK(ptr, NULL/*caller*/); +} +#endif + + +LIBXSMM_API_INTERN int internal_xfree(const void* memory, internal_malloc_info_type* info) +{ +#if !defined(LIBXSMM_BUILD) || !defined(_WIN32) + static int error_once = 0; +#endif + int result = EXIT_SUCCESS; + internal_malloc_info_type local; + LIBXSMM_ASSIGN127(&local, info); +#if !defined(LIBXSMM_BUILD) /* sanity check */ + if (NULL != local.pointer || 0 == local.size) +#endif + { +#if !defined(LIBXSMM_MALLOC_INFO_ALLOCSIZE) || !defined(NDEBUG) + const size_t size = local.size + (size_t)(((const char*)memory) - ((const char*)local.pointer)); +#endif +#if defined(LIBXSMM_MALLOC_INFO_ALLOCSIZE) + const size_t size_alloc = local.size_alloc; + assert(0 == local.size || (NULL != local.pointer && size <= size_alloc)); /* !LIBXSMM_ASSERT */ +#else + const size_t size_alloc = /*LIBXSMM_UP2(*/size/*, LIBXSMM_PAGE_MINSIZE)*/; +#endif + assert(NULL != memory && NULL != info && sizeof(internal_malloc_info_type) < size_alloc); /* !LIBXSMM_ASSERT */ +#if defined(LIBXSMM_MALLOC_INFO_ALLOCSIZE) && defined(NDEBUG) + LIBXSMM_UNUSED(memory); +#endif + if (0 == (LIBXSMM_MALLOC_FLAG_MMAP & local.flags)) { + if (NULL != local.free.function) { +#if defined(LIBXSMM_MALLOC_DELETE_SAFE) + LIBXSMM_MEMZERO127(info); +#endif + if (NULL == local.context) { +#if defined(LIBXSMM_MALLOC_HOOK) + if (free == local.free.function) { + __real_free(local.pointer); + } + else +#endif + if (NULL != local.free.function) { + local.free.function(local.pointer); + } + } + else { + LIBXSMM_ASSERT(NULL != local.free.ctx_form); + local.free.ctx_form(local.pointer, local.context); + } + } + } + else { +#if defined(LIBXSMM_VTUNE) + if (0 != (LIBXSMM_MALLOC_FLAG_X & local.flags) && 0 != local.code_id && iJIT_SAMPLING_ON == iJIT_IsProfilingActive()) { + iJIT_NotifyEvent(LIBXSMM_VTUNE_JIT_UNLOAD, &local.code_id); + } +#endif +#if defined(_WIN32) + result = (NULL == local.pointer || FALSE != VirtualFree(local.pointer, 0, MEM_RELEASE)) ? EXIT_SUCCESS : EXIT_FAILURE; +#else /* !_WIN32 */ + { + if (0 != munmap(local.pointer, size_alloc)) { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: %s (attempted to unmap buffer %p+%" PRIuPTR ")!\n", + strerror(errno), local.pointer, (uintptr_t)size_alloc); + } + result = EXIT_FAILURE; + } + if (0 != (LIBXSMM_MALLOC_FLAG_X & local.flags) && EXIT_SUCCESS == result + && NULL != local.reloc && MAP_FAILED != local.reloc && local.pointer != local.reloc + && 0 != munmap(local.reloc, size_alloc)) + { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: %s (attempted to unmap code %p+%" PRIuPTR ")!\n", + strerror(errno), local.reloc, (uintptr_t)size_alloc); + } + result = EXIT_FAILURE; + } + } +#endif + } + if (0 == (LIBXSMM_MALLOC_FLAG_X & local.flags)) { /* update statistics */ +#if !defined(_WIN32) +# if defined(MAP_HUGETLB) && defined(LIBXSMM_MALLOC_HUGE_PAGES) + if (0 != (LIBXSMM_MALLOC_FLAG_PHUGE & local.flags)) { /* huge pages */ + LIBXSMM_ASSERT(0 != (LIBXSMM_MALLOC_FLAG_MMAP & local.flags)); + LIBXSMM_ATOMIC_SUB_FETCH(&internal_malloc_hugetlb, size_alloc, LIBXSMM_ATOMIC_RELAXED); + } +# endif +# if defined(MAP_LOCKED) && defined(LIBXSMM_MALLOC_LOCK_PAGES) + if (0 != (LIBXSMM_MALLOC_FLAG_PLOCK & local.flags)) { /* page-locked */ + LIBXSMM_ASSERT(0 != (LIBXSMM_MALLOC_FLAG_MMAP & local.flags)); + LIBXSMM_ATOMIC_SUB_FETCH(&internal_malloc_plocked, size_alloc, LIBXSMM_ATOMIC_RELAXED); + } +# endif +#endif + if (0 == (LIBXSMM_MALLOC_FLAG_PRIVATE & local.flags)) { /* public */ + if (0 != (LIBXSMM_MALLOC_FLAG_SCRATCH & local.flags)) { /* scratch */ + const size_t current = (size_t)LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_LOAD, LIBXSMM_BITS)( + &internal_malloc_public_cur, LIBXSMM_ATOMIC_RELAXED); + LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE, LIBXSMM_BITS)(&internal_malloc_public_cur, + size_alloc <= current ? (current - size_alloc) : 0, LIBXSMM_ATOMIC_RELAXED); + } + else { /* local */ + const size_t current = (size_t)LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_LOAD, LIBXSMM_BITS)( + &internal_malloc_local_cur, LIBXSMM_ATOMIC_RELAXED); + LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE, LIBXSMM_BITS)(&internal_malloc_local_cur, + size_alloc <= current ? (current - size_alloc) : 0, LIBXSMM_ATOMIC_RELAXED); + } + } + else { /* private */ + const size_t current = (size_t)LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_LOAD, LIBXSMM_BITS)( + &internal_malloc_private_cur, LIBXSMM_ATOMIC_RELAXED); + LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_STORE, LIBXSMM_BITS)(&internal_malloc_private_cur, + size_alloc <= current ? (current - size_alloc) : 0, LIBXSMM_ATOMIC_RELAXED); + } + } + } +#if !defined(LIBXSMM_BUILD) + else if ((LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: attempt to release memory from non-matching implementation!\n"); + } +#endif + return result; +} + + +LIBXSMM_API_INTERN void libxsmm_malloc_init(void) +{ +#if (0 != LIBXSMM_SYNC) && defined(LIBXSMM_MALLOC_SCRATCH_JOIN) + const char *const env = getenv("LIBXSMM_MALLOC_JOIN"); + if (NULL != env && '\0' != *env) internal_malloc_join = atoi(env); +#endif +#if defined(LIBXSMM_MALLOC_HOOK_DYNAMIC) +# if defined(LIBXSMM_MALLOC_HOOK_QKMALLOC) + void* handle_qkmalloc = NULL; + dlerror(); /* clear an eventual error status */ + handle_qkmalloc = dlopen("libqkmalloc.so", RTLD_LAZY); + if (NULL != handle_qkmalloc) { + libxsmm_malloc_fn.memalign.ptr = internal_memalign_malloc; + libxsmm_malloc_fn.malloc.dlsym = dlsym(handle_qkmalloc, "malloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.malloc.dlsym) { +# if defined(LIBXSMM_MALLOC_HOOK_CALLOC) + libxsmm_malloc_fn.calloc.dlsym = dlsym(handle_qkmalloc, "calloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.calloc.dlsym) +# endif + { +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + libxsmm_malloc_fn.realloc.dlsym = dlsym(handle_qkmalloc, "realloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.realloc.dlsym) +# endif + { + libxsmm_malloc_fn.free.dlsym = dlsym(handle_qkmalloc, "free"); + } + } + } + dlclose(handle_qkmalloc); + } + if (NULL == libxsmm_malloc_fn.free.ptr) +# elif defined(LIBXSMM_MALLOC_HOOK_KMP) + dlerror(); /* clear an eventual error status */ + libxsmm_malloc_fn.alignmem.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "kmp_aligned_malloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.alignmem.dlsym) { + libxsmm_malloc_fn.memalign.ptr = internal_memalign_twiddle; + libxsmm_malloc_fn.malloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "kmp_malloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.malloc.dlsym) { +# if defined(LIBXSMM_MALLOC_HOOK_CALLOC) + libxsmm_malloc_fn.calloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "kmp_calloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.calloc.dlsym) +# endif + { +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + libxsmm_malloc_fn.realloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "kmp_realloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.realloc.dlsym) +# endif + { + libxsmm_malloc_fn.free.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "kmp_free"); + } + } + } + } + if (NULL == libxsmm_malloc_fn.free.ptr) +# endif /*defined(LIBXSMM_MALLOC_HOOK_QKMALLOC)*/ + { + dlerror(); /* clear an eventual error status */ +# if (defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD))) + libxsmm_malloc_fn.memalign.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "__libc_memalign"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.memalign.dlsym) { + libxsmm_malloc_fn.malloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "__libc_malloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.malloc.dlsym) { +# if defined(LIBXSMM_MALLOC_HOOK_CALLOC) + libxsmm_malloc_fn.calloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "__libc_calloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.calloc.dlsym) +# endif + { +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + libxsmm_malloc_fn.realloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "__libc_realloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.realloc.dlsym) +# endif + { + libxsmm_malloc_fn.free.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "__libc_free"); + } + } + } + } + if (NULL == libxsmm_malloc_fn.free.ptr) { + void* handle_libc = NULL; + dlerror(); /* clear an eventual error status */ + handle_libc = dlopen("libc.so." LIBXSMM_STRINGIFY(LIBXSMM_MALLOC_GLIBC), RTLD_LAZY); + if (NULL != handle_libc) { + libxsmm_malloc_fn.memalign.dlsym = dlsym(handle_libc, "__libc_memalign"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.memalign.dlsym) { + libxsmm_malloc_fn.malloc.dlsym = dlsym(handle_libc, "__libc_malloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.malloc.dlsym) { +# if defined(LIBXSMM_MALLOC_HOOK_CALLOC) + libxsmm_malloc_fn.calloc.dlsym = dlsym(handle_libc, "__libc_calloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.calloc.dlsym) +# endif + { +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + libxsmm_malloc_fn.realloc.dlsym = dlsym(handle_libc, "__libc_realloc"); + if (NULL == dlerror() && NULL != libxsmm_malloc_fn.realloc.dlsym) +# endif + { + libxsmm_malloc_fn.free.dlsym = dlsym(handle_libc, "__libc_free"); + } + } + } + } + dlclose(handle_libc); + } + } +# if 0 + { /* attempt to setup deprecated GLIBC hooks */ + union { const void* dlsym; void* (**ptr)(size_t, size_t, const void*); } hook_memalign; + dlerror(); /* clear an eventual error status */ + hook_memalign.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "__memalign_hook"); + if (NULL == dlerror() && NULL != hook_memalign.dlsym) { + union { const void* dlsym; void* (**ptr)(size_t, const void*); } hook_malloc; + hook_malloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "__malloc_hook"); + if (NULL == dlerror() && NULL != hook_malloc.dlsym) { +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + union { const void* dlsym; void* (**ptr)(void*, size_t, const void*); } hook_realloc; + hook_realloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "__realloc_hook"); + if (NULL == dlerror() && NULL != hook_realloc.dlsym) +# endif + { + union { const void* dlsym; void (**ptr)(void*, const void*); } hook_free; + hook_free.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "__free_hook"); + if (NULL == dlerror() && NULL != hook_free.dlsym) { + *hook_memalign.ptr = internal_memalign_hook; + *hook_malloc.ptr = internal_malloc_hook; +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + *hook_realloc.ptr = internal_realloc_hook; +# endif + *hook_free.ptr = internal_free_hook; + } + } + } + } + } +# endif +# else /* TODO */ +# endif /*(defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD)))*/ + } + if (NULL != libxsmm_malloc_fn.free.ptr) { +# if defined(LIBXSMM_MALLOC_HOOK_IMALLOC) + union { const void* dlsym; libxsmm_malloc_fun* ptr; } i_malloc; + i_malloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "i_malloc"); + if (NULL == dlerror() && NULL != i_malloc.dlsym) { +# if defined(LIBXSMM_MALLOC_HOOK_CALLOC) + union { const void* dlsym; void* (**ptr)(size_t, size_t); } i_calloc; + i_calloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "i_calloc"); + if (NULL == dlerror() && NULL != i_calloc.dlsym) +# endif + { +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + union { const void* dlsym; libxsmm_realloc_fun* ptr; } i_realloc; + i_realloc.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "i_realloc"); + if (NULL == dlerror() && NULL != i_realloc.dlsym) +# endif + { + union { const void* dlsym; libxsmm_free_fun* ptr; } i_free; + i_free.dlsym = dlsym(LIBXSMM_RTLD_NEXT, "i_free"); + if (NULL == dlerror() && NULL != i_free.dlsym) { + *i_malloc.ptr = libxsmm_malloc_fn.malloc.ptr; +# if defined(LIBXSMM_MALLOC_HOOK_CALLOC) + *i_calloc.ptr = libxsmm_malloc_fn.calloc.ptr; +# endif +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + *i_realloc.ptr = libxsmm_malloc_fn.realloc.ptr; +# endif + *i_free.ptr = libxsmm_malloc_fn.free.ptr; + } + } + } + } +# endif /*defined(LIBXSMM_MALLOC_HOOK_IMALLOC)*/ + } + else { /* fallback: potentially recursive */ +# if (defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD))) + libxsmm_malloc_fn.memalign.ptr = __libc_memalign; + libxsmm_malloc_fn.malloc.ptr = __libc_malloc; +# if defined(LIBXSMM_MALLOC_HOOK_CALLOC) + libxsmm_malloc_fn.calloc.ptr = __libc_calloc; +# endif +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + libxsmm_malloc_fn.realloc.ptr = __libc_realloc; +# endif + libxsmm_malloc_fn.free.ptr = __libc_free; +# else + libxsmm_malloc_fn.memalign.ptr = libxsmm_memalign_internal; + libxsmm_malloc_fn.malloc.ptr = malloc; +# if defined(LIBXSMM_MALLOC_HOOK_CALLOC) + libxsmm_malloc_fn.calloc.ptr = calloc; +# endif +# if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + libxsmm_malloc_fn.realloc.ptr = realloc; +# endif + libxsmm_malloc_fn.free.ptr = free; +# endif + } +#endif +} + + +LIBXSMM_API_INTERN void libxsmm_malloc_finalize(void) +{ +} + + +LIBXSMM_API_INTERN int libxsmm_xset_default_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock, + const void* context, libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn) +{ + int result = EXIT_SUCCESS; + if (NULL != lock) { + LIBXSMM_INIT + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK, lock); + } + if (NULL != malloc_fn.function && NULL != free_fn.function) { + libxsmm_default_allocator_context = context; + libxsmm_default_malloc_fn = malloc_fn; + libxsmm_default_free_fn = free_fn; + } + else { + libxsmm_malloc_function internal_malloc_fn; + libxsmm_free_function internal_free_fn; + const void* internal_allocator = NULL; + internal_malloc_fn.function = __real_malloc; + internal_free_fn.function = __real_free; + /*internal_allocator = NULL;*/ + if (NULL == malloc_fn.function && NULL == free_fn.function) { + libxsmm_default_allocator_context = internal_allocator; + libxsmm_default_malloc_fn = internal_malloc_fn; + libxsmm_default_free_fn = internal_free_fn; + } + else { /* invalid allocator */ + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: allocator setup without malloc or free function!\n"); + } + /* keep any valid (previously instantiated) default allocator */ + if (NULL == libxsmm_default_malloc_fn.function || NULL == libxsmm_default_free_fn.function) { + libxsmm_default_allocator_context = internal_allocator; + libxsmm_default_malloc_fn = internal_malloc_fn; + libxsmm_default_free_fn = internal_free_fn; + } + result = EXIT_FAILURE; + } + } + if (NULL != lock) { + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK, lock); + } + LIBXSMM_ASSERT(EXIT_SUCCESS == result); + return result; +} + + +LIBXSMM_API_INTERN int libxsmm_xget_default_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock, + const void** context, libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn) +{ + int result = EXIT_SUCCESS; + if (NULL != context || NULL != malloc_fn || NULL != free_fn) { + if (NULL != lock) { + LIBXSMM_INIT + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK, lock); + } + if (context) *context = libxsmm_default_allocator_context; + if (NULL != malloc_fn) *malloc_fn = libxsmm_default_malloc_fn; + if (NULL != free_fn) *free_fn = libxsmm_default_free_fn; + if (NULL != lock) { + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK, lock); + } + } + else if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + static int error_once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { + fprintf(stderr, "LIBXSMM ERROR: invalid signature used to get the default memory allocator!\n"); + } + result = EXIT_FAILURE; + } + LIBXSMM_ASSERT(EXIT_SUCCESS == result); + return result; +} + + +LIBXSMM_API_INTERN int libxsmm_xset_scratch_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock, + const void* context, libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn) +{ + int result = EXIT_SUCCESS; + static int error_once = 0; + if (NULL != lock) { + LIBXSMM_INIT + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK, lock); + } + /* make sure the default allocator is setup before adopting it eventually */ + if (NULL == libxsmm_default_malloc_fn.function || NULL == libxsmm_default_free_fn.function) { + const libxsmm_malloc_function null_malloc_fn = { NULL }; + const libxsmm_free_function null_free_fn = { NULL }; + libxsmm_xset_default_allocator(NULL/*already locked*/, NULL/*context*/, null_malloc_fn, null_free_fn); + } + if (NULL == malloc_fn.function && NULL == free_fn.function) { /* adopt default allocator */ + libxsmm_scratch_allocator_context = libxsmm_default_allocator_context; + libxsmm_scratch_malloc_fn = libxsmm_default_malloc_fn; + libxsmm_scratch_free_fn = libxsmm_default_free_fn; + } + else if (NULL != malloc_fn.function) { + if (NULL == free_fn.function + && /*warning*/(LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: scratch allocator setup without free function!\n"); + } + libxsmm_scratch_allocator_context = context; + libxsmm_scratch_malloc_fn = malloc_fn; + libxsmm_scratch_free_fn = free_fn; /* NULL allowed */ + } + else { /* invalid scratch allocator */ + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid scratch allocator (default used)!\n"); + } + /* keep any valid (previously instantiated) scratch allocator */ + if (NULL == libxsmm_scratch_malloc_fn.function) { + libxsmm_scratch_allocator_context = libxsmm_default_allocator_context; + libxsmm_scratch_malloc_fn = libxsmm_default_malloc_fn; + libxsmm_scratch_free_fn = libxsmm_default_free_fn; + } + result = EXIT_FAILURE; + } + if (NULL != lock) { + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK, lock); + } + LIBXSMM_ASSERT(EXIT_SUCCESS == result); + return result; +} + + +LIBXSMM_API_INTERN int libxsmm_xget_scratch_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock, + const void** context, libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn) +{ + int result = EXIT_SUCCESS; + if (NULL != context || NULL != malloc_fn || NULL != free_fn) { + if (NULL != lock) { + LIBXSMM_INIT + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK, lock); + } + if (context) *context = libxsmm_scratch_allocator_context; + if (NULL != malloc_fn) *malloc_fn = libxsmm_scratch_malloc_fn; + if (NULL != free_fn) *free_fn = libxsmm_scratch_free_fn; + if (NULL != lock) { + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK, lock); + } + } + else if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + static int error_once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { + fprintf(stderr, "LIBXSMM ERROR: invalid signature used to get the scratch memory allocator!\n"); + } + result = EXIT_FAILURE; + } + LIBXSMM_ASSERT(EXIT_SUCCESS == result); + return result; +} + + +LIBXSMM_API int libxsmm_set_default_allocator(const void* context, + libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn) +{ + return libxsmm_xset_default_allocator(&libxsmm_lock_global, context, malloc_fn, free_fn); +} + + +LIBXSMM_API int libxsmm_get_default_allocator(const void** context, + libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn) +{ + return libxsmm_xget_default_allocator(&libxsmm_lock_global, context, malloc_fn, free_fn); +} + + +LIBXSMM_API int libxsmm_set_scratch_allocator(const void* context, + libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn) +{ + return libxsmm_xset_scratch_allocator(&libxsmm_lock_global, context, malloc_fn, free_fn); +} + + +LIBXSMM_API int libxsmm_get_scratch_allocator(const void** context, + libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn) +{ + return libxsmm_xget_scratch_allocator(&libxsmm_lock_global, context, malloc_fn, free_fn); +} + + +LIBXSMM_API int libxsmm_get_malloc_xinfo(const void* memory, size_t* size, int* flags, void** extra) +{ + int result; +#if !defined(NDEBUG) + if (NULL != size || NULL != extra) +#endif + { + const int check = ((NULL == flags || 0 == (LIBXSMM_MALLOC_FLAG_X & *flags)) ? 2 : 1); + const internal_malloc_info_type *const info = internal_malloc_info(memory, check); + if (NULL != info) { + if (NULL != size) *size = info->size; + if (NULL != flags) *flags = info->flags; + if (NULL != extra) *extra = info->pointer; + result = EXIT_SUCCESS; + } + else { /* potentially foreign buffer */ + result = (NULL != memory ? EXIT_FAILURE : EXIT_SUCCESS); + if (NULL != size) *size = 0; + if (NULL != flags) *flags = 0; + if (NULL != extra) *extra = NULL; + } + } +#if !defined(NDEBUG) + else { + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: attachment error for memory buffer %p!\n", memory); + } + LIBXSMM_ASSERT_MSG(0/*false*/, "LIBXSMM ERROR: attachment error"); + result = EXIT_FAILURE; + } +#endif + return result; +} + + +#if !defined(_WIN32) + +LIBXSMM_API_INLINE void internal_xmalloc_mhint(void* buffer, size_t size) +{ + LIBXSMM_ASSERT((MAP_FAILED != buffer && NULL != buffer) || 0 == size); +#if (defined(_DEFAULT_SOURCE) || defined(_BSD_SOURCE)) + /* proceed after failed madvise (even in case of an error; take what we got) */ + /* issue no warning as a failure seems to be related to the kernel version */ + madvise(buffer, size, MADV_NORMAL/*MADV_RANDOM*/ +# if defined(MADV_NOHUGEPAGE) /* if not available, we then take what we got (THP) */ + | ((LIBXSMM_MALLOC_ALIGNMAX * LIBXSMM_MALLOC_ALIGNFCT) > size ? MADV_NOHUGEPAGE : 0) +# endif +# if defined(MADV_DONTDUMP) + | ((LIBXSMM_MALLOC_ALIGNMAX * LIBXSMM_MALLOC_ALIGNFCT) > size ? 0 : MADV_DONTDUMP) +# endif + ); +#else + LIBXSMM_UNUSED(buffer); LIBXSMM_UNUSED(size); +#endif +} + + +LIBXSMM_API_INLINE void* internal_xmalloc_xmap(const char* dir, size_t size, int flags, void** rx) +{ + void* result = MAP_FAILED; + char filename[4096] = LIBXSMM_MALLOC_XMAP_TEMPLATE; + int i = 0; + LIBXSMM_ASSERT(NULL != rx && MAP_FAILED != *rx); + if (NULL != dir && '\0' != *dir) { + i = LIBXSMM_SNPRINTF(filename, sizeof(filename), "%s/" LIBXSMM_MALLOC_XMAP_TEMPLATE, dir); + } + if (0 <= i && i < (int)sizeof(filename)) { + /* coverity[secure_temp] */ + i = mkstemp(filename); + if (0 <= i) { + if (0 == unlink(filename) && 0 == ftruncate(i, size) /*&& 0 == chmod(filename, S_IRWXU)*/) { + const int mflags = (flags | LIBXSMM_MAP_SHARED); + void *const xmap = mmap(*rx, size, PROT_READ | PROT_EXEC, mflags, i, 0/*offset*/); + if (MAP_FAILED != xmap) { + LIBXSMM_ASSERT(NULL != xmap); +#if defined(MAP_32BIT) + result = mmap(NULL, size, PROT_READ | PROT_WRITE, mflags & ~MAP_32BIT, i, 0/*offset*/); +#else + result = mmap(NULL, size, PROT_READ | PROT_WRITE, mflags, i, 0/*offset*/); +#endif + if (MAP_FAILED != result) { + LIBXSMM_ASSERT(NULL != result); + internal_xmalloc_mhint(xmap, size); + *rx = xmap; + } + else { + munmap(xmap, size); + *rx = NULL; + } + } + } + close(i); + } + } + return result; +} + +#endif /*!defined(_WIN32)*/ + + +LIBXSMM_API_INLINE void* internal_xrealloc(void** ptr, internal_malloc_info_type** info, size_t size, + libxsmm_realloc_fun realloc_fn, libxsmm_free_fun free_fn) +{ + char *const base = (char*)(NULL != *info ? (*info)->pointer : *ptr), *result; + LIBXSMM_ASSERT(NULL != *ptr && NULL != free_fn); + /* reallocation may implicitly invalidate info */ + result = (char*)(NULL != realloc_fn ? realloc_fn(base, size) : __real_malloc(size)); + if (result == base) { /* signal no-copy */ + LIBXSMM_ASSERT(NULL != result); + *info = NULL; /* no delete */ + *ptr = NULL; /* no copy */ + } + else if (NULL != result) { /* copy */ + if (NULL != realloc_fn) { + const size_t offset_src = (const char*)*ptr - base; + *ptr = result + offset_src; /* copy */ + *info = NULL; /* no delete */ + } + } +#if !defined(NDEBUG) && 0 + else { /* failed */ + if (NULL != *info) { + internal_xfree(*ptr, *info); /* invalidates info */ + } + else { /* foreign pointer */ + free_fn(*ptr); + } + *info = NULL; /* no delete */ + *ptr = NULL; /* no copy */ + } +#else + LIBXSMM_UNUSED(free_fn); +#endif + return result; +} + + +LIBXSMM_API_INTERN void* internal_xmalloc(void** /*ptr*/, internal_malloc_info_type** /*info*/, size_t /*size*/, + const void* /*context*/, libxsmm_malloc_function /*malloc_fn*/, libxsmm_free_function /*free_fn*/); +LIBXSMM_API_INTERN void* internal_xmalloc(void** ptr, internal_malloc_info_type** info, size_t size, + const void* context, libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn) +{ + void* result; + LIBXSMM_ASSERT(NULL != ptr && NULL != info && NULL != malloc_fn.function); + if (NULL == *ptr) { + result = (NULL == context + ? malloc_fn.function(size) + : malloc_fn.ctx_form(size, context)); + } + else { /* reallocate */ + if (NULL != free_fn.function /* prefer free_fn since it is part of pointer-info */ + ? (__real_free == free_fn.function || free == free_fn.function) + : (__real_malloc == malloc_fn.function || malloc == malloc_fn.function)) + { +#if defined(LIBXSMM_MALLOC_HOOK_REALLOC) + result = internal_xrealloc(ptr, info, size, __real_realloc, __real_free); +#else + result = internal_xrealloc(ptr, info, size, NULL, __real_free); +#endif + } + else { /* fallback with regular allocation */ + result = (NULL == context + ? malloc_fn.function(size) + : malloc_fn.ctx_form(size, context)); + if (NULL == result) { /* failed */ + if (NULL != *info) { + internal_xfree(*ptr, *info); /* invalidates info */ + } + else { /* foreign pointer */ + (NULL != free_fn.function ? free_fn.function : __real_free)(*ptr); + } + *ptr = NULL; /* safe delete */ + } + } + } + return result; +} + + +LIBXSMM_API int libxsmm_xmalloc(void** memory, size_t size, size_t alignment, + int flags, const void* extra, size_t extra_size) +{ + int result = EXIT_SUCCESS; +#if !defined(NDEBUG) + if (NULL != memory) +#endif + { + static int error_once = 0; + if (0 != size) { + size_t alloc_alignment = 0, alloc_size = 0, max_preserve = 0; + internal_malloc_info_type* info = NULL; + void *buffer = NULL, *reloc = NULL; + /* ATOMIC BEGIN: this region should be atomic/locked */ + const void* context = libxsmm_default_allocator_context; + libxsmm_malloc_function malloc_fn = libxsmm_default_malloc_fn; + libxsmm_free_function free_fn = libxsmm_default_free_fn; + if (0 != (LIBXSMM_MALLOC_FLAG_SCRATCH & flags)) { + context = libxsmm_scratch_allocator_context; + malloc_fn = libxsmm_scratch_malloc_fn; + free_fn = libxsmm_scratch_free_fn; +#if defined(LIBXSMM_MALLOC_MMAP_SCRATCH) + flags |= LIBXSMM_MALLOC_FLAG_MMAP; +#endif + } + if ((0 != (internal_malloc_kind & 1) && 0 < internal_malloc_kind) + || NULL == malloc_fn.function || NULL == free_fn.function) + { + malloc_fn.function = __real_malloc; + free_fn.function = __real_free; + context = NULL; + } + /* ATOMIC END: this region should be atomic */ + flags |= LIBXSMM_MALLOC_FLAG_RW; /* normalize given flags since flags=0 is accepted as well */ + if (0 != (LIBXSMM_MALLOC_FLAG_REALLOC & flags) && NULL != *memory) { + info = internal_malloc_info(*memory, 2/*check*/); + if (NULL != info) { + max_preserve = info->size; + } + else { /* reallocation of unknown allocation */ + flags &= ~LIBXSMM_MALLOC_FLAG_MMAP; + } + } + else *memory = NULL; +#if !defined(LIBXSMM_MALLOC_MMAP) + if (0 == (LIBXSMM_MALLOC_FLAG_X & flags) && 0 == (LIBXSMM_MALLOC_FLAG_MMAP & flags)) { + alloc_alignment = (0 == (LIBXSMM_MALLOC_FLAG_REALLOC & flags) ? libxsmm_alignment(size, alignment) : alignment); + alloc_size = size + extra_size + sizeof(internal_malloc_info_type) + alloc_alignment - 1; + buffer = internal_xmalloc(memory, &info, alloc_size, context, malloc_fn, free_fn); + } + else +#endif + if (NULL == info || size != info->size) { +#if defined(_WIN32) || defined(__CYGWIN__) + const int mflags = (0 != (LIBXSMM_MALLOC_FLAG_X & flags) ? PAGE_EXECUTE_READWRITE : PAGE_READWRITE); + static SIZE_T alloc_alignmax = 0, alloc_pagesize = 0; + if (0 == alloc_alignmax) { /* first/one time */ + SYSTEM_INFO system_info; + GetSystemInfo(&system_info); + alloc_pagesize = system_info.dwPageSize; + alloc_alignmax = GetLargePageMinimum(); + } + if ((LIBXSMM_MALLOC_ALIGNMAX * LIBXSMM_MALLOC_ALIGNFCT) <= size) { /* attempt to use large pages */ + HANDLE process_token; + alloc_alignment = (NULL == info + ? (0 == alignment ? alloc_alignmax : libxsmm_lcm(alignment, alloc_alignmax)) + : libxsmm_lcm(alignment, alloc_alignmax)); + alloc_size = LIBXSMM_UP2(size + extra_size + sizeof(internal_malloc_info_type) + alloc_alignment - 1, alloc_alignmax); + if (TRUE == OpenProcessToken(GetCurrentProcess(), TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &process_token)) { + TOKEN_PRIVILEGES tp; + if (TRUE == LookupPrivilegeValue(NULL, TEXT("SeLockMemoryPrivilege"), &tp.Privileges[0].Luid)) { + tp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED; tp.PrivilegeCount = 1; /* enable privilege */ + if (TRUE == AdjustTokenPrivileges(process_token, FALSE, &tp, 0, (PTOKEN_PRIVILEGES)NULL, 0) + && ERROR_SUCCESS == GetLastError()/*may has failed (regardless of TRUE)*/) + { + /* VirtualAlloc cannot be used to reallocate memory */ + buffer = VirtualAlloc(NULL, alloc_size, MEM_RESERVE | MEM_COMMIT | MEM_LARGE_PAGES, mflags); + } + tp.Privileges[0].Attributes = 0; /* disable privilege */ + AdjustTokenPrivileges(process_token, FALSE, &tp, 0, (PTOKEN_PRIVILEGES)NULL, 0); + } + CloseHandle(process_token); + } + } + else { /* small allocation using regular page-size */ + alloc_alignment = (NULL == info ? libxsmm_alignment(size, alignment) : alignment); + alloc_size = LIBXSMM_UP2(size + extra_size + sizeof(internal_malloc_info_type) + alloc_alignment - 1, alloc_pagesize); + } + if (NULL == buffer) { /* small allocation or retry with regular page size */ + /* VirtualAlloc cannot be used to reallocate memory */ + buffer = VirtualAlloc(NULL, alloc_size, MEM_RESERVE | MEM_COMMIT, mflags); + } + if (NULL != buffer) { + flags |= LIBXSMM_MALLOC_FLAG_MMAP; /* select the corresponding deallocation */ + } + else if (0 == (LIBXSMM_MALLOC_FLAG_MMAP & flags)) { /* fallback allocation */ + buffer = internal_xmalloc(memory, &info, alloc_size, context, malloc_fn, free_fn); + } +#else /* !defined(_WIN32) */ +# if defined(MAP_HUGETLB) && defined(LIBXSMM_MALLOC_HUGE_PAGES) + static size_t limit_hugetlb = LIBXSMM_SCRATCH_UNLIMITED; +# endif +# if defined(MAP_LOCKED) && defined(LIBXSMM_MALLOC_LOCK_PAGES) + static size_t limit_plocked = LIBXSMM_SCRATCH_UNLIMITED; +# endif +# if defined(MAP_32BIT) + static int map32 = 1; +# endif + int mflags = 0 +# if defined(MAP_UNINITIALIZED) && 0/*fails with WSL*/ + | MAP_UNINITIALIZED /* unlikely available */ +# endif +# if defined(MAP_NORESERVE) + | (LIBXSMM_MALLOC_ALIGNMAX < size ? 0 : MAP_NORESERVE) +# endif +# if defined(MAP_32BIT) + | ((0 != (LIBXSMM_MALLOC_FLAG_X & flags) && 0 != map32 + && (LIBXSMM_X86_AVX512_CORE > libxsmm_target_archid) + && (LIBXSMM_X86_AVX512 < libxsmm_target_archid || + LIBXSMM_X86_AVX > libxsmm_target_archid)) ? MAP_32BIT : 0) +# endif +# if defined(MAP_HUGETLB) && defined(LIBXSMM_MALLOC_HUGE_PAGES) + | ((0 == (LIBXSMM_MALLOC_FLAG_X & flags) + && ((LIBXSMM_MALLOC_ALIGNMAX * LIBXSMM_MALLOC_ALIGNFCT) <= size || + 0 != (LIBXSMM_MALLOC_FLAG_PHUGE & flags)) + && (internal_malloc_hugetlb + size) < limit_hugetlb) ? MAP_HUGETLB : 0) +# endif +# if defined(MAP_LOCKED) && defined(LIBXSMM_MALLOC_LOCK_PAGES) && 0 == (LIBXSMM_MALLOC_LOCK_PAGES) + | (((0 != (LIBXSMM_MALLOC_FLAG_PLOCK & flags) || 0 == (LIBXSMM_MALLOC_FLAG_X & flags)) + && (internal_malloc_plocked + size) < limit_plocked) ? MAP_LOCKED : 0) +# endif + ; /* mflags */ +# if defined(MAP_POPULATE) + { static int prefault = 0; + if (0 == prefault) { /* prefault only on Linux 3.10.0-327 (and later) to avoid data race in page-fault handler */ + struct utsname osinfo; unsigned int version_major = 3, version_minor = 10, version_update = 0, version_patch = 327; + if (0 <= uname(&osinfo) && 0 == strcmp("Linux", osinfo.sysname) + && 4 == sscanf(osinfo.release, "%u.%u.%u-%u", &version_major, &version_minor, &version_update, &version_patch) + && LIBXSMM_VERSION4(3, 10, 0, 327) > LIBXSMM_VERSION4(version_major, version_minor, version_update, version_patch)) + { + mflags |= MAP_POPULATE; prefault = 1; + } + else prefault = -1; + } + else if (1 == prefault) mflags |= MAP_POPULATE; + } +# endif + /* make allocated size at least a multiple of the smallest page-size to avoid split-pages (unmap!) */ + alloc_alignment = libxsmm_lcm(0 == alignment ? libxsmm_alignment(size, alignment) : alignment, LIBXSMM_PAGE_MINSIZE); + alloc_size = LIBXSMM_UP2(size + extra_size + sizeof(internal_malloc_info_type) + alloc_alignment - 1, alloc_alignment); + if (0 == (LIBXSMM_MALLOC_FLAG_X & flags)) { /* anonymous and non-executable */ +# if defined(MAP_32BIT) + LIBXSMM_ASSERT(0 == (MAP_32BIT & mflags)); +# endif +# if 0 + LIBXSMM_ASSERT(NULL != info || NULL == *memory); /* no memory mapping of foreign pointer */ +# endif + buffer = mmap(NULL == info ? NULL : info->pointer, alloc_size, PROT_READ | PROT_WRITE, + MAP_PRIVATE | LIBXSMM_MAP_ANONYMOUS | mflags, -1, 0/*offset*/); +# if defined(MAP_HUGETLB) && defined(LIBXSMM_MALLOC_HUGE_PAGES) + INTERNAL_XMALLOC_KIND(MAP_HUGETLB, "huge-page", LIBXSMM_MALLOC_FLAG_PHUGE, flags, mflags, + internal_malloc_hugetlb, limit_hugetlb, info, alloc_size, buffer); +# endif +# if defined(MAP_LOCKED) && defined(LIBXSMM_MALLOC_LOCK_PAGES) +# if 0 == (LIBXSMM_MALLOC_LOCK_PAGES) + INTERNAL_XMALLOC_KIND(MAP_LOCKED, "locked-page", LIBXSMM_MALLOC_FLAG_PLOCK, flags, mflags, + internal_malloc_plocked, limit_plocked, info, alloc_size, buffer); +# else + if (0 != (MAP_LOCKED & mflags) && MAP_FAILED != buffer) { + LIBXSMM_ASSERT(NULL != buffer); +# if 1 == (LIBXSMM_MALLOC_LOCK_PAGES) || !defined(MLOCK_ONFAULT) || !defined(SYS_mlock2) + if (0 == mlock(buffer, alloc_size)) +# elif 0 /* mlock2 is potentially not exposed */ + if (0 == mlock2(buffer, alloc_size, MLOCK_ONFAULT)) +# else + if (0 == syscall(SYS_mlock2, buffer, alloc_size, MLOCK_ONFAULT)) +# endif + { + LIBXSMM_ATOMIC_ADD_FETCH(&internal_malloc_plocked, alloc_size, LIBXSMM_ATOMIC_RELAXED); + flags |= LIBXSMM_MALLOC_FLAG_PLOCK; + } + else { /* update watermark */ + INTERNAL_XMALLOC_WATERMARK("locked-page", internal_malloc_plocked, limit_plocked, alloc_size); + flags &= ~LIBXSMM_MALLOC_FLAG_PLOCK; + } + } +# endif +# endif + } + else { /* executable buffer requested */ + static /*LIBXSMM_TLS*/ int entrypoint = -1; /* fallback allocation method */ +# if defined(MAP_HUGETLB) && defined(LIBXSMM_MALLOC_HUGE_PAGES) + LIBXSMM_ASSERT(0 == (MAP_HUGETLB & mflags)); +# endif +# if defined(MAP_LOCKED) && defined(LIBXSMM_MALLOC_LOCK_PAGES) + LIBXSMM_ASSERT(0 == (MAP_LOCKED & mflags)); +# endif + if (0 > (int)LIBXSMM_ATOMIC_LOAD(&entrypoint, LIBXSMM_ATOMIC_RELAXED)) { + const char *const env = getenv("LIBXSMM_SE"); + LIBXSMM_ATOMIC_STORE(&entrypoint, NULL == env + /* libxsmm_se decides */ + ? (0 == libxsmm_se ? LIBXSMM_MALLOC_FINAL : LIBXSMM_MALLOC_FALLBACK) + /* user's choice takes precedence */ + : ('0' != *env ? LIBXSMM_MALLOC_FALLBACK : LIBXSMM_MALLOC_FINAL), + LIBXSMM_ATOMIC_SEQ_CST); + LIBXSMM_ASSERT(0 <= entrypoint); + } + INTERNAL_XMALLOC(0, entrypoint, "JITDUMPDIR", "", map32, mflags, alloc_size, buffer, &reloc); /* 1st try */ + INTERNAL_XMALLOC(1, entrypoint, "TMPDIR", "/tmp", map32, mflags, alloc_size, buffer, &reloc); /* 2nd try */ + /* coverity[string_size] */ + INTERNAL_XMALLOC(2, entrypoint, "HOME", "", map32, mflags, alloc_size, buffer, &reloc); /* 3rd try */ + if (3 >= entrypoint && (MAP_FAILED == buffer || NULL == buffer)) { /* 4th try */ + buffer = mmap(reloc, alloc_size, PROT_READ | PROT_WRITE | PROT_EXEC, +# if defined(MAP_32BIT) + MAP_PRIVATE | LIBXSMM_MAP_ANONYMOUS | (0 == map32 ? (mflags & ~MAP_32BIT) : mflags), +# else + MAP_PRIVATE | LIBXSMM_MAP_ANONYMOUS | mflags, +# endif + -1, 0/*offset*/); + if (MAP_FAILED != buffer) entrypoint = 3; +# if defined(MAP_32BIT) + else if (0 != (MAP_32BIT & mflags) && 0 != map32) { + buffer = mmap(reloc, alloc_size, PROT_READ | PROT_WRITE | PROT_EXEC, + MAP_PRIVATE | LIBXSMM_MAP_ANONYMOUS | (mflags & ~MAP_32BIT), + - 1, 0/*offset*/); + if (MAP_FAILED != buffer) { + entrypoint = 3; + map32 = 0; + } + } +# endif + } + /* upgrade to SE-mode and retry lower entry-points */ + if (MAP_FAILED == buffer && 0 == libxsmm_se) { + libxsmm_se = 1; entrypoint = 0; + INTERNAL_XMALLOC(0, entrypoint, "JITDUMPDIR", "", map32, mflags, alloc_size, buffer, &reloc); /* 1st try */ + INTERNAL_XMALLOC(1, entrypoint, "TMPDIR", "/tmp", map32, mflags, alloc_size, buffer, &reloc); /* 2nd try */ + INTERNAL_XMALLOC(2, entrypoint, "HOME", "", map32, mflags, alloc_size, buffer, &reloc); /* 3rd try */ + } + } + if (MAP_FAILED != buffer && NULL != buffer) { + flags |= LIBXSMM_MALLOC_FLAG_MMAP; /* select deallocation */ + } + else { /* allocation failed */ + if (0 == (LIBXSMM_MALLOC_FLAG_MMAP & flags)) { /* ultimate fallback */ + buffer = (NULL != malloc_fn.function + ? (NULL == context ? malloc_fn.function(alloc_size) : malloc_fn.ctx_form(alloc_size, context)) + : (NULL)); + } + reloc = NULL; + } + if (MAP_FAILED != buffer && NULL != buffer) { + internal_xmalloc_mhint(buffer, alloc_size); + } +#endif /* !defined(_WIN32) */ + } + else { /* reallocation of the same pointer and size */ + alloc_size = size + extra_size + sizeof(internal_malloc_info_type) + alignment - 1; + if (NULL != info) { + buffer = info->pointer; + flags |= info->flags; + } + else { + flags |= LIBXSMM_MALLOC_FLAG_MMAP; + buffer = *memory; + } + alloc_alignment = alignment; + *memory = NULL; /* signal no-copy */ + } + if ( +#if !defined(_WIN32) && !defined(__clang_analyzer__) + MAP_FAILED != buffer && +#endif + NULL != buffer) + { + char *const cbuffer = (char*)buffer, *const aligned = LIBXSMM_ALIGN( + cbuffer + extra_size + sizeof(internal_malloc_info_type), alloc_alignment); + internal_malloc_info_type *const buffer_info = (internal_malloc_info_type*)( + aligned - sizeof(internal_malloc_info_type)); + LIBXSMM_ASSERT((aligned + size) <= (cbuffer + alloc_size)); + LIBXSMM_ASSERT(0 < alloc_alignment); + /* former content must be preserved prior to setup of buffer_info */ + if (NULL != *memory) { /* preserve/copy previous content */ +#if 0 + LIBXSMM_ASSERT(0 != (LIBXSMM_MALLOC_FLAG_REALLOC & flags)); +#endif + /* content behind foreign pointers is not explicitly preserved; buffers may overlap */ + memmove(aligned, *memory, LIBXSMM_MIN(max_preserve, size)); + if (NULL != info /* known allocation (non-foreign pointer) */ + && EXIT_SUCCESS != internal_xfree(*memory, info) /* !libxsmm_free, invalidates info */ + && 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { /* display some extra context of the failure (reallocation) */ + fprintf(stderr, "LIBXSMM ERROR: memory reallocation failed to release memory!\n"); + } + } + if (NULL != extra || 0 == extra_size) { + const char *const src = (const char*)extra; + int i; for (i = 0; i < (int)extra_size; ++i) cbuffer[i] = src[i]; + } + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: incorrect extraneous data specification!\n"); + /* no EXIT_FAILURE because valid buffer is returned */ + } + if (0 == (LIBXSMM_MALLOC_FLAG_X & flags)) { /* update statistics */ + if (0 == (LIBXSMM_MALLOC_FLAG_PRIVATE & flags)) { /* public */ + if (0 != (LIBXSMM_MALLOC_FLAG_SCRATCH & flags)) { /* scratch */ + const size_t watermark = LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_ADD_FETCH, LIBXSMM_BITS)( + &internal_malloc_public_cur, alloc_size, LIBXSMM_ATOMIC_RELAXED); + if (internal_malloc_public_max < watermark) internal_malloc_public_max = watermark; /* accept data-race */ + } + else { /* local */ + const size_t watermark = LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_ADD_FETCH, LIBXSMM_BITS)( + &internal_malloc_local_cur, alloc_size, LIBXSMM_ATOMIC_RELAXED); + if (internal_malloc_local_max < watermark) internal_malloc_local_max = watermark; /* accept data-race */ + } + } + else if (0 != (LIBXSMM_MALLOC_FLAG_SCRATCH & flags)) { /* private scratch */ + const size_t watermark = LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_ADD_FETCH, LIBXSMM_BITS)( + &internal_malloc_private_cur, alloc_size, LIBXSMM_ATOMIC_RELAXED); + if (internal_malloc_private_max < watermark) internal_malloc_private_max = watermark; /* accept data-race */ + } + } + /* keep allocation function on record */ + if (0 == (LIBXSMM_MALLOC_FLAG_MMAP & flags)) { + buffer_info->context = context; + buffer_info->free = free_fn; + } + else { + buffer_info->free.function = NULL; + buffer_info->context = NULL; + } +#if defined(LIBXSMM_MALLOC_INFO_ALLOCSIZE) + buffer_info->size_alloc = alloc_size; +#endif + buffer_info->size = size; + buffer_info->pointer = buffer; + buffer_info->reloc = reloc; + buffer_info->flags = flags; +#if defined(LIBXSMM_VTUNE) + buffer_info->code_id = 0; +#endif /* info must be initialized to calculate correct checksum */ +#if !defined(LIBXSMM_MALLOC_CRC_OFF) +# if defined(LIBXSMM_MALLOC_CRC_LIGHT) + buffer_info->hash = LIBXSMM_CRC32U(LIBXSMM_BITS)(LIBXSMM_MALLOC_SEED, &buffer_info); +# else + buffer_info->hash = libxsmm_crc32(LIBXSMM_MALLOC_SEED, buffer_info, + (unsigned int)(((char*)&buffer_info->hash) - ((char*)buffer_info))); +# endif +#endif /* finally commit/return allocated buffer */ + *memory = aligned; + } + else { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + char alloc_size_buffer[32]; + libxsmm_format_value(alloc_size_buffer, sizeof(alloc_size_buffer), alloc_size, "KM", "B", 10); + fprintf(stderr, "LIBXSMM ERROR: failed to allocate %s with flag=%i!\n", alloc_size_buffer, flags); + } + result = EXIT_FAILURE; + *memory = NULL; + } + } + else { + if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: zero-sized memory allocation detected!\n"); + } + *memory = NULL; /* no EXIT_FAILURE */ + } + } +#if !defined(NDEBUG) + else if (0 != size) { + result = EXIT_FAILURE; + } +#endif + return result; +} + + +LIBXSMM_API void libxsmm_xfree(const void* memory, int check) +{ +#if (!defined(LIBXSMM_MALLOC_HOOK) || defined(_DEBUG)) + static int error_once = 0; +#endif + /*const*/ internal_malloc_info_type *const info = internal_malloc_info(memory, check); + if (NULL != info) { /* !libxsmm_free */ +#if (!defined(LIBXSMM_MALLOC_HOOK) || defined(_DEBUG)) + if (EXIT_SUCCESS != internal_xfree(memory, info)) { /* invalidates info */ + if ( 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: memory deallocation failed!\n"); + } + } +#else + internal_xfree(memory, info); /* invalidates info */ +#endif + } + else if (NULL != memory) { +#if 1 + union { const void* const_ptr; void* ptr; } cast; + cast.const_ptr = memory; /* C-cast still warns */ + __real_free(cast.ptr); +#endif +#if (!defined(LIBXSMM_MALLOC_HOOK) || defined(_DEBUG)) + if ( 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: deallocation does not match allocation!\n"); + } +#endif + } +} + + +#if defined(LIBXSMM_VTUNE) +LIBXSMM_API_INLINE void internal_get_vtune_jitdesc(const void* code, + unsigned int code_id, size_t code_size, const char* code_name, + LIBXSMM_VTUNE_JIT_DESC_TYPE* desc) +{ + LIBXSMM_ASSERT(NULL != code && 0 != code_id && 0 != code_size && NULL != desc); + desc->method_id = code_id; + /* incorrect constness (method_name) */ + desc->method_name = (char*)code_name; + /* incorrect constness (method_load_address) */ + desc->method_load_address = (void*)code; + desc->method_size = code_size; + desc->line_number_size = 0; + desc->line_number_table = NULL; + desc->class_file_name = NULL; + desc->source_file_name = NULL; +# if (2 <= LIBXSMM_VTUNE_JITVERSION) + desc->module_name = "libxsmm.jit"; +# endif +} +#endif + + +LIBXSMM_API_INTERN int libxsmm_malloc_attrib(void** memory, int flags, const char* name) +{ + internal_malloc_info_type *const info = (NULL != memory ? internal_malloc_info(*memory, 0/*no check*/) : NULL); + int result = EXIT_SUCCESS; + static int error_once = 0; + if (NULL != info) { + void *const buffer = info->pointer; + const size_t size = info->size; +#if defined(_WIN32) + LIBXSMM_ASSERT(NULL != buffer || 0 == size); +#else + LIBXSMM_ASSERT((NULL != buffer && MAP_FAILED != buffer) || 0 == size); +#endif + flags |= (info->flags & ~LIBXSMM_MALLOC_FLAG_RWX); /* merge with current flags */ + /* quietly keep the read permission, but eventually revoke write permissions */ + if (0 == (LIBXSMM_MALLOC_FLAG_W & flags) || 0 != (LIBXSMM_MALLOC_FLAG_X & flags)) { + const size_t alignment = (size_t)(((const char*)(*memory)) - ((const char*)buffer)); + const size_t alloc_size = size + alignment; + if (0 == (LIBXSMM_MALLOC_FLAG_X & flags)) { /* data-buffer; non-executable */ +#if defined(_WIN32) + /* TODO: implement memory protection under Microsoft Windows */ + LIBXSMM_UNUSED(alloc_size); +#else + if (EXIT_SUCCESS != mprotect(buffer, alloc_size/*entire memory region*/, PROT_READ) + && (LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: read-only request for buffer failed!\n"); + } +#endif + } + else { /* executable buffer requested */ + void *const code_ptr = (NULL != info->reloc ? ((void*)(((char*)info->reloc) + alignment)) : *memory); + LIBXSMM_ASSERT(0 != (LIBXSMM_MALLOC_FLAG_X & flags)); + if (name && *name) { /* profiler support requested */ + if (0 > libxsmm_verbosity) { /* avoid dump if just the profiler is enabled */ + LIBXSMM_EXPECT(EXIT_SUCCESS, libxsmm_dump("LIBXSMM-JIT-DUMP", name, code_ptr, size, 1/*unique*/)); + } +#if defined(LIBXSMM_VTUNE) + if (iJIT_SAMPLING_ON == iJIT_IsProfilingActive()) { + LIBXSMM_VTUNE_JIT_DESC_TYPE vtune_jit_desc; + const unsigned int code_id = iJIT_GetNewMethodID(); + internal_get_vtune_jitdesc(code_ptr, code_id, size, name, &vtune_jit_desc); + iJIT_NotifyEvent(LIBXSMM_VTUNE_JIT_LOAD, &vtune_jit_desc); + info->code_id = code_id; + } + else { + info->code_id = 0; + } +#endif +#if defined(LIBXSMM_PERF) + /* If JIT is enabled and a valid name is given, emit information for profiler + * In jitdump case this needs to be done after mprotect as it gets overwritten + * otherwise. */ + libxsmm_perf_dump_code(code_ptr, size, name); +#endif + } + if (NULL != info->reloc && info->pointer != info->reloc) { +#if defined(_WIN32) + /* TODO: implement memory protection under Microsoft Windows */ +#else + /* memory is already protected at this point; relocate code */ + LIBXSMM_ASSERT(0 != (LIBXSMM_MALLOC_FLAG_MMAP & flags)); + *memory = code_ptr; /* relocate */ + info->pointer = info->reloc; + info->reloc = NULL; +# if !defined(LIBXSMM_MALLOC_CRC_OFF) /* update checksum */ +# if defined(LIBXSMM_MALLOC_CRC_LIGHT) + { const internal_malloc_info_type *const code_info = internal_malloc_info(code_ptr, 0/*no check*/); + info->hash = LIBXSMM_CRC32U(LIBXSMM_BITS)(LIBXSMM_MALLOC_SEED, &code_info); + } +# else + info->hash = libxsmm_crc32(LIBXSMM_MALLOC_SEED, info, + /* info size minus actual hash value */ + (unsigned int)(((char*)&info->hash) - ((char*)info))); +# endif +# endif /* treat memory protection errors as soft error; ignore return value */ + munmap(buffer, alloc_size); +#endif + } +#if !defined(_WIN32) + else { /* malloc-based fallback */ + int mprotect_result; +# if !defined(LIBXSMM_MALLOC_CRC_OFF) && defined(LIBXSMM_VTUNE) /* check checksum */ +# if defined(LIBXSMM_MALLOC_CRC_LIGHT) + assert(info->hash == LIBXSMM_CRC32U(LIBXSMM_BITS)(LIBXSMM_MALLOC_SEED, &info)); /* !LIBXSMM_ASSERT */ +# else + assert(info->hash == libxsmm_crc32(LIBXSMM_MALLOC_SEED, info, /* !LIBXSMM_ASSERT */ + /* info size minus actual hash value */ + (unsigned int)(((char*)&info->hash) - ((char*)info)))); +# endif +# endif /* treat memory protection errors as soft error; ignore return value */ + mprotect_result = mprotect(buffer, alloc_size/*entire memory region*/, PROT_READ | PROT_EXEC); + if (EXIT_SUCCESS != mprotect_result) { + if (0 != libxsmm_se) { /* hard-error in case of SELinux */ + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: failed to allocate an executable buffer!\n"); + } + result = mprotect_result; + } + else if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: read-only request for JIT-buffer failed!\n"); + } + } + } +#endif + } + } + } + else if (NULL == memory || NULL == *memory) { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: libxsmm_malloc_attrib failed because NULL cannot be attributed!\n"); + } + result = EXIT_FAILURE; + } + else if ((LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM WARNING: %s buffer %p does not match!\n", + 0 != (LIBXSMM_MALLOC_FLAG_X & flags) ? "executable" : "memory", *memory); + } + return result; +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_MALLOC void* libxsmm_aligned_malloc(size_t size, size_t alignment) +{ + void* result = NULL; + LIBXSMM_INIT + if (2 > internal_malloc_kind) { +#if !defined(NDEBUG) + int status = +#endif + libxsmm_xmalloc(&result, size, alignment, LIBXSMM_MALLOC_FLAG_DEFAULT, NULL/*extra*/, 0/*extra_size*/); + assert(EXIT_SUCCESS == status || NULL == result); /* !LIBXSMM_ASSERT */ + } + else { /* scratch */ + const void *const caller = libxsmm_trace_caller_id(0/*level*/); + internal_scratch_malloc(&result, size, alignment, LIBXSMM_MALLOC_FLAG_DEFAULT, caller); + } + return result; +} + + +LIBXSMM_API void* libxsmm_realloc(size_t size, void* ptr) +{ + const int nzeros = LIBXSMM_INTRINSICS_BITSCANFWD64((uintptr_t)ptr), alignment = 1 << nzeros; + LIBXSMM_ASSERT(0 == ((uintptr_t)ptr & ~(0xFFFFFFFFFFFFFFFF << nzeros))); + LIBXSMM_INIT + if (2 > internal_malloc_kind) { +#if !defined(NDEBUG) + int status = +#endif + libxsmm_xmalloc(&ptr, size, alignment, LIBXSMM_MALLOC_FLAG_REALLOC, NULL/*extra*/, 0/*extra_size*/); + assert(EXIT_SUCCESS == status || NULL == ptr); /* !LIBXSMM_ASSERT */ + } + else { /* scratch */ + const void *const caller = libxsmm_trace_caller_id(0/*level*/); + internal_scratch_malloc(&ptr, size, alignment, LIBXSMM_MALLOC_FLAG_REALLOC, caller); + } + return ptr; +} + + +LIBXSMM_API void* libxsmm_scratch_malloc(size_t size, size_t alignment, const void* caller) +{ + void* result; + LIBXSMM_INIT + internal_scratch_malloc(&result, size, alignment, + LIBXSMM_MALLOC_INTERNAL_CALLER != caller ? LIBXSMM_MALLOC_FLAG_DEFAULT : LIBXSMM_MALLOC_FLAG_PRIVATE, + caller); + return result; +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_MALLOC void* libxsmm_malloc(size_t size) +{ + return libxsmm_aligned_malloc(size, 0/*auto*/); +} + + +LIBXSMM_API void libxsmm_free(const void* memory) +{ + if (NULL != memory) { +#if defined(LIBXSMM_MALLOC_SCRATCH_DELETE_FIRST) || /* prefer safe method if possible */ \ + !defined(LIBXSMM_MALLOC_HOOK) +# if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + internal_malloc_pool_type *const pool = internal_scratch_malloc_pool(memory); + if (NULL != pool) { /* memory belongs to scratch domain */ + internal_scratch_free(memory, pool); + } + else +# endif + { /* local */ + libxsmm_xfree(memory, 2/*check*/); + } +#else /* lookup matching pool */ + internal_malloc_info_type *const info = internal_malloc_info(memory, 2/*check*/); + static int error_once = 0; + if (NULL != info && 0 == (LIBXSMM_MALLOC_FLAG_SCRATCH & info->flags)) { /* !libxsmm_free */ +# if !defined(NDEBUG) + if (EXIT_SUCCESS != internal_xfree(memory, info) /* invalidates info */ + && 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: memory deallocation failed!\n"); + } +# else + internal_xfree(memory, info); /* !libxsmm_free, invalidates info */ +# endif + } + else { +# if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + internal_malloc_pool_type *const pool = internal_scratch_malloc_pool(memory); + if (NULL != pool) { /* memory belongs to scratch domain */ + internal_scratch_free(memory, pool); + } + else +# endif + { +# if defined(NDEBUG) && defined(LIBXSMM_MALLOC_HOOK) + __real_free((void*)memory); +# else +# if defined(LIBXSMM_MALLOC_HOOK) + __real_free((void*)memory); +# endif + if (0 != libxsmm_verbosity && /* library code is expected to be mute */ + 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: deallocation does not match allocation!\n"); + } +# endif + } + } +#endif + } +} + + +LIBXSMM_API_INTERN void libxsmm_xrelease_scratch(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock) +{ +#if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + internal_malloc_pool_type* pools = NULL; + libxsmm_scratch_info scratch_info; + LIBXSMM_ASSERT(libxsmm_scratch_pools <= LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS); + if (NULL != lock) { + LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK, lock); + } +# if defined(LIBXSMM_MALLOC_DELETE_SAFE) + if (0 == (internal_malloc_kind & 1) || 0 >= internal_malloc_kind) +# endif + { + unsigned int i; + pools = (internal_malloc_pool_type*)LIBXSMM_UP2( + (uintptr_t)internal_malloc_pool_buffer, LIBXSMM_MALLOC_SCRATCH_PADDING); + for (i = 0; i < libxsmm_scratch_pools; ++i) { + if (0 != pools[i].instance.minsize) { + if ( +# if !defined(LIBXSMM_MALLOC_SCRATCH_DELETE_FIRST) + 1 < /*LIBXSMM_ATOMIC_LOAD(&*/pools[i].instance.counter/*, LIBXSMM_ATOMIC_SEQ_CST)*/ && +# endif + NULL != pools[i].instance.buffer) + { + internal_malloc_info_type *const info = internal_malloc_info(pools[i].instance.buffer, 2/*check*/); + if (NULL != info) internal_xfree(info->pointer, info); /* invalidates info */ + } + } + else break; /* early exit */ + } + } + LIBXSMM_EXPECT(EXIT_SUCCESS, libxsmm_get_scratch_info(&scratch_info)); + if (0 != scratch_info.npending && /* library code is expected to be mute */ + (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity)) + { + char pending_size_buffer[32]; + libxsmm_format_value(pending_size_buffer, sizeof(pending_size_buffer), + internal_malloc_public_cur + internal_malloc_local_cur, "KM", "B", 10); + fprintf(stderr, "LIBXSMM WARNING: %s pending scratch-memory by %" PRIuPTR " allocation%s!\n", + pending_size_buffer, (uintptr_t)scratch_info.npending, 1 < scratch_info.npending ? "s" : ""); + } + if (NULL != pools) { + memset(pools, 0, (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) * sizeof(internal_malloc_pool_type)); + /* no reset: keep private watermark (internal_malloc_private_max, internal_malloc_private_cur) */ + internal_malloc_public_max = internal_malloc_public_cur = 0; + internal_malloc_local_max = internal_malloc_local_cur = 0; + internal_malloc_scratch_nmallocs = 0; + } + if (NULL != lock) { + LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK, lock); + } +#endif +} + + +LIBXSMM_API void libxsmm_release_scratch(void) +{ + libxsmm_xrelease_scratch(&libxsmm_lock_global); +} + + +LIBXSMM_API int libxsmm_get_malloc_info(const void* memory, libxsmm_malloc_info* info) +{ + int result = EXIT_SUCCESS; + if (NULL != info) { + size_t size; + result = libxsmm_get_malloc_xinfo(memory, &size, NULL/*flags*/, NULL/*extra*/); + LIBXSMM_MEMZERO127(info); + if (EXIT_SUCCESS == result) { + info->size = size; + } +#if !defined(NDEBUG) /* library code is expected to be mute */ + else if (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) { + static int error_once = 0; + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { + fprintf(stderr, "LIBXSMM WARNING: foreign memory buffer %p discovered!\n", memory); + } + } +#endif + } + else { + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API int libxsmm_get_scratch_info(libxsmm_scratch_info* info) +{ + int result = EXIT_SUCCESS; + if (NULL != info) { +#if defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + LIBXSMM_MEMZERO127(info); + info->nmallocs = internal_malloc_scratch_nmallocs; + info->internal = internal_malloc_private_max; + info->local = internal_malloc_local_max; + info->size = internal_malloc_public_max; + { const internal_malloc_pool_type* pool = (const internal_malloc_pool_type*)LIBXSMM_UP2( + (uintptr_t)internal_malloc_pool_buffer, LIBXSMM_MALLOC_SCRATCH_PADDING); +# if (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + const internal_malloc_pool_type *const end = pool + libxsmm_scratch_pools; + LIBXSMM_ASSERT(libxsmm_scratch_pools <= LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS); + for (; pool != end; ++pool) if ((LIBXSMM_MALLOC_INTERNAL_CALLER) != pool->instance.site) { +# endif + if (0 != pool->instance.minsize) { + const size_t npending = /*LIBXSMM_ATOMIC_LOAD(&*/pool->instance.counter/*, LIBXSMM_ATOMIC_RELAXED)*/; +# if defined(LIBXSMM_MALLOC_SCRATCH_DELETE_FIRST) + info->npending += npending; +# else + info->npending += 1 < npending ? (npending - 1) : 0; +# endif + ++info->npools; + } +# if (1 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)) + else break; /* early exit */ + } +# endif + } +#else + LIBXSMM_MEMZERO127(info); +#endif /*defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS) && (0 < (LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS))*/ + } + else { + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API void libxsmm_set_scratch_limit(size_t nbytes) +{ + /* !LIBXSMM_INIT */ + internal_malloc_scratch_limit = nbytes; +} + + +LIBXSMM_API size_t libxsmm_get_scratch_limit(void) +{ + size_t result; + /* !LIBXSMM_INIT */ + if (LIBXSMM_SCRATCH_DEFAULT != internal_malloc_scratch_limit) { + result = internal_malloc_scratch_limit; + } + else if (0 == internal_malloc_kind) { + result = LIBXSMM_MALLOC_SCRATCH_LIMIT; + } + else { + result = LIBXSMM_SCRATCH_UNLIMITED; + } + return result; +} + + +LIBXSMM_API void libxsmm_set_malloc(int enabled, const size_t* lo, const size_t* hi) +{ + /* !LIBXSMM_INIT */ +#if defined(LIBXSMM_MALLOC_HOOK) && defined(LIBXSMM_MALLOC) && (0 != LIBXSMM_MALLOC) +# if (0 < LIBXSMM_MALLOC) + LIBXSMM_UNUSED(enabled); + internal_malloc_kind = LIBXSMM_MALLOC; +# else + internal_malloc_kind = enabled; +# endif + /* setup lo/hi after internal_malloc_kind! */ + if (NULL != lo) internal_malloc_limit[0] = *lo; + if (NULL != hi) { + const size_t scratch_limit = libxsmm_get_scratch_limit(); + const size_t malloc_upper = LIBXSMM_MIN(*hi, scratch_limit); + internal_malloc_limit[1] = LIBXSMM_MAX(malloc_upper, internal_malloc_limit[0]); + } +#else + LIBXSMM_UNUSED(lo); LIBXSMM_UNUSED(hi); + internal_malloc_kind = enabled; +#endif + libxsmm_malloc_init(); +} + + +LIBXSMM_API int libxsmm_get_malloc(size_t* lo, size_t* hi) +{ + LIBXSMM_INIT +#if defined(LIBXSMM_MALLOC_HOOK) && defined(LIBXSMM_MALLOC) && (0 != LIBXSMM_MALLOC) + if (NULL != lo) *lo = internal_malloc_limit[0]; + if (NULL != hi) *hi = internal_malloc_limit[1]; +#else + if (NULL != lo) *lo = 0; + if (NULL != hi) *hi = 0; +#endif + return internal_malloc_kind; +} + diff --git a/third_party/libxsmm/src/libxsmm_math.c b/third_party/libxsmm/src/libxsmm_math.c new file mode 100644 index 0000000000000000000000000000000000000000..7c8f8e36813055d99aa1b5cac511239b2d9a32e6 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_math.c @@ -0,0 +1,569 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include +#include "libxsmm_main.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#if !defined(LIBXSMM_NO_LIBM) +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#define LIBXSMM_MATDIFF_DIV(NOMINATOR, DENREF, DENTST) \ + (0 < (DENREF) ? ((NOMINATOR) / (DENREF)) : \ + (0 < (DENTST) ? ((NOMINATOR) / (DENTST)) : 0)) + + +LIBXSMM_API int libxsmm_matdiff(libxsmm_matdiff_info* info, + libxsmm_datatype datatype, libxsmm_blasint m, libxsmm_blasint n, const void* ref, const void* tst, + const libxsmm_blasint* ldref, const libxsmm_blasint* ldtst) +{ + int result = EXIT_SUCCESS, result_swap = 0, result_nan = 0; + libxsmm_blasint ldr = (NULL == ldref ? m : *ldref), ldt = (NULL == ldtst ? m : *ldtst); + if (NULL == ref && NULL != tst) { ref = tst; tst = NULL; result_swap = 1; } + if (NULL != ref && NULL != info && m <= ldr && m <= ldt) { + const size_t ntotal = (size_t)m * n; + libxsmm_blasint mm = m, nn = n; + double inf; + if (1 == n) { mm = ldr = ldt = 1; nn = m; } /* ensure row-vector shape to standardize results */ + libxsmm_matdiff_clear(info); + inf = info->min_ref; + switch (datatype) { + case LIBXSMM_DATATYPE_F64: { +# define LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE double +# include "template/libxsmm_matdiff.tpl.c" +# undef LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE + } break; + case LIBXSMM_DATATYPE_F32: { +# define LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE float +# include "template/libxsmm_matdiff.tpl.c" +# undef LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE + } break; + case LIBXSMM_DATATYPE_I32: { +# define LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE int +# include "template/libxsmm_matdiff.tpl.c" +# undef LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE + } break; + case LIBXSMM_DATATYPE_I16: { +# define LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE short +# include "template/libxsmm_matdiff.tpl.c" +# undef LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE + } break; + case LIBXSMM_DATATYPE_I8: { +# define LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE signed char +# include "template/libxsmm_matdiff.tpl.c" +# undef LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE + } break; + default: { + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: unsupported data-type requested!\n"); + } + result = EXIT_FAILURE; + } + } + LIBXSMM_ASSERT((0 <= info->m && 0 <= info->n) || (0 > info->m && 0 > info->n)); + LIBXSMM_ASSERT(info->m < mm && info->n < nn); + if (EXIT_SUCCESS == result) { + const char *const env = getenv("LIBXSMM_DUMP"); + LIBXSMM_INIT + if (NULL != env && 0 != *env && '0' != *env) { + if ('-' != *env || (0 <= info->m && 0 <= info->n)) { + const char *const defaultname = (('0' < *env && '9' >= *env) || '-' == *env) ? "libxsmm_dump" : env; + const libxsmm_mhd_elemtype type_src = (libxsmm_mhd_elemtype)datatype; + const libxsmm_mhd_elemtype type_dst = LIBXSMM_MIN(LIBXSMM_MHD_ELEMTYPE_F32, type_src); + const int envi = atoi(env), reshape = (1 < envi || -1 > envi); + size_t shape[2], size[2]; + char filename[256]; + if (0 == reshape) { + shape[0] = (size_t)mm; shape[1] = (size_t)nn; + size[0] = (size_t)ldr; size[1] = (size_t)nn; + } + else { /* reshape */ + const size_t y = (size_t)libxsmm_isqrt2_u32((unsigned int)ntotal); + shape[0] = ntotal / y; shape[1] = y; + size[0] = shape[0]; + size[1] = shape[1]; + } + LIBXSMM_SNPRINTF(filename, sizeof(filename), "%s-%p-ref.mhd", defaultname, ref); + libxsmm_mhd_write(filename, NULL/*offset*/, shape, size, 2/*ndims*/, 1/*ncomponents*/, + type_src, &type_dst, ref, NULL/*header_size*/, NULL/*extension_header*/, + NULL/*extension*/, 0/*extension_size*/); + if (NULL != tst) { + if (0 == reshape) { + size[0] = (size_t)ldt; + size[1] = (size_t)nn; + } + LIBXSMM_SNPRINTF(filename, sizeof(filename), "%s-%p-tst.mhd", defaultname, ref/*adopt ref-ptr*/); + libxsmm_mhd_write(filename, NULL/*offset*/, shape, size, 2/*ndims*/, 1/*ncomponents*/, + type_src, &type_dst, tst, NULL/*header_size*/, NULL/*extension_header*/, + NULL/*extension*/, 0/*extension_size*/); + if ('-' == *env && '1' < env[1]) { + printf("LIBXSMM MATDIFF (%s): m=%" PRIuPTR " n=%" PRIuPTR " ldi=%" PRIuPTR " ldo=%" PRIuPTR " failed.\n", + libxsmm_typename(datatype), (uintptr_t)m, (uintptr_t)n, (uintptr_t)ldr, (uintptr_t)ldt); + } + } + } + else if ('-' == *env && '1' < env[1] && NULL != tst) { + printf("LIBXSMM MATDIFF (%s): m=%" PRIuPTR " n=%" PRIuPTR " ldi=%" PRIuPTR " ldo=%" PRIuPTR " passed.\n", + libxsmm_typename(datatype), (uintptr_t)m, (uintptr_t)n, (uintptr_t)ldr, (uintptr_t)ldt); + } + } + if (0 == result_nan) { + info->rsq = 1.0 - LIBXSMM_MATDIFF_DIV(info->l2_abs, info->var_ref, info->var_tst); + if (0 != ntotal) { /* final variance */ + info->var_ref /= ntotal; + info->var_tst /= ntotal; + } + info->normf_rel = libxsmm_dsqrt(info->normf_rel); + info->l2_abs = libxsmm_dsqrt(info->l2_abs); + info->l2_rel = libxsmm_dsqrt(info->l2_rel); + } + else if (1 == result_nan) { + /* in case of NaN in test-set, statistics is not set to inf (ref/test) */ + info->norm1_abs = info->norm1_rel = info->normi_abs = info->normi_rel = info->normf_rel + = info->linf_abs = info->linf_rel = info->l2_abs = info->l2_rel + = inf; + } + if (1 == n) LIBXSMM_ISWAP(info->m, info->n); + if (0 != result_swap) { + info->min_tst = info->min_ref; + info->min_ref = 0; + info->max_tst = info->max_ref; + info->max_ref = 0; + info->avg_tst = info->avg_ref; + info->avg_ref = 0; + info->var_tst = info->var_ref; + info->var_ref = 0; + info->l1_tst = info->l1_ref; + info->l1_ref = 0; + info->v_tst = info->v_ref; + info->v_ref = 0; + } + } + } + else { + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API void libxsmm_matdiff_reduce(libxsmm_matdiff_info* output, const libxsmm_matdiff_info* input) +{ + if (NULL != output && NULL != input) { + if (output->linf_abs < input->linf_abs) { + output->linf_abs = input->linf_abs; + output->linf_rel = input->linf_rel; + output->v_ref = input->v_ref; + output->v_tst = input->v_tst; + LIBXSMM_ASSERT(0 <= input->m); + output->m = input->m; + LIBXSMM_ASSERT(0 <= input->n); + output->n = input->n; + } + if (output->norm1_abs < input->norm1_abs) { + output->norm1_abs = input->norm1_abs; + output->norm1_rel = input->norm1_rel; + } + if (output->normi_abs < input->normi_abs) { + output->normi_abs = input->normi_abs; + output->normi_rel = input->normi_rel; + } + if (output->l2_abs < input->l2_abs) { + output->l2_abs = input->l2_abs; + output->l2_rel = input->l2_rel; + output->rsq = input->rsq; + } + if (output->normf_rel < input->normf_rel) { + output->normf_rel = input->normf_rel; + } + if (output->var_ref < input->var_ref) { + output->var_ref = input->var_ref; + } + if (output->var_tst < input->var_tst) { + output->var_tst = input->var_tst; + } + if (output->max_ref < input->max_ref) { + output->max_ref = input->max_ref; + } + if (output->max_tst < input->max_tst) { + output->max_tst = input->max_tst; + } + if (output->min_ref > input->min_ref) { + output->min_ref = input->min_ref; + } + if (output->min_tst > input->min_tst) { + output->min_tst = input->min_tst; + } + output->avg_ref = 0.5 * (output->avg_ref + input->avg_ref); + output->avg_tst = 0.5 * (output->avg_tst + input->avg_tst); + output->l1_ref += input->l1_ref; + output->l1_tst += input->l1_tst; + } + else { + libxsmm_matdiff_clear(output); + } +} + + +LIBXSMM_API void libxsmm_matdiff_clear(libxsmm_matdiff_info* info) +{ + if (NULL != info) { + union { int raw; float value; } inf; +#if defined(INFINITY) && /*overflow warning*/!defined(_CRAYC) + inf.value = (float)(INFINITY); +#else + inf.raw = 0x7F800000; +#endif + memset(info, 0, sizeof(*info)); /* nullify */ + /* no location discovered yet with a difference */ + info->m = info->n = -1; + /* initial minimum/maximum of reference/test */ + info->min_ref = info->min_tst = +inf.value; + info->max_ref = info->max_tst = -inf.value; + } +} + + +LIBXSMM_API size_t libxsmm_shuffle(unsigned int n) +{ + const unsigned int s = (0 != (n & 1) ? ((n / 2 - 1) | 1) : ((n / 2) & ~1)); + const unsigned int d = (0 != (n & 1) ? 1 : 2); + unsigned int result = (1 < n ? 1 : 0), i; + for (i = (d < n ? (n - 1) : 0); d < i; i -= d) { + unsigned int c = (s <= i ? (i - s) : (s - i)); + unsigned int a = n, b = c; + do { + const unsigned int r = a % b; + a = b; + b = r; + } while (0 != b); + if (1 == a) { + result = c; + if (2 * c <= n) { + i = d; /* break */ + } + } + } + assert((0 == result && 1 >= n) || (result < n && 1 == libxsmm_gcd(result, n))); + return result; +} + + +LIBXSMM_API unsigned int libxsmm_isqrt_u64(unsigned long long x) +{ + unsigned long long b; unsigned int y = 0, s; + for (s = 0x80000000/*2^31*/; 0 < s; s >>= 1) { + b = y | s; y |= (b * b <= x ? s : 0); + } + return y; +} + + +LIBXSMM_API unsigned int libxsmm_isqrt_u32(unsigned int x) +{ + unsigned int b; unsigned int y = 0; int s; + for (s = 0x40000000/*2^30*/; 0 < s; s >>= 2) { + b = y | s; y >>= 1; + if (b <= x) { x -= b; y |= s; } + } + return y; +} + + +LIBXSMM_API unsigned int libxsmm_isqrt2_u32(unsigned int x) +{ + return libxsmm_product_limit(x, libxsmm_isqrt_u32(x), 0/*is_lower*/); +} + + +LIBXSMM_API double libxsmm_kahan_sum(double value, double* accumulator, double* compensation) +{ + double r, c; + LIBXSMM_ASSERT(NULL != accumulator && NULL != compensation); + c = value - *compensation; r = *accumulator + c; + *compensation = (r - *accumulator) - c; + *accumulator = r; + return r; +} + + +LIBXSMM_API LIBXSMM_INTRINSICS(LIBXSMM_X86_GENERIC) double libxsmm_dsqrt(double x) +{ +#if defined(LIBXSMM_INTRINSICS_X86) && !defined(__PGI) + const __m128d a = LIBXSMM_INTRINSICS_MM_UNDEFINED_PD(); + const double result = _mm_cvtsd_f64(_mm_sqrt_sd(a, _mm_set_sd(x))); +#elif !defined(LIBXSMM_NO_LIBM) + const double result = sqrt(x); +#else /* fallback */ + double result, y = x; + if (LIBXSMM_NEQ(0, x)) { + do { + result = y; + y = 0.5 * (y + x / y); + } while (LIBXSMM_NEQ(result, y)); + } + result = y; +#endif + return result; +} + + +LIBXSMM_API LIBXSMM_INTRINSICS(LIBXSMM_X86_GENERIC) float libxsmm_ssqrt(float x) +{ +#if defined(LIBXSMM_INTRINSICS_X86) + const float result = _mm_cvtss_f32(_mm_sqrt_ss(_mm_set_ss(x))); +#elif !defined(LIBXSMM_NO_LIBM) + const float result = LIBXSMM_SQRTF(x); +#else /* fallback */ + float result, y = x; + if (LIBXSMM_NEQ(0, x)) { + do { + result = y; + y = 0.5f * (y + x / y); + } while (LIBXSMM_NEQ(result, y)); + } + result = y; +#endif + return result; +} + + +LIBXSMM_API unsigned int libxsmm_icbrt_u64(unsigned long long x) +{ + unsigned long long b; unsigned int y = 0; int s; + for (s = 63; 0 <= s; s -= 3) { + y += y; b = ((unsigned long long)y + 1) * 3 * y + 1ULL; + if (b <= (x >> s)) { x -= b << s; ++y; } + } + return y; +} + + +LIBXSMM_API unsigned int libxsmm_icbrt_u32(unsigned int x) +{ + unsigned int b; unsigned int y = 0; int s; + for (s = 30; 0 <= s; s -= 3) { + y += y; b = 3 * y * (y + 1) + 1; + if (b <= (x >> s)) { x -= b << s; ++y; } + } + return y; +} + +#if defined(LIBXSMM_NO_LIBM) +/* Implementation based on Claude Baumann's product (http://www.convict.lu/Jeunes/ultimate_stuff/exp_ln_2.htm). + * Exponential function, which exposes the number of iterations taken in the main case (1...22). + */ +LIBXSMM_API_INLINE float internal_math_sexp2(float x, int maxiter) +{ + static const float lut[] = { /* tabulated powf(2.f, powf(2.f, -index)) */ + 2.00000000f, 1.41421354f, 1.18920708f, 1.09050775f, 1.04427373f, 1.02189720f, 1.01088929f, 1.00542986f, + 1.00271130f, 1.00135469f, 1.00067711f, 1.00033855f, 1.00016928f, 1.00008464f, 1.00004232f, 1.00002110f, + 1.00001061f, 1.00000525f, 1.00000262f, 1.00000131f, 1.00000072f, 1.00000036f, 1.00000012f + }; + const int lut_size = sizeof(lut) / sizeof(*lut), lut_size1 = lut_size - 1; + int sign, temp, unbiased, exponent, mantissa; + union { int i; float s; } result; + + result.s = x; + sign = (0 == (result.i & 0x80000000) ? 0 : 1); + temp = result.i & 0x7FFFFFFF; /* clear sign */ + unbiased = (temp >> 23) - 127; /* exponent */ + exponent = -unbiased; + mantissa = (temp << 8) | 0x80000000; + + if (lut_size1 >= exponent) { + if (lut_size1 != exponent) { /* multiple lookups needed */ + if (7 >= unbiased) { /* not a degenerated case */ + const int n = (0 >= maxiter || lut_size1 <= maxiter) ? lut_size1 : maxiter; + int i = 1; + if (0 > unbiased) { /* regular/main case */ + LIBXSMM_ASSERT(0 <= exponent && exponent < lut_size); + result.s = lut[exponent]; /* initial value */ + i = exponent + 1; /* next LUT offset */ + } + else { + result.s = 2.f; /* lut[0] */ + i = 1; /* next LUT offset */ + } + for (; i <= n && 0 != mantissa; ++i) { + mantissa <<= 1; + if (0 != (mantissa & 0x80000000)) { /* check MSB */ + LIBXSMM_ASSERT(0 <= i && i < lut_size); + result.s *= lut[i]; /* TODO: normalized multiply */ + } + } + for (i = 0; i < unbiased; ++i) { /* compute squares */ + result.s *= result.s; + } + if (0 != sign) { /* negative value, so reciprocal */ + result.s = 1.f / result.s; + } + } + else { /* out of range */ +#if defined(INFINITY) && /*overflow warning*/!defined(_CRAYC) + result.s = (0 == sign ? ((float)(INFINITY)) : 0.f); +#else + result.i = (0 == sign ? 0x7F800000 : 0); +#endif + } + } + else if (0 == sign) { + result.s = lut[lut_size1]; + } + else { /* reciprocal */ + result.s = 1.f / lut[lut_size1]; + } + } + else { + result.s = 1.f; /* case 2^0 */ + } + return result.s; +} +#endif + + +LIBXSMM_API float libxsmm_sexp2(float x) +{ +#if !defined(LIBXSMM_NO_LIBM) + return LIBXSMM_EXP2F(x); +#else /* fallback */ + return internal_math_sexp2(x, 20/*compromise*/); +#endif +} + + +LIBXSMM_API float libxsmm_sexp2_u8(unsigned char x) +{ + union { int i; float s; } result; + if (128 > x) { + if (31 < x) { + const float r32 = 2.f * ((float)(1U << 31)); /* 2^32 */ + const int n = x >> 5; + int i; + result.s = r32; + for (i = 1; i < n; ++i) result.s *= r32; + result.s *= (1U << (x - (n << 5))); + } + else { + result.s = (float)(1U << x); + } + } + else { +#if defined(INFINITY) && /*overflow warning*/!defined(_CRAYC) + result.s = (float)(INFINITY); +#else + result.i = 0x7F800000; +#endif + } + return result.s; +} + + +LIBXSMM_API float libxsmm_sexp2_i8(signed char x) +{ + union { int i; float s; } result; + if (-128 != x) { + const signed char ux = (signed char)LIBXSMM_ABS(x); + if (31 < ux) { + const float r32 = 2.f * ((float)(1U << 31)); /* 2^32 */ + const int n = ux >> 5; + int i; + result.s = r32; + for (i = 1; i < n; ++i) result.s *= r32; + result.s *= (1U << (ux - (n << 5))); + } + else { + result.s = (float)(1U << ux); + } + if (ux != x) { /* signed */ + result.s = 1.f / result.s; + } + } + else { + result.i = 0x200000; + } + return result.s; +} + + +LIBXSMM_API float libxsmm_sexp2_i8i(int x) +{ + LIBXSMM_ASSERT(-128 <= x && x <= 127); + return libxsmm_sexp2_i8((signed char)x); +} + + +#if defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__)) + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_matdiff)(libxsmm_matdiff_info* /*info*/, + const int* /*datatype*/, const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const void* /*ref*/, const void* /*tst*/, + const libxsmm_blasint* /*ldref*/, const libxsmm_blasint* /*ldtst*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_matdiff)(libxsmm_matdiff_info* info, + const int* datatype, const libxsmm_blasint* m, const libxsmm_blasint* n, const void* ref, const void* tst, + const libxsmm_blasint* ldref, const libxsmm_blasint* ldtst) +{ + static int error_once = 0; + if ((NULL == datatype || LIBXSMM_DATATYPE_UNSUPPORTED <= *datatype || 0 > *datatype || NULL == m + || EXIT_SUCCESS != libxsmm_matdiff(info, (libxsmm_datatype)*datatype, *m, *(NULL != n ? n : m), ref, tst, ldref, ldtst)) + && 0 != libxsmm_verbosity && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_matdiff specified!\n"); + } +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_matdiff_reduce)(libxsmm_matdiff_info* /*output*/, const libxsmm_matdiff_info* /*input*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_matdiff_reduce)(libxsmm_matdiff_info* output, const libxsmm_matdiff_info* input) +{ + libxsmm_matdiff_reduce(output, input); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_matdiff_clear)(libxsmm_matdiff_info* /*info*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_matdiff_clear)(libxsmm_matdiff_info* info) +{ + libxsmm_matdiff_clear(info); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_shuffle)(long long* /*coprime*/, const int* /*n*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_shuffle)(long long* coprime, const int* n) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != coprime && NULL != n && 0 <= *n) +#endif + { + *coprime = (long long)(libxsmm_shuffle((unsigned int)(*n)) & 0x7FFFFFFF); + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_shuffle specified!\n"); + } +#endif +} + +#endif /*defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/ + diff --git a/third_party/libxsmm/src/libxsmm_matrixeqn.c b/third_party/libxsmm/src/libxsmm_matrixeqn.c new file mode 100644 index 0000000000000000000000000000000000000000..54d8f5effede1e7388c920155b47f9146cbbbb7b --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_matrixeqn.c @@ -0,0 +1,1265 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#include "libxsmm_matrixeqn.h" + +/* aux struct for matrix equations */ +LIBXSMM_APIVAR_DEFINE(libxsmm_matrix_eqn* libxsmm_matrix_eqns[256]); +LIBXSMM_APIVAR_DEFINE(libxsmm_blasint libxsmm_matrix_eqns_init); +LIBXSMM_APIVAR_DEFINE(libxsmm_blasint libxsmm_matrix_eqns_count); + +LIBXSMM_API_INTERN libxsmm_matrix_eqn* libxsmm_matrix_eqn_get_equation( libxsmm_blasint eqn_idx ) { + return libxsmm_matrix_eqns[eqn_idx]; +} + +LIBXSMM_API_INTERN +libxsmm_matrix_eqn_bcast_type get_bcast_type_unary(libxsmm_meltw_unary_flags flags) { + libxsmm_matrix_eqn_bcast_type result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE; + if ((flags & LIBXSMM_MELTW_FLAG_UNARY_BCAST_ROW) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_ROW; + } else if ((flags & LIBXSMM_MELTW_FLAG_UNARY_BCAST_COL) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_COL; + } else if ((flags & LIBXSMM_MELTW_FLAG_UNARY_BCAST_SCALAR) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_SCALAR; + } + return result; +} + +LIBXSMM_API_INTERN +libxsmm_matrix_eqn_bcast_type get_bcast_type_binary(libxsmm_meltw_binary_flags flags, unsigned int side) { + libxsmm_matrix_eqn_bcast_type result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE; + if (side == RIGHT) { + if ((flags & LIBXSMM_MELTW_FLAG_BINARY_BCAST_ROW_IN_1) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_ROW; + } else if ((flags & LIBXSMM_MELTW_FLAG_BINARY_BCAST_COL_IN_1) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_COL; + } else if ((flags & LIBXSMM_MELTW_FLAG_BINARY_BCAST_SCALAR_IN_1) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_SCALAR; + } + } + if (side == LEFT) { + if ((flags & LIBXSMM_MELTW_FLAG_BINARY_BCAST_ROW_IN_0) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_ROW; + } else if ((flags & LIBXSMM_MELTW_FLAG_BINARY_BCAST_COL_IN_0) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_COL; + } else if ((flags & LIBXSMM_MELTW_FLAG_BINARY_BCAST_SCALAR_IN_0) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_SCALAR; + } + } + return result; +} + +LIBXSMM_API_INTERN +libxsmm_matrix_eqn_bcast_type get_bcast_type_ternary(libxsmm_meltw_ternary_flags flags, unsigned int side) { + libxsmm_matrix_eqn_bcast_type result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE; + if (side == RIGHT2) { + if ((flags & LIBXSMM_MELTW_FLAG_TERNARY_BCAST_ROW_IN_2) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_ROW; + } else if ((flags & LIBXSMM_MELTW_FLAG_TERNARY_BCAST_COL_IN_2) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_COL; + } else if ((flags & LIBXSMM_MELTW_FLAG_TERNARY_BCAST_SCALAR_IN_2) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_SCALAR; + } + } + if (side == RIGHT) { + if ((flags & LIBXSMM_MELTW_FLAG_TERNARY_BCAST_ROW_IN_1) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_ROW; + } else if ((flags & LIBXSMM_MELTW_FLAG_TERNARY_BCAST_COL_IN_1) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_COL; + } else if ((flags & LIBXSMM_MELTW_FLAG_TERNARY_BCAST_SCALAR_IN_1) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_SCALAR; + } + } + if (side == LEFT) { + if ((flags & LIBXSMM_MELTW_FLAG_TERNARY_BCAST_ROW_IN_0) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_ROW; + } else if ((flags & LIBXSMM_MELTW_FLAG_TERNARY_BCAST_COL_IN_0) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_COL; + } else if ((flags & LIBXSMM_MELTW_FLAG_TERNARY_BCAST_SCALAR_IN_0) > 0) { + result = LIBXSMM_MATRIX_EQN_BCAST_TYPE_SCALAR; + } + } + return result; +} + +LIBXSMM_API_INTERN libxsmm_blasint can_overwrite_unary_input(libxsmm_matrix_eqn_elem* cur_node); +LIBXSMM_API_INTERN libxsmm_blasint can_overwrite_unary_input(libxsmm_matrix_eqn_elem* cur_node) { + libxsmm_blasint result = 1; + if (cur_node->info.u_op.type == LIBXSMM_MELTW_TYPE_UNARY_IDENTITY) { + result = 0; + } + if ((cur_node->le->tmp.dtype == LIBXSMM_DATATYPE_BF16) && (cur_node->tmp.dtype == LIBXSMM_DATATYPE_F32)) { + result = 0; + } + if (is_unary_opcode_transform_kernel(cur_node->info.u_op.type) > 0) { + result = 0; + } + return result; +} + +LIBXSMM_API_INTERN libxsmm_blasint can_overwrite_binary_input(libxsmm_matrix_eqn_elem* cur_node); +LIBXSMM_API_INTERN libxsmm_blasint can_overwrite_binary_input(libxsmm_matrix_eqn_elem* cur_node) { + libxsmm_blasint result = 1; + if (cur_node->info.b_op.type == LIBXSMM_MELTW_TYPE_BINARY_MATMUL) { + result = 0; + } + if (((cur_node->le->tmp.dtype == LIBXSMM_DATATYPE_BF16) || (cur_node->ri->tmp.dtype == LIBXSMM_DATATYPE_BF16)) && (cur_node->tmp.dtype == LIBXSMM_DATATYPE_F32)) { + result = 0; + } + return result; +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_trv_dbg_print( libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint indent ); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_trv_dbg_print( libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint indent ) { + libxsmm_blasint i; + libxsmm_blasint tree_print_indent = 4; + + for ( i = 0; i < indent; ++i ) { + if ( i < indent - tree_print_indent ) { + printf(" "); + } else { + if ( i % tree_print_indent == 0 ) { + printf("|"); + } else { + printf("-"); + } + } + } + + /* check if we are at an argument leaf, then we move up */ + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + libxsmm_blasint argid = cur_node->info.arg.in_pos; + if ( (cur_node->le == NULL) && (cur_node->ri == NULL) ) { + if (argid >= 0) { + printf("ARG: M=%i, N=%i, LD=%i, arg_id=%i, dtype=%i\n", cur_node->info.arg.m, cur_node->info.arg.n, cur_node->info.arg.ld, cur_node->info.arg.in_pos, LIBXSMM_TYPESIZE(cur_node->info.arg.dtype) ); + } else { + printf("ARG: M=%i, N=%i, LD=%i, arg_id is scratch=%i, dtype=%i\n", cur_node->info.arg.m, cur_node->info.arg.n, cur_node->info.arg.ld, -1-argid, LIBXSMM_TYPESIZE(cur_node->info.arg.dtype) ); + } + } else { + printf("ERROR: Arg cannot have left or right child!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + /* we have to push more in this branch */ + if ( cur_node->le != NULL ) { + printf("UNARY: type=%i, flags=%i, timestamp=%i, out_tmp_id=%i, out_dtype=%i\n", (int)cur_node->info.u_op.type, (int)cur_node->info.u_op.flags, cur_node->visit_timestamp, cur_node->tmp.id, LIBXSMM_TYPESIZE(cur_node->tmp.dtype)); + libxsmm_matrix_eqn_trv_dbg_print( cur_node->le, indent+tree_print_indent ); + /* we have reached the root, as we are unary, there is no right branch */ + } else if ( (cur_node->ri != NULL) ) { + printf("ERROR: Unary cannot have right childs!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + /* we have to push more in this branch */ + if ( (cur_node->le != NULL) && (cur_node->ri != NULL) ) { + printf("BINARY: type=%i, flags=%i, timestamp=%i, out_tmp_id=%i, out_dtype=%i\n", (int)cur_node->info.b_op.type, (int)cur_node->info.b_op.flags, cur_node->visit_timestamp, cur_node->tmp.id, LIBXSMM_TYPESIZE(cur_node->tmp.dtype)); + libxsmm_matrix_eqn_trv_dbg_print( cur_node->le, indent+tree_print_indent ); + libxsmm_matrix_eqn_trv_dbg_print( cur_node->ri, indent+tree_print_indent ); + } else { + printf("ERROR: Binary needs left and right child!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + /* we have to push more in this branch */ + if ( (cur_node->le != NULL) && (cur_node->ri != NULL) && (cur_node->r2 != NULL)) { + printf("TERNARY: type=%i, flags=%i, timestamp=%i, out_tmp_id=%i, out_dtype=%i\n", (int)cur_node->info.t_op.type, (int)cur_node->info.t_op.flags, cur_node->visit_timestamp, cur_node->tmp.id, LIBXSMM_TYPESIZE(cur_node->tmp.dtype)); + libxsmm_matrix_eqn_trv_dbg_print( cur_node->le, indent+tree_print_indent ); + libxsmm_matrix_eqn_trv_dbg_print( cur_node->ri, indent+tree_print_indent ); + libxsmm_matrix_eqn_trv_dbg_print( cur_node->r2, indent+tree_print_indent ); + } else { + printf("ERROR: Ternary needs three children!\n"); + } + } else { + /* shouldn't happen */ + } +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_assign_reg_scores( libxsmm_matrix_eqn_elem* cur_node ) { + /* check if we are at an argument leaf, then we assign register score 0 */ + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + if ( (cur_node->le == NULL) && (cur_node->ri == NULL) ) { + cur_node->reg_score = 0; + } + else { + printf("ERROR: Arg cannot have left or right child!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + /* If the node is unary type we have the following cases: + * 1) If the left child is an arg, we just set the score to 1 (we do not overwrite the input) + * 2) if the left child is NOT an arg AND we can overwrite the tmp, we just propagate the register score from it (no additional tmp storage is needed) + * 3) if the left child is NOT an arg AND we CAN NOT overwrite the tmp, we should make the register score at least 2 + * */ + if ( cur_node->le != NULL ) { + libxsmm_matrix_eqn_assign_reg_scores( cur_node->le ); + if ( cur_node->le->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + cur_node->reg_score = 1; + } else { + if (can_overwrite_unary_input(cur_node) > 0) { + cur_node->reg_score = cur_node->le->reg_score; + } else { + cur_node->reg_score = LIBXSMM_MAX(2, cur_node->le->reg_score); + } + } + /* we have reached the root, as we are unary, there is no right branch */ + } else if ( (cur_node->ri != NULL) ) { + printf("ERROR: Unary cannot have right childs!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + if ( (cur_node->le != NULL) && (cur_node->ri != NULL) ) { + libxsmm_matrix_eqn_assign_reg_scores( cur_node->le ); + libxsmm_matrix_eqn_assign_reg_scores( cur_node->ri ); + + /* If left and right are args, we just need 1 tmp */ + if ( (cur_node->le->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type == LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + cur_node->reg_score = 1; + } else { + if (can_overwrite_binary_input(cur_node) > 0) { + /* If the node is binary type we have two cases: + * 1) If the left/right subtrees have the same register score, we have to increase it by one (i.e. we have to first compute one of the subtrees and keep the result in a tmp storage and then compute the other subtree, so we would need an extra tmp storage) + * 2) If the left/right subtrees DO NOT have the same register score, then we assign the maximum of the register scores (i.e. we would compute first the subtree with the maximum score and then the tree with the smallest score, thus no extra tmp storage is required) */ + if (cur_node->le->reg_score == cur_node->ri->reg_score) { + cur_node->reg_score = cur_node->le->reg_score + 1; + } else { + cur_node->reg_score = LIBXSMM_MAX(cur_node->le->reg_score, cur_node->ri->reg_score); + } + } else { + if (cur_node->le->reg_score == cur_node->ri->reg_score) { + cur_node->reg_score = LIBXSMM_MAX(3, cur_node->le->reg_score + 1); + } else { + cur_node->reg_score = LIBXSMM_MAX(3, LIBXSMM_MAX(cur_node->le->reg_score, cur_node->ri->reg_score)); + } + } + } + } else { + printf("ERROR: Binary needs left and right child!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + if ( (cur_node->le != NULL) && (cur_node->ri != NULL) && (cur_node->r2 != NULL) ) { + int use_r2_as_output = ((cur_node->info.t_op.flags & LIBXSMM_MELTW_FLAG_TERNARY_REUSE_IN_2_AS_OUT) > 0) ? 1 : 0; + libxsmm_matrix_eqn_assign_reg_scores( cur_node->le ); + libxsmm_matrix_eqn_assign_reg_scores( cur_node->ri ); + libxsmm_matrix_eqn_assign_reg_scores( cur_node->r2 ); + /* If all children re args, we just need 1 tmp */ + if ( (cur_node->le->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->r2->type == LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + cur_node->reg_score = 1; + } else { + if (use_r2_as_output > 0) { + cur_node->reg_score = LIBXSMM_MAX(3, LIBXSMM_MAX(LIBXSMM_MAX(cur_node->le->reg_score, cur_node->ri->reg_score), cur_node->r2->reg_score)); + } else { + cur_node->reg_score = LIBXSMM_MAX(4, LIBXSMM_MAX(LIBXSMM_MAX(cur_node->le->reg_score, cur_node->ri->reg_score), cur_node->r2->reg_score)); + } + } + } else { + printf("ERROR: Ternary needs all three children!\n"); + } + } else { + /* shouldn't happen */ + } +} + +LIBXSMM_API_INTERN +void libxsmm_generator_assign_new_timestamp(libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *current_timestamp ) { + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + /* Do not increase the timestamp, this node is just an arg so it's not part of the execution */ + cur_node->visit_timestamp = -1; + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + libxsmm_generator_assign_new_timestamp( cur_node->le, current_timestamp ); + cur_node->visit_timestamp = *current_timestamp; + *current_timestamp = *current_timestamp + 1; + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + if (cur_node->le->reg_score >= cur_node->ri->reg_score) { + libxsmm_generator_assign_new_timestamp( cur_node->le, current_timestamp ); + libxsmm_generator_assign_new_timestamp( cur_node->ri, current_timestamp ); + } else { + libxsmm_generator_assign_new_timestamp( cur_node->ri, current_timestamp ); + libxsmm_generator_assign_new_timestamp( cur_node->le, current_timestamp ); + } + cur_node->visit_timestamp = *current_timestamp; + *current_timestamp = *current_timestamp + 1; + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + if ((cur_node->le->reg_score >= cur_node->ri->reg_score) && (cur_node->le->reg_score >= cur_node->r2->reg_score) ) { + libxsmm_generator_assign_new_timestamp( cur_node->le, current_timestamp ); + if ( cur_node->ri->reg_score >= cur_node->r2->reg_score ) { + libxsmm_generator_assign_new_timestamp( cur_node->ri, current_timestamp ); + libxsmm_generator_assign_new_timestamp( cur_node->r2, current_timestamp ); + } else { + libxsmm_generator_assign_new_timestamp( cur_node->r2, current_timestamp ); + libxsmm_generator_assign_new_timestamp( cur_node->ri, current_timestamp ); + } + } else if ((cur_node->ri->reg_score >= cur_node->le->reg_score) && (cur_node->ri->reg_score >= cur_node->r2->reg_score) ) { + libxsmm_generator_assign_new_timestamp( cur_node->ri, current_timestamp ); + if ( cur_node->le->reg_score >= cur_node->r2->reg_score ) { + libxsmm_generator_assign_new_timestamp( cur_node->le, current_timestamp ); + libxsmm_generator_assign_new_timestamp( cur_node->r2, current_timestamp ); + } else { + libxsmm_generator_assign_new_timestamp( cur_node->r2, current_timestamp ); + libxsmm_generator_assign_new_timestamp( cur_node->le, current_timestamp ); + } + } else { + libxsmm_generator_assign_new_timestamp( cur_node->r2, current_timestamp ); + if ( cur_node->le->reg_score >= cur_node->ri->reg_score ) { + libxsmm_generator_assign_new_timestamp( cur_node->le, current_timestamp ); + libxsmm_generator_assign_new_timestamp( cur_node->ri, current_timestamp ); + } else { + libxsmm_generator_assign_new_timestamp( cur_node->ri, current_timestamp ); + libxsmm_generator_assign_new_timestamp( cur_node->le, current_timestamp ); + } + } + } else { + /* shouldn't happen */ + } +} + +LIBXSMM_API_INTERN +void libxsmm_generator_matequation_assign_timestamps(libxsmm_matrix_eqn *eqn) { + libxsmm_blasint timestamp = 0; + libxsmm_generator_assign_new_timestamp(eqn->eqn_root, ×tamp ); +} + +LIBXSMM_API_INTERN libxsmm_blasint reserve_tmp_storage(libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool) { + libxsmm_blasint i; + if ( tmp_storage_pool != NULL ) { + for (i = 0; i < n_max_tmp; i++) { + if (tmp_storage_pool[i] == 0) { + tmp_storage_pool[i] = 1; + return i; + } + } + } + return -1; +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_configure_unary_tmp(libxsmm_matrix_eqn_elem* cur_node); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_configure_unary_tmp(libxsmm_matrix_eqn_elem* cur_node) { + cur_node->tmp.m = cur_node->le->tmp.m; + cur_node->tmp.n = cur_node->le->tmp.n; + cur_node->tmp.ld = cur_node->le->tmp.ld; + cur_node->tmp.dtype = cur_node->info.u_op.dtype; +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_configure_binary_tmp(libxsmm_matrix_eqn_elem* cur_node); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_configure_binary_tmp(libxsmm_matrix_eqn_elem* cur_node) { + cur_node->tmp.m = cur_node->le->tmp.m; + cur_node->tmp.ld = cur_node->le->tmp.ld; + if (cur_node->info.b_op.type == LIBXSMM_MELTW_TYPE_BINARY_MATMUL) { + cur_node->tmp.n = cur_node->ri->tmp.n; + } else { + cur_node->tmp.n = cur_node->le->tmp.n; + } + cur_node->tmp.dtype = cur_node->info.b_op.dtype; +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_configure_ternary_tmp(libxsmm_matrix_eqn_elem* cur_node); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_configure_ternary_tmp(libxsmm_matrix_eqn_elem* cur_node) { + cur_node->tmp.m = cur_node->r2->tmp.m; + cur_node->tmp.n = cur_node->r2->tmp.n; + cur_node->tmp.ld = cur_node->r2->tmp.ld; + if (cur_node->info.t_op.type == LIBXSMM_MELTW_TYPE_TERNARY_MATMUL) { + cur_node->tmp.m = cur_node->r2->tmp.m; + cur_node->tmp.n = cur_node->r2->tmp.n; + cur_node->tmp.ld = cur_node->r2->tmp.ld; + } + cur_node->tmp.dtype = cur_node->info.t_op.dtype; +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_visit_arg_node(libxsmm_matrix_eqn_elem* cur_node); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_visit_arg_node(libxsmm_matrix_eqn_elem* cur_node) { + /* Do not increase the timestamp, this node is just an arg so it's not part of the execution */ + cur_node->visit_timestamp = -1; + cur_node->n_args = 1; + cur_node->max_tmp_size = cur_node->info.arg.ld * cur_node->info.arg.n; + cur_node->tmp.m = cur_node->info.arg.m; + cur_node->tmp.n = cur_node->info.arg.n; + cur_node->tmp.ld = cur_node->info.arg.ld; + cur_node->tmp.dtype = cur_node->info.arg.dtype; +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_visit_unary_node(libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *global_timestamp, libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_visit_unary_node(libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *global_timestamp, libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool) { + /* Assign timestamp and propagate info for n_args/max_tmp_size */ + cur_node->visit_timestamp = *global_timestamp; + *global_timestamp = *global_timestamp + 1; + cur_node->n_args = cur_node->le->n_args; + cur_node->max_tmp_size = cur_node->le->max_tmp_size; + /* When assigning the tmp output storage, we have two cases in the unary: + * 1) The child is an arg, so we have to reserve a tmp storage + * 2) The child is NOT an arg, so we just reuse the tmp storage of the child IF we are allowed to overwrite */ + if ( cur_node->le->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + cur_node->tree_max_comp_tsize = LIBXSMM_TYPESIZE( cur_node->info.u_op.dtype ); + } else { + if (can_overwrite_unary_input(cur_node) > 0) { + cur_node->tmp.id = cur_node->le->tmp.id; + } else { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->le->tmp.id] = 0; + } + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE(cur_node->info.u_op.dtype), cur_node->le->tree_max_comp_tsize ); + } + libxsmm_matrix_eqn_exec_plan_configure_unary_tmp( cur_node ); +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_visit_binary_node(libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *global_timestamp, libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_visit_binary_node(libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *global_timestamp, libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool) { + /* Assign timestamp and propagate info for n_args/max_tmp_size */ + cur_node->visit_timestamp = *global_timestamp; + *global_timestamp = *global_timestamp + 1; + cur_node->n_args = cur_node->le->n_args + cur_node->ri->n_args; + cur_node->max_tmp_size = LIBXSMM_MAX(cur_node->le->max_tmp_size, cur_node->ri->max_tmp_size); + /* Max tmp size has to be adjusted if it is a MATMUL op */ + if (cur_node->info.b_op.type == LIBXSMM_MELTW_TYPE_BINARY_MATMUL) { + libxsmm_blasint matmul_out_size = cur_node->le->tmp.ld * cur_node->ri->tmp.n; + cur_node->max_tmp_size = LIBXSMM_MAX(matmul_out_size, cur_node->max_tmp_size); + } + /* When assigning the tmp output storage, we have three cases in the binary: + * 1) Both children are arg, so we have to reserve a tmp storage + * 2) Both child are NOT arg, so we reuse the tmp storage of either one for our output and we make the other tmp storage available IF we are allowed to overwrite + * 3) One child IS arg and the other child is NOT an arg, so we just reuse the tmp storage of the non-arg child IF we are allowed to overwrite */ + if ( (cur_node->le->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type == LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + cur_node->tree_max_comp_tsize = LIBXSMM_TYPESIZE( cur_node->info.b_op.dtype ); + } else if ( (cur_node->le->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type != LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + if (can_overwrite_binary_input(cur_node) > 0) { + cur_node->tmp.id = cur_node->le->tmp.id; + tmp_storage_pool[cur_node->ri->tmp.id] = 0; + } else { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->le->tmp.id] = 0; + tmp_storage_pool[cur_node->ri->tmp.id] = 0; + } + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE( cur_node->info.b_op.dtype ), LIBXSMM_MAX( cur_node->ri->tree_max_comp_tsize, cur_node->le->tree_max_comp_tsize )); + } else { + if (cur_node->le->type != LIBXSMM_MATRIX_EQN_NODE_ARG) { + if (can_overwrite_binary_input(cur_node) > 0) { + cur_node->tmp.id = cur_node->le->tmp.id; + } else { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->le->tmp.id] = 0; + } + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE(cur_node->info.b_op.dtype), cur_node->le->tree_max_comp_tsize ); + } else { + if (can_overwrite_binary_input(cur_node) > 0) { + cur_node->tmp.id = cur_node->ri->tmp.id; + } else { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->ri->tmp.id] = 0; + } + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE(cur_node->info.b_op.dtype), cur_node->ri->tree_max_comp_tsize ); + } + } + libxsmm_matrix_eqn_exec_plan_configure_binary_tmp( cur_node ); +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_visit_ternary_node(libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *global_timestamp, libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_exec_plan_visit_ternary_node(libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *global_timestamp, libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool) { + /* Assign timestamp and propagate info for n_args/max_tmp_size */ + int use_r2_as_output = ((cur_node->info.t_op.flags & LIBXSMM_MELTW_FLAG_TERNARY_REUSE_IN_2_AS_OUT) > 0) ? 1 : 0; + cur_node->visit_timestamp = *global_timestamp; + *global_timestamp = *global_timestamp + 1; + cur_node->n_args = cur_node->le->n_args + cur_node->ri->n_args + cur_node->r2->n_args; + cur_node->max_tmp_size = LIBXSMM_MAX( LIBXSMM_MAX(cur_node->le->max_tmp_size, cur_node->ri->max_tmp_size), cur_node->r2->max_tmp_size); + if ( (cur_node->le->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->r2->type == LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + cur_node->tree_max_comp_tsize = LIBXSMM_TYPESIZE( cur_node->info.t_op.dtype ); + } else if ( (cur_node->le->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->r2->type != LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + if (use_r2_as_output > 0 ) { + cur_node->tmp.id = cur_node->r2->tmp.id; + tmp_storage_pool[cur_node->le->tmp.id] = 0; + tmp_storage_pool[cur_node->ri->tmp.id] = 0; + } else { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->le->tmp.id] = 0; + tmp_storage_pool[cur_node->ri->tmp.id] = 0; + tmp_storage_pool[cur_node->r2->tmp.id] = 0; + } + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE( cur_node->info.t_op.dtype ), LIBXSMM_MAX( cur_node->r2->tree_max_comp_tsize, LIBXSMM_MAX( cur_node->ri->tree_max_comp_tsize, cur_node->le->tree_max_comp_tsize ))); + } else if ( (cur_node->le->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->r2->type != LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + if (use_r2_as_output > 0 ) { + cur_node->tmp.id = cur_node->r2->tmp.id; + tmp_storage_pool[cur_node->ri->tmp.id] = 0; + } else { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->ri->tmp.id] = 0; + tmp_storage_pool[cur_node->r2->tmp.id] = 0; + } + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE( cur_node->info.t_op.dtype ), LIBXSMM_MAX( cur_node->r2->tree_max_comp_tsize, LIBXSMM_MAX( cur_node->ri->tree_max_comp_tsize, 1 ))); + } else if ( (cur_node->le->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->r2->type != LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + if (use_r2_as_output > 0 ) { + cur_node->tmp.id = cur_node->r2->tmp.id; + tmp_storage_pool[cur_node->le->tmp.id] = 0; + } else { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->le->tmp.id] = 0; + tmp_storage_pool[cur_node->r2->tmp.id] = 0; + } + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE( cur_node->info.t_op.dtype ), LIBXSMM_MAX( cur_node->r2->tree_max_comp_tsize, LIBXSMM_MAX( 1, cur_node->le->tree_max_comp_tsize ))); + } else if ( (cur_node->le->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->r2->type == LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->le->tmp.id] = 0; + tmp_storage_pool[cur_node->ri->tmp.id] = 0; + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE( cur_node->info.t_op.dtype ), LIBXSMM_MAX( 1, LIBXSMM_MAX( cur_node->ri->tree_max_comp_tsize, cur_node->le->tree_max_comp_tsize ))); + } else if ( (cur_node->le->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->r2->type != LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + if (use_r2_as_output > 0 ) { + cur_node->tmp.id = cur_node->r2->tmp.id; + } else { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->r2->tmp.id] = 0; + } + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE( cur_node->info.t_op.dtype ), LIBXSMM_MAX( cur_node->r2->tree_max_comp_tsize, LIBXSMM_MAX( 1, 1 ))); + } else if ( (cur_node->le->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->r2->type == LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->le->tmp.id] = 0; + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE( cur_node->info.t_op.dtype ), LIBXSMM_MAX( 1, LIBXSMM_MAX( 1, cur_node->le->tree_max_comp_tsize ))); + } else if ( (cur_node->le->type == LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->ri->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (cur_node->r2->type == LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + cur_node->tmp.id = reserve_tmp_storage( n_max_tmp, tmp_storage_pool ); + tmp_storage_pool[cur_node->ri->tmp.id] = 0; + cur_node->tree_max_comp_tsize = LIBXSMM_MAX( LIBXSMM_TYPESIZE( cur_node->info.t_op.dtype ), LIBXSMM_MAX( 1, LIBXSMM_MAX( cur_node->ri->tree_max_comp_tsize, 1))); + } + libxsmm_matrix_eqn_exec_plan_configure_ternary_tmp( cur_node ); +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_reassign_children_bcast_tmp(libxsmm_matrix_eqn *eqn, libxsmm_matrix_eqn_elem* cur_node) { + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + /* Do nothing */ + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + if ((cur_node->le->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (get_bcast_type_unary(cur_node->info.u_op.flags) != LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE)) { + cur_node->le->tmp.id = eqn->eqn_root->reg_score; + eqn->eqn_root->reg_score = eqn->eqn_root->reg_score + 1; + } + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, cur_node->le); + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + if ((cur_node->le->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (get_bcast_type_binary(cur_node->info.b_op.flags, LEFT) != LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE)) { + cur_node->le->tmp.id = eqn->eqn_root->reg_score; + eqn->eqn_root->reg_score = eqn->eqn_root->reg_score + 1; + } + if ((cur_node->ri->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (get_bcast_type_binary(cur_node->info.b_op.flags, RIGHT) != LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE)) { + cur_node->ri->tmp.id = eqn->eqn_root->reg_score; + eqn->eqn_root->reg_score = eqn->eqn_root->reg_score + 1; + } + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, cur_node->le); + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, cur_node->ri); + } else if( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + if ((cur_node->le->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (get_bcast_type_ternary(cur_node->info.t_op.flags, LEFT) != LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE)) { + cur_node->le->tmp.id = eqn->eqn_root->reg_score; + eqn->eqn_root->reg_score = eqn->eqn_root->reg_score + 1; + } + if ((cur_node->ri->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (get_bcast_type_ternary(cur_node->info.t_op.flags, RIGHT) != LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE)) { + cur_node->ri->tmp.id = eqn->eqn_root->reg_score; + eqn->eqn_root->reg_score = eqn->eqn_root->reg_score + 1; + } + if ((cur_node->r2->type != LIBXSMM_MATRIX_EQN_NODE_ARG) && (get_bcast_type_ternary(cur_node->info.t_op.flags, RIGHT2) != LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE)) { + cur_node->r2->tmp.id = eqn->eqn_root->reg_score; + eqn->eqn_root->reg_score = eqn->eqn_root->reg_score + 1; + } + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, cur_node->le); + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, cur_node->ri); + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, cur_node->r2); + } else { + /* This should not happen */ + } +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_reassign_bcast_tmp(libxsmm_matrix_eqn *eqn) { + libxsmm_matrix_eqn_elem* root = eqn->eqn_root; + if ( root->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, root->le); + } + if ( root->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, root->le); + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, root->ri); + } + if ( root->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, root->le); + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, root->ri); + libxsmm_matrix_eqn_reassign_children_bcast_tmp(eqn, root->r2); + } +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_create_exec_plan( libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *global_timestamp, libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool ) { + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + libxsmm_matrix_eqn_exec_plan_visit_arg_node(cur_node); + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + /* First visit left child tree */ + libxsmm_matrix_eqn_create_exec_plan( cur_node->le, global_timestamp, n_max_tmp, tmp_storage_pool ); + libxsmm_matrix_eqn_exec_plan_visit_unary_node(cur_node, global_timestamp, n_max_tmp, tmp_storage_pool); + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + /* First we visit the child tree with the maximum register score */ + if (cur_node->le->reg_score >= cur_node->ri->reg_score) { + libxsmm_matrix_eqn_create_exec_plan( cur_node->le, global_timestamp, n_max_tmp, tmp_storage_pool ); + libxsmm_matrix_eqn_create_exec_plan( cur_node->ri, global_timestamp, n_max_tmp, tmp_storage_pool ); + } else { + libxsmm_matrix_eqn_create_exec_plan( cur_node->ri, global_timestamp, n_max_tmp, tmp_storage_pool ); + libxsmm_matrix_eqn_create_exec_plan( cur_node->le, global_timestamp, n_max_tmp, tmp_storage_pool ); + } + libxsmm_matrix_eqn_exec_plan_visit_binary_node(cur_node, global_timestamp, n_max_tmp, tmp_storage_pool); + } else if( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + if ((cur_node->le->reg_score >= cur_node->ri->reg_score) && (cur_node->le->reg_score >= cur_node->r2->reg_score) ) { + libxsmm_matrix_eqn_create_exec_plan( cur_node->le, global_timestamp, n_max_tmp, tmp_storage_pool ); + if ( cur_node->ri->reg_score >= cur_node->r2->reg_score ) { + libxsmm_matrix_eqn_create_exec_plan( cur_node->ri, global_timestamp, n_max_tmp, tmp_storage_pool ); + libxsmm_matrix_eqn_create_exec_plan( cur_node->r2, global_timestamp, n_max_tmp, tmp_storage_pool ); + } else { + libxsmm_matrix_eqn_create_exec_plan( cur_node->r2, global_timestamp, n_max_tmp, tmp_storage_pool ); + libxsmm_matrix_eqn_create_exec_plan( cur_node->ri, global_timestamp, n_max_tmp, tmp_storage_pool ); + } + } else if ((cur_node->ri->reg_score >= cur_node->le->reg_score) && (cur_node->ri->reg_score >= cur_node->r2->reg_score) ) { + libxsmm_matrix_eqn_create_exec_plan( cur_node->ri, global_timestamp, n_max_tmp, tmp_storage_pool ); + if ( cur_node->le->reg_score >= cur_node->r2->reg_score ) { + libxsmm_matrix_eqn_create_exec_plan( cur_node->le, global_timestamp, n_max_tmp, tmp_storage_pool ); + libxsmm_matrix_eqn_create_exec_plan( cur_node->r2, global_timestamp, n_max_tmp, tmp_storage_pool ); + } else { + libxsmm_matrix_eqn_create_exec_plan( cur_node->r2, global_timestamp, n_max_tmp, tmp_storage_pool ); + libxsmm_matrix_eqn_create_exec_plan( cur_node->le, global_timestamp, n_max_tmp, tmp_storage_pool ); + } + } else { + libxsmm_matrix_eqn_create_exec_plan( cur_node->r2, global_timestamp, n_max_tmp, tmp_storage_pool ); + if ( cur_node->le->reg_score >= cur_node->ri->reg_score ) { + libxsmm_matrix_eqn_create_exec_plan( cur_node->le, global_timestamp, n_max_tmp, tmp_storage_pool ); + libxsmm_matrix_eqn_create_exec_plan( cur_node->ri, global_timestamp, n_max_tmp, tmp_storage_pool ); + } else { + libxsmm_matrix_eqn_create_exec_plan( cur_node->ri, global_timestamp, n_max_tmp, tmp_storage_pool ); + libxsmm_matrix_eqn_create_exec_plan( cur_node->le, global_timestamp, n_max_tmp, tmp_storage_pool ); + } + } + libxsmm_matrix_eqn_exec_plan_visit_ternary_node(cur_node, global_timestamp, n_max_tmp, tmp_storage_pool); + } else { + /* This should not happen */ + } +} + +LIBXSMM_API_INTERN +int is_unary_opcode_reduce_kernel (unsigned int opcode) { + int result = 0; + if ((opcode == LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_OP_ADD) || + (opcode == LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_OP_MAX) || + (opcode == LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_OP_MUL) || + (opcode == LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X2_OP_ADD) || + (opcode == LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_OP_ADD_NCNC_FORMAT) || + (opcode == LIBXSMM_MELTW_TYPE_UNARY_REDUCE_X_X2_OP_ADD)) { + result = 1; + } + return result; +} + +LIBXSMM_API_INTERN +int is_unary_opcode_transform_kernel (unsigned int opcode) { + int result = 0; + if ((opcode == LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI) || + (opcode == LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT) || + (opcode == LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI_TO_VNNIT) || + (opcode == LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNIT) || + (opcode == LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI_PAD)) { + result = 1; + } + return result; +} + +LIBXSMM_API_INTERN +int is_unary_opcode_reduce_to_scalar (unsigned int opcode) { + int result = 0; + if (opcode == LIBXSMM_MELTW_TYPE_UNARY_REDUCE_TO_SCALAR_OP_ADD) { + result = 1; + } + return result; +} + +LIBXSMM_API_INTERN +int is_binary_opcode_reduce_to_scalar (unsigned int opcode) { + int result = 0; + if (opcode == LIBXSMM_MELTW_TYPE_BINARY_MUL_AND_REDUCE_TO_SCALAR_OP_ADD) { + result = 1; + } + return result; +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_adjust_tmp_sizes( libxsmm_matrix_eqn_elem* cur_node ) { + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + /* Do nothing */ + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + libxsmm_matrix_eqn_adjust_tmp_sizes( cur_node->le ); + /* If it is reduce kernel, have to resize out tmp */ + if ( is_unary_opcode_reduce_kernel(cur_node->info.u_op.type) > 0 ) { + if ((cur_node->info.u_op.flags & LIBXSMM_MELTW_FLAG_UNARY_REDUCE_ROWS) > 0) { + cur_node->tmp.m = cur_node->le->tmp.n; + cur_node->tmp.n = 1; + cur_node->tmp.ld = cur_node->le->tmp.n; + } else if ((cur_node->info.u_op.flags & LIBXSMM_MELTW_FLAG_UNARY_REDUCE_COLS) > 0) { + cur_node->tmp.m = cur_node->le->tmp.m; + cur_node->tmp.n = 1; + cur_node->tmp.ld = cur_node->le->tmp.ld; + } + } else if ( is_unary_opcode_reduce_to_scalar(cur_node->info.u_op.type) > 0 ) { + cur_node->tmp.m = 1; + cur_node->tmp.n = 1; + cur_node->tmp.ld = 1; + } else if ( is_unary_opcode_transform_kernel(cur_node->info.u_op.type) > 0 ) { + cur_node->tmp.m = cur_node->le->tmp.n; + cur_node->tmp.n = cur_node->le->tmp.m; + cur_node->tmp.ld = cur_node->le->tmp.n; + } else { + cur_node->tmp.m = cur_node->le->tmp.m; + cur_node->tmp.n = cur_node->le->tmp.n; + cur_node->tmp.ld = cur_node->le->tmp.ld; + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + libxsmm_matrix_eqn_adjust_tmp_sizes( cur_node->le); + libxsmm_matrix_eqn_adjust_tmp_sizes( cur_node->ri); + if ( is_binary_opcode_reduce_to_scalar(cur_node->info.b_op.type) > 0 ) { + cur_node->tmp.m = 1; + cur_node->tmp.n = 1; + cur_node->tmp.ld = 1; + } else { + cur_node->tmp.m = LIBXSMM_MAX(cur_node->le->tmp.m, cur_node->ri->tmp.m); + cur_node->tmp.n = LIBXSMM_MAX(cur_node->le->tmp.n, cur_node->ri->tmp.n); + cur_node->tmp.ld = LIBXSMM_MAX(cur_node->le->tmp.ld, cur_node->ri->tmp.ld); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + libxsmm_matrix_eqn_adjust_tmp_sizes( cur_node->le ); + libxsmm_matrix_eqn_adjust_tmp_sizes( cur_node->ri); + libxsmm_matrix_eqn_adjust_tmp_sizes( cur_node->r2); + cur_node->tmp.m = LIBXSMM_MAX(cur_node->r2->tmp.m, LIBXSMM_MAX(cur_node->le->tmp.m, cur_node->ri->tmp.m)); + cur_node->tmp.n = LIBXSMM_MAX(cur_node->r2->tmp.n, LIBXSMM_MAX(cur_node->le->tmp.n, cur_node->ri->tmp.n)); + cur_node->tmp.ld = LIBXSMM_MAX( cur_node->r2->tmp.ld, LIBXSMM_MAX(cur_node->le->tmp.ld, cur_node->ri->tmp.ld)); + } +} + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_opt_exec_plan( libxsmm_blasint idx ); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_opt_exec_plan( libxsmm_blasint idx ) { + libxsmm_blasint global_timestamp = 0; + libxsmm_blasint max_reg_score = 0; + libxsmm_blasint *tmp_storage_pool = NULL; + libxsmm_blasint i; + if ( libxsmm_matrix_eqns[idx] == NULL ) { + fprintf( stderr, "the requested equation doesn't exist, nothing to optimize!\n" ); + } + if ( libxsmm_matrix_eqns[idx]->is_constructed == 0 ) { + fprintf( stderr, "the requested equation is not yet finalized, so can't optimize!\n" ); + } +#if 0 + printf("\n"); + printf("Assigning register scores to find optimal traversal plan (i.e. that minimizes tmp storage)... \n"); +#endif + libxsmm_matrix_eqn_assign_reg_scores( libxsmm_matrix_eqns[idx]->eqn_root ); + max_reg_score = libxsmm_matrix_eqns[idx]->eqn_root->reg_score; + tmp_storage_pool = (libxsmm_blasint*) malloc(max_reg_score * sizeof(libxsmm_blasint)); + if (tmp_storage_pool == NULL) { + fprintf( stderr, "Tmp storage allocation array failed...\n" ); + return; + } else { + for (i = 0; i < max_reg_score; i++) { + tmp_storage_pool[i] = 0; + } + } +#if 0 + printf("Optimal number of intermediate tmp storage is %d\n", max_reg_score); +#endif + libxsmm_matrix_eqn_create_exec_plan( libxsmm_matrix_eqns[idx]->eqn_root, &global_timestamp, max_reg_score, tmp_storage_pool ); + libxsmm_matrix_eqn_adjust_tmp_sizes( libxsmm_matrix_eqns[idx]->eqn_root ); + libxsmm_matrix_eqn_reassign_bcast_tmp( libxsmm_matrix_eqns[idx] ); +#if 0 + printf("Created optimal exexution plan...\n"); +#endif + if (tmp_storage_pool != NULL) { + free(tmp_storage_pool); + } +#if 0 + printf("\n\n"); +#endif + libxsmm_matrix_eqns[idx]->is_optimized = 1; +} + +LIBXSMM_API_INTERN +void libxsmm_generator_reoptimize_eqn(libxsmm_matrix_eqn *eqn) { + libxsmm_blasint max_reg_score = 0, global_timestamp = 0, i = 0; + libxsmm_blasint *tmp_storage_pool = NULL; + libxsmm_matrix_eqn_assign_reg_scores( eqn->eqn_root ); + max_reg_score = eqn->eqn_root->reg_score; + tmp_storage_pool = (libxsmm_blasint*) malloc(max_reg_score * sizeof(libxsmm_blasint)); + if (tmp_storage_pool == NULL) { + fprintf( stderr, "Tmp storage allocation array failed...\n" ); + return; + } else { + for (i = 0; i < max_reg_score; i++) { + tmp_storage_pool[i] = 0; + } + } + libxsmm_matrix_eqn_create_exec_plan( eqn->eqn_root, &global_timestamp, max_reg_score, tmp_storage_pool ); + libxsmm_matrix_eqn_adjust_tmp_sizes( eqn->eqn_root ); + if (tmp_storage_pool != NULL) { + free(tmp_storage_pool); + } +} + +LIBXSMM_API_INTERN libxsmm_matrix_eqn_elem* libxsmm_matrix_eqn_add_node( libxsmm_matrix_eqn_elem* cur_node, libxsmm_matrix_eqn_node_type type, libxsmm_matrix_eqn_info info ); +LIBXSMM_API_INTERN libxsmm_matrix_eqn_elem* libxsmm_matrix_eqn_add_node( libxsmm_matrix_eqn_elem* cur_node, libxsmm_matrix_eqn_node_type type, libxsmm_matrix_eqn_info info ) { + if ( type == LIBXSMM_MATRIX_EQN_NODE_NONE ) { + /* shouldn't happen */ + fprintf( stderr, "wrong op node type to add!\n"); + } + + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + libxsmm_matrix_eqn_elem *node = (libxsmm_matrix_eqn_elem*) malloc( sizeof(libxsmm_matrix_eqn_elem) ); + + node->le = NULL; + node->ri = NULL; + node->r2 = NULL; + node->up = cur_node; + node->type = type; + node->info = info; + + if ( cur_node->le == NULL ) { + cur_node->le = node; + } else { + /* shouldn't happen */ + fprintf( stderr, "this is not a leaf node, so we cannot add a node!\n"); + free( node ); + node = NULL; + } + + return node; + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + libxsmm_matrix_eqn_elem *node = (libxsmm_matrix_eqn_elem*) malloc( sizeof(libxsmm_matrix_eqn_elem) ); + + node->le = NULL; + node->ri = NULL; + node->r2 = NULL; + node->up = cur_node; + node->type = type; + node->info = info; + + if ( cur_node->le == NULL ) { + cur_node->le = node; + } else if ( cur_node->ri == NULL ) { + cur_node->ri = node; + } else { + /* shouldn't happen */ + fprintf( stderr, "this is not a leaf node, so we cannot add a node!\n"); + free( node ); + node = NULL; + } + + return node; + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + libxsmm_matrix_eqn_elem *node = (libxsmm_matrix_eqn_elem*) malloc( sizeof(libxsmm_matrix_eqn_elem) ); + + node->le = NULL; + node->ri = NULL; + node->r2 = NULL; + node->up = cur_node; + node->type = type; + node->info = info; + + if ( cur_node->le == NULL ) { + cur_node->le = node; + } else if ( cur_node->ri == NULL ) { + cur_node->ri = node; + } else if ( cur_node->r2 == NULL ) { + cur_node->r2 = node; + } else { + /* shouldn't happen */ + fprintf( stderr, "this is not a leaf node, so we cannot add a node!\n"); + free( node ); + node = NULL; + } + + return node; + /* we converting the root */ + } else if ( (cur_node->type == LIBXSMM_MATRIX_EQN_NODE_NONE) && (type != LIBXSMM_MATRIX_EQN_NODE_ARG) ) { + cur_node->le = NULL; + cur_node->ri = NULL; + cur_node->r2 = NULL; + cur_node->up = NULL; + cur_node->type = type; + cur_node->info = info; + + return cur_node; + } else { + /* shouldn't happen */ + fprintf( stderr, "at this position we cannot add an op!\n"); + } + + return NULL; +} + + +LIBXSMM_API_INTERN libxsmm_matrix_eqn_elem* libxsmm_matrix_eqn_trv_head( libxsmm_matrix_eqn_elem* cur_node ); +LIBXSMM_API_INTERN libxsmm_matrix_eqn_elem* libxsmm_matrix_eqn_trv_head( libxsmm_matrix_eqn_elem* cur_node ) { + /* check if we are at an argument leaf, then we move up */ + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + return libxsmm_matrix_eqn_trv_head( cur_node->up ); + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + /* we have to push more in this branch */ + if ( cur_node->le == NULL ) { + return cur_node; + /* we have reached the root, as we are unary, there is no right branch */ + } else if ( cur_node->up == NULL ) { + return cur_node; + /* we have to find another node */ + } else { + return libxsmm_matrix_eqn_trv_head( cur_node->up ); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + /* we have to push more in this branch */ + if ( cur_node->le == NULL ) { + return cur_node; + } else if ( cur_node->ri == NULL ) { + return cur_node; + /* we have reached the root, as we are unary, there is no right branch */ + } else if ( cur_node->up == NULL ) { + return cur_node; + /* we have to find another node */ + } else { + return libxsmm_matrix_eqn_trv_head( cur_node->up ); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + /* we have to push more in this branch */ + if ( cur_node->le == NULL ) { + return cur_node; + } else if ( cur_node->ri == NULL ) { + return cur_node; + } else if ( cur_node->r2 == NULL ) { + return cur_node; + /* we have reached the root, as we are unary, there is no right branch */ + } else if ( cur_node->up == NULL ) { + return cur_node; + /* we have to find another node */ + } else { + return libxsmm_matrix_eqn_trv_head( cur_node->up ); + } + } else { + /* should not happen */ + } + + return NULL; +} + + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_trv_print( libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint indent ); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_trv_print( libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint indent ) { + libxsmm_blasint i; + libxsmm_blasint tree_print_indent = 4; + + for ( i = 0; i < indent; ++i ) { + if ( i < indent - tree_print_indent ) { + printf(" "); + } else { + if ( i % tree_print_indent == 0 ) { + printf("|"); + } else { + printf("-"); + } + } + } + + /* check if we are at an argument leaf, then we move up */ + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + if ( (cur_node->le == NULL) && (cur_node->ri == NULL) ) { + printf("ARG: %i %i %i %i %i\n", cur_node->info.arg.m, cur_node->info.arg.n, cur_node->info.arg.ld, cur_node->info.arg.in_pos, cur_node->info.arg.offs_in_pos ); + } else { + printf("ERROR: Arg cannot have left or right child!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + /* we have to push more in this branch */ + if ( cur_node->le != NULL ) { + printf("UNARY: %i %i (timestamp = %i, tmp = %i)\n", (int)cur_node->info.u_op.type, (int)cur_node->info.u_op.flags, cur_node->visit_timestamp, cur_node->tmp.id ); + libxsmm_matrix_eqn_trv_print( cur_node->le, indent+tree_print_indent ); + /* we have reached the root, as we are unary, there is no right branch */ + } else if ( (cur_node->ri != NULL) ) { + printf("ERROR: Unary cannot have right childs!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + /* we have to push more in this branch */ + if ( (cur_node->le != NULL) && (cur_node->ri != NULL) ) { + printf("BINARY: %i %i (timestamp = %i, tmp = %i)\n", (int)cur_node->info.b_op.type, (int)cur_node->info.b_op.flags, cur_node->visit_timestamp, cur_node->tmp.id ); + libxsmm_matrix_eqn_trv_print( cur_node->le, indent+tree_print_indent ); + libxsmm_matrix_eqn_trv_print( cur_node->ri, indent+tree_print_indent ); + } else { + printf("ERROR: Binary needs left and right child!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + /* we have to push more in this branch */ + if ( (cur_node->le != NULL) && (cur_node->ri != NULL) && (cur_node->r2 != NULL) ) { + printf("TERNARY: %i %i (timestamp = %i, tmp = %i)\n", (int)cur_node->info.t_op.type, (int)cur_node->info.t_op.flags, cur_node->visit_timestamp, cur_node->tmp.id ); + libxsmm_matrix_eqn_trv_print( cur_node->le, indent+tree_print_indent ); + libxsmm_matrix_eqn_trv_print( cur_node->ri, indent+tree_print_indent ); + libxsmm_matrix_eqn_trv_print( cur_node->r2, indent+tree_print_indent ); + } else { + printf("ERROR: Ternary needs left, right and right2 child!\n"); + } + } else { + /* shouldn't happen */ + } +} + + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_trv_rpn_print( libxsmm_matrix_eqn_elem* cur_node ); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_trv_rpn_print( libxsmm_matrix_eqn_elem* cur_node ) { + /* check if we are at an argument leaf, then we move up */ + if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_ARG ) { + if ( (cur_node->le == NULL) && (cur_node->ri == NULL) ) { + printf("ARG "); + } else { + printf("ERROR: Arg cannot have left or right child!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_UNARY ) { + /* we have to push more in this branch */ + if ( cur_node->le != NULL ) { + libxsmm_matrix_eqn_trv_rpn_print( cur_node->le ); + printf("UNARY-%i ", (int)cur_node->info.u_op.type ); + /* we have reached the root, as we are unary, there is no right branch */ + } else if ( (cur_node->ri != NULL) ) { + printf("ERROR: Unary cannot have right childs!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_BINARY ) { + /* we have to push more in this branch */ + if ( (cur_node->le != NULL) && (cur_node->ri != NULL) ) { + libxsmm_matrix_eqn_trv_rpn_print( cur_node->le ); + libxsmm_matrix_eqn_trv_rpn_print( cur_node->ri ); + printf("BINARY-%i ", (int)cur_node->info.b_op.type ); + } else { + printf("ERROR: Binary needs left and right child!\n"); + } + } else if ( cur_node->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY ) { + /* we have to push more in this branch */ + if ( (cur_node->le != NULL) && (cur_node->ri != NULL) && (cur_node->r2 != NULL) ) { + libxsmm_matrix_eqn_trv_rpn_print( cur_node->le ); + libxsmm_matrix_eqn_trv_rpn_print( cur_node->ri ); + libxsmm_matrix_eqn_trv_rpn_print( cur_node->r2 ); + printf("TERNARY-%i ", (int)cur_node->info.t_op.type ); + } else { + printf("ERROR: Ternary needs left, right and right2 child!\n"); + } + } else { + /* shouldn't happen */ + } +} + + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_mov_head( libxsmm_blasint idx ); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_mov_head( libxsmm_blasint idx ) { + if ( libxsmm_matrix_eqns[idx] == NULL ) { + fprintf( stderr, "the requested equation doesn't exist!\n" ); + } + if ( libxsmm_matrix_eqns[idx]->is_constructed == 1 ) { + fprintf( stderr, "the requested equation is already finalized!\n" ); + } + + libxsmm_matrix_eqns[idx]->eqn_cur = libxsmm_matrix_eqn_trv_head( libxsmm_matrix_eqns[idx]->eqn_cur ); + +#if 0 + printf("cur node address: %lld\n", libxsmm_matrix_eqns[idx]->eqn_cur ); +#endif + + /* let's see if we need seal the equation */ + if ( (libxsmm_matrix_eqns[idx]->eqn_cur == libxsmm_matrix_eqns[idx]->eqn_root) && + ( ((libxsmm_matrix_eqns[idx]->eqn_cur->type == LIBXSMM_MATRIX_EQN_NODE_UNARY) && (libxsmm_matrix_eqns[idx]->eqn_cur->le != NULL)) || + ((libxsmm_matrix_eqns[idx]->eqn_cur->type == LIBXSMM_MATRIX_EQN_NODE_BINARY) && (libxsmm_matrix_eqns[idx]->eqn_cur->ri != NULL)) || + ((libxsmm_matrix_eqns[idx]->eqn_cur->type == LIBXSMM_MATRIX_EQN_NODE_TERNARY) && (libxsmm_matrix_eqns[idx]->eqn_cur->r2 != NULL)) ) ) { + libxsmm_matrix_eqns[idx]->is_constructed = 1; + libxsmm_matrix_eqn_opt_exec_plan( idx ); + } +} + + +LIBXSMM_API_INTERN int libxsmm_matrix_eqn_is_ready_for_jit( libxsmm_blasint eqn_idx ) { + if ( libxsmm_matrix_eqns[eqn_idx] == NULL ) { + fprintf( stderr, "the requested equation doesn't exist!\n" ); + return 1; + } + if ( libxsmm_matrix_eqns[eqn_idx]->is_constructed == 0 ) { + fprintf( stderr, "the requested equation is not finalized, yet!\n" ); + return 2; + } + if ( libxsmm_matrix_eqns[eqn_idx]->is_optimized == 0 ) { + fprintf( stderr, "the requested equation is not optimized, yet!\n" ); + return 2; + } + + return 0; +} + + +LIBXSMM_API libxsmm_blasint libxsmm_matrix_eqn_create(void) { + libxsmm_blasint ret = libxsmm_matrix_eqns_count; + libxsmm_matrix_eqn_elem* node; + + /* lazy init of helper array */ + if ( libxsmm_matrix_eqns_init == 0 ) { + libxsmm_blasint i; + for ( i = 0; i < 256; ++i ) { + libxsmm_matrix_eqns[i] = NULL; + } + libxsmm_matrix_eqns_count = 0; + libxsmm_matrix_eqns_init = 1; + } + + libxsmm_matrix_eqns_count++; + + libxsmm_matrix_eqns[ret] = (libxsmm_matrix_eqn*) malloc( sizeof(libxsmm_matrix_eqn) ); + + node = (libxsmm_matrix_eqn_elem*) malloc( sizeof(libxsmm_matrix_eqn_elem) ); + + node->le = NULL; + node->ri = NULL; + node->up = NULL; + node->type = LIBXSMM_MATRIX_EQN_NODE_NONE; + + libxsmm_matrix_eqns[ret]->eqn_root = node; + libxsmm_matrix_eqns[ret]->eqn_cur = node; + libxsmm_matrix_eqns[ret]->is_constructed = 0; + libxsmm_matrix_eqns[ret]->is_optimized = 0; + libxsmm_matrix_eqns[ret]->unary_only = 0; + libxsmm_matrix_eqns[ret]->unary_only = 0; +#if 0 + printf("created equation no: %i\n", ret); + printf("root node address: %lld\n", libxsmm_matrix_eqns[ret]->eqn_cur ); +#endif + + return ret; +} + + +LIBXSMM_API int libxsmm_matrix_eqn_push_back_arg( const libxsmm_blasint idx, const libxsmm_blasint m, const libxsmm_blasint n, const libxsmm_blasint ld, const libxsmm_blasint in_pos, const libxsmm_blasint offs_in_pos, const libxsmm_datatype dtype ) { + union libxsmm_matrix_eqn_info info; + + if ( libxsmm_matrix_eqns[idx] == NULL ) { + fprintf( stderr, "the requested equation doesn't exist!\n" ); + return 1; + } + if ( libxsmm_matrix_eqns[idx]->is_constructed == 1 ) { + fprintf( stderr, "the requested equation is already finalized!\n" ); + return 2; + } + + info.arg.m = m; + info.arg.n = n; + info.arg.ld = ld; + info.arg.in_pos = in_pos; + info.arg.offs_in_pos = offs_in_pos; + info.arg.dtype = dtype; + libxsmm_matrix_eqns[idx]->eqn_cur = libxsmm_matrix_eqn_add_node( libxsmm_matrix_eqns[idx]->eqn_cur, LIBXSMM_MATRIX_EQN_NODE_ARG, info ); +#if 0 + printf("added arg node: %lld %i %i %i %i %i %i\n", libxsmm_matrix_eqns[idx]->eqn_cur, M, N, ld, in_pos, offs_in_pos, dtype ); +#endif + + /* move to the next head position in the tree */ + libxsmm_matrix_eqn_mov_head( idx ); + + return 0; +} + + +LIBXSMM_API int libxsmm_matrix_eqn_push_back_unary_op( const libxsmm_blasint idx, const libxsmm_meltw_unary_type type, const libxsmm_meltw_unary_flags flags, const libxsmm_datatype dtype ) { + union libxsmm_matrix_eqn_info info; + + if ( libxsmm_matrix_eqns[idx] == NULL ) { + fprintf( stderr, "the requested equation doesn't exist!\n" ); + return 1; + } + if ( libxsmm_matrix_eqns[idx]->is_constructed == 1 ) { + fprintf( stderr, "the requested equation is already finalized!\n" ); + return 2; + } + + info.u_op.type = type; + info.u_op.flags = flags; + info.u_op.dtype = dtype; + libxsmm_matrix_eqns[idx]->eqn_cur = libxsmm_matrix_eqn_add_node( libxsmm_matrix_eqns[idx]->eqn_cur, LIBXSMM_MATRIX_EQN_NODE_UNARY, info ); +#if 0 + printf("added unary node: %lld %i %i %i\n", libxsmm_matrix_eqns[idx]->eqn_cur, type, flags, dtype ); +#endif + + /* move to the next head position in the tree */ + libxsmm_matrix_eqn_mov_head( idx ); + + return 0; +} + + +LIBXSMM_API int libxsmm_matrix_eqn_push_back_binary_op( const libxsmm_blasint idx, const libxsmm_meltw_binary_type type, const libxsmm_meltw_binary_flags flags, const libxsmm_datatype dtype ) { + union libxsmm_matrix_eqn_info info; + + if ( libxsmm_matrix_eqns[idx] == NULL ) { + fprintf( stderr, "the requested equation doesn't exist!\n" ); + return 1; + } + if ( libxsmm_matrix_eqns[idx]->is_constructed == 1 ) { + fprintf( stderr, "the requested equation is already finalized!\n" ); + return 2; + } + + info.b_op.type = type; + info.b_op.flags = flags; + info.b_op.dtype = dtype; + libxsmm_matrix_eqns[idx]->eqn_cur = libxsmm_matrix_eqn_add_node( libxsmm_matrix_eqns[idx]->eqn_cur, LIBXSMM_MATRIX_EQN_NODE_BINARY, info ); +#if 0 + printf("added binary node: %lld %i %i %i\n", libxsmm_matrix_eqns[idx]->eqn_cur, type, flags, dtype ); +#endif + + /* move to the next head position in the tree */ + libxsmm_matrix_eqn_mov_head( idx ); + + return 0; +} + + +LIBXSMM_API int libxsmm_matrix_eqn_push_back_ternary_op( const libxsmm_blasint idx, const libxsmm_meltw_ternary_type type, const libxsmm_meltw_ternary_flags flags, const libxsmm_datatype dtype ) { + union libxsmm_matrix_eqn_info info; + + if ( libxsmm_matrix_eqns[idx] == NULL ) { + fprintf( stderr, "the requested equation doesn't exist!\n" ); + return 1; + } + if ( libxsmm_matrix_eqns[idx]->is_constructed == 1 ) { + fprintf( stderr, "the requested equation is already finalized!\n" ); + return 2; + } + + info.t_op.type = type; + info.t_op.flags = flags; + info.t_op.dtype = dtype; + libxsmm_matrix_eqns[idx]->eqn_cur = libxsmm_matrix_eqn_add_node( libxsmm_matrix_eqns[idx]->eqn_cur, LIBXSMM_MATRIX_EQN_NODE_TERNARY, info ); +#if 0 + printf("added ternary node: %lld %i %i %i\n", libxsmm_matrix_eqns[idx]->eqn_cur, type, flags, dtype ); +#endif + + /* move to the next head position in the tree */ + libxsmm_matrix_eqn_mov_head( idx ); + + return 0; +} + +LIBXSMM_API void libxsmm_matrix_eqn_tree_print( const libxsmm_blasint idx ) { + if ( libxsmm_matrix_eqns[idx] == NULL ) { + fprintf( stderr, "the requested equation doesn't exist!\n" ); + } + if ( libxsmm_matrix_eqns[idx]->is_constructed == 0 ) { + fprintf( stderr, "the requested equation is not yet finalized!\n" ); + } + + printf("\n"); + printf("Schematic of the expression tree (Pre-order)\n"); + libxsmm_matrix_eqn_trv_print( libxsmm_matrix_eqns[idx]->eqn_root, 0 ); + printf("\n"); +} + + +LIBXSMM_API void libxsmm_matrix_eqn_rpn_print( const libxsmm_blasint idx ) { + if ( libxsmm_matrix_eqns[idx] == NULL ) { + fprintf( stderr, "the requested equation doesn't exist!\n" ); + } + if ( libxsmm_matrix_eqns[idx]->is_constructed == 0 ) { + fprintf( stderr, "the requested equation is not yet finalized!\n" ); + } + + printf("\n"); + printf("HP calculator (RPN) print of the expression tree (Post-order)\n"); + libxsmm_matrix_eqn_trv_rpn_print( libxsmm_matrix_eqns[idx]->eqn_root ); + printf("\n\n"); +} + + diff --git a/third_party/libxsmm/src/libxsmm_matrixeqn.h b/third_party/libxsmm/src/libxsmm_matrixeqn.h new file mode 100644 index 0000000000000000000000000000000000000000..47d2243bf67973fec64f1efc74f6c1d2c94beedd --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_matrixeqn.h @@ -0,0 +1,148 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_MATRIXEQN_H +#define LIBXSMM_MATRIXEQN_H + +#define LEFT 0 +#define RIGHT 1 +#define RIGHT2 2 + +#include +/** + * TF includes src/libxsmm_main.h and uses LIBXSMM's sync primitives + * without including libxsmm_sync. However, libxsmm_sync.h shall be + * an explicit include separate from including libxsmm.h. + */ +#include "libxsmm_sync.h" + +LIBXSMM_EXTERN_C typedef enum libxsmm_matrix_eqn_node_type { + LIBXSMM_MATRIX_EQN_NODE_NONE = 0, + LIBXSMM_MATRIX_EQN_NODE_UNARY = 1, + LIBXSMM_MATRIX_EQN_NODE_BINARY = 2, + LIBXSMM_MATRIX_EQN_NODE_TERNARY = 4, + LIBXSMM_MATRIX_EQN_NODE_ARG = 8 +} libxsmm_matrix_eqn_node_type; + +LIBXSMM_EXTERN_C typedef enum libxsmm_matrix_eqn_bcast_type { + LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE = 0, + LIBXSMM_MATRIX_EQN_BCAST_TYPE_ROW = 1, + LIBXSMM_MATRIX_EQN_BCAST_TYPE_COL = 2, + LIBXSMM_MATRIX_EQN_BCAST_TYPE_SCALAR = 4 +} libxsmm_matrix_eqn_bcast_type; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_unary_op { + libxsmm_meltw_unary_type type; + libxsmm_meltw_unary_flags flags; + libxsmm_datatype dtype; +} libxsmm_matrix_eqn_unary_op; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_binary_op { + libxsmm_meltw_binary_type type; + libxsmm_meltw_binary_flags flags; + libxsmm_datatype dtype; +} libxsmm_matrix_eqn_binary_op; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_ternary_op { + libxsmm_meltw_ternary_type type; + libxsmm_meltw_ternary_flags flags; + libxsmm_datatype dtype; +} libxsmm_matrix_eqn_ternary_op; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_arg { + libxsmm_blasint m; + libxsmm_blasint n; + libxsmm_blasint ld; + libxsmm_blasint in_pos; + libxsmm_blasint offs_in_pos; + libxsmm_datatype dtype; + libxsmm_matrix_eqn_bcast_type bcast_type; +} libxsmm_matrix_eqn_arg; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_tmp_info { + libxsmm_blasint id; + libxsmm_blasint m; + libxsmm_blasint n; + libxsmm_blasint ld; + libxsmm_datatype dtype; + libxsmm_matrix_eqn_bcast_type bcast_type; + libxsmm_blasint m_s; + libxsmm_blasint n_s; + libxsmm_blasint ld_s; + libxsmm_datatype dtype_s; + libxsmm_matrix_eqn_bcast_type bcast_type_s; + libxsmm_blasint m_t; + libxsmm_blasint n_t; + libxsmm_blasint ld_t; + libxsmm_datatype dtype_t; + libxsmm_matrix_eqn_bcast_type bcast_type_t; +} libxsmm_matrix_eqn_tmp_info; + +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_matrix_eqn_info { + libxsmm_matrix_eqn_unary_op u_op; + libxsmm_matrix_eqn_binary_op b_op; + libxsmm_matrix_eqn_ternary_op t_op; + libxsmm_matrix_eqn_arg arg; +} libxsmm_matrix_eqn_info; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_elem { + struct libxsmm_matrix_eqn_elem* le; + struct libxsmm_matrix_eqn_elem* ri; + struct libxsmm_matrix_eqn_elem* r2; + struct libxsmm_matrix_eqn_elem* up; + libxsmm_matrix_eqn_node_type type; + libxsmm_matrix_eqn_info info; + libxsmm_blasint reg_score; + libxsmm_blasint visit_timestamp; + libxsmm_matrix_eqn_tmp_info tmp; + libxsmm_blasint max_tmp_size; + libxsmm_blasint n_args; + libxsmm_blasint tree_max_comp_tsize; +} libxsmm_matrix_eqn_elem; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn { + libxsmm_matrix_eqn_elem* eqn_root; + libxsmm_matrix_eqn_elem* eqn_cur; + libxsmm_blasint is_constructed; + libxsmm_blasint is_optimized; + libxsmm_blasint unary_only; + libxsmm_blasint binary_only; +} libxsmm_matrix_eqn; + +/* Helper functions for matrix equation handling */ +LIBXSMM_API_INTERN libxsmm_matrix_eqn* libxsmm_matrix_eqn_get_equation( libxsmm_blasint eqn_idx ); +LIBXSMM_API_INTERN int libxsmm_matrix_eqn_is_ready_for_jit( libxsmm_blasint eqn_idx ); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_assign_reg_scores( libxsmm_matrix_eqn_elem* cur_node ); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_create_exec_plan( libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *global_timestamp, libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool ); +LIBXSMM_API_INTERN libxsmm_blasint reserve_tmp_storage(libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool); +LIBXSMM_API_INTERN void libxsmm_generator_assign_new_timestamp(libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *current_timestamp ); +LIBXSMM_API_INTERN void libxsmm_generator_matequation_assign_timestamps(libxsmm_matrix_eqn *eqn); +LIBXSMM_API_INTERN void libxsmm_generator_reoptimize_eqn(libxsmm_matrix_eqn *eqn); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_adjust_tmp_sizes( libxsmm_matrix_eqn_elem* cur_node ); +LIBXSMM_API_INTERN int is_unary_opcode_reduce_kernel (unsigned int opcode); +LIBXSMM_API_INTERN int is_unary_opcode_transform_kernel (unsigned int opcode); +LIBXSMM_API_INTERN int is_unary_opcode_reduce_to_scalar (unsigned int opcode); +LIBXSMM_API_INTERN int is_binary_opcode_reduce_to_scalar (unsigned int opcode); + +LIBXSMM_API_INTERN +libxsmm_matrix_eqn_bcast_type get_bcast_type_unary(libxsmm_meltw_unary_flags flags); + +LIBXSMM_API_INTERN +libxsmm_matrix_eqn_bcast_type get_bcast_type_binary(libxsmm_meltw_binary_flags flags, unsigned int side); + +LIBXSMM_API_INTERN +libxsmm_matrix_eqn_bcast_type get_bcast_type_ternary(libxsmm_meltw_ternary_flags flags, unsigned int side); + +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_reassign_bcast_tmp(libxsmm_matrix_eqn *eqn); +LIBXSMM_API_INTERN void libxsmm_matrix_eqn_reassign_children_bcast_tmp(libxsmm_matrix_eqn *eqn, libxsmm_matrix_eqn_elem* cur_node); + + +#endif /*LIBXSMM_MATRIXEQN_H*/ + diff --git a/third_party/libxsmm/src/libxsmm_memory.c b/third_party/libxsmm/src/libxsmm_memory.c new file mode 100644 index 0000000000000000000000000000000000000000..2226bbef63d1537aaa445bc33e1036757a29a098 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_memory.c @@ -0,0 +1,593 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include +#include "libxsmm_hash.h" +#include "libxsmm_diff.h" +#include "libxsmm_main.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if !defined(LIBXSMM_MEMORY_STDLIB) && 0 +# define LIBXSMM_MEMORY_STDLIB +#endif +#if !defined(LIBXSMM_MEMORY_SW) && 0 +# define LIBXSMM_MEMORY_SW +#endif + + +#if !defined(LIBXSMM_MEMORY_SW) +LIBXSMM_APIVAR_DEFINE(unsigned char (*internal_diff_function)(const void*, const void*, unsigned char)); +LIBXSMM_APIVAR_DEFINE(int (*internal_memcmp_function)(const void*, const void*, size_t)); +#endif + + +LIBXSMM_API_INLINE +unsigned char internal_diff_sw(const void* a, const void* b, unsigned char size) +{ +#if defined(LIBXSMM_MEMORY_STDLIB) && defined(LIBXSMM_MEMORY_SW) + return (unsigned char)memcmp(a, b, size); +#else + const uint8_t *const a8 = (const uint8_t*)a, *const b8 = (const uint8_t*)b; + unsigned char i; + LIBXSMM_PRAGMA_UNROLL/*_N(2)*/ + for (i = 0; i < (size & 0xF0); i += 16) { + LIBXSMM_DIFF_16_DECL(aa); + LIBXSMM_DIFF_16_LOAD(aa, a8 + i); + if (LIBXSMM_DIFF_16(aa, b8 + i, 0/*dummy*/)) return 1; + } + for (; i < size; ++i) if (a8[i] ^ b8[i]) return 1; + return 0; +#endif +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_GENERIC) +unsigned char internal_diff_sse(const void* a, const void* b, unsigned char size) +{ +#if defined(LIBXSMM_INTRINSICS_X86) && !defined(LIBXSMM_MEMORY_SW) + const uint8_t *const a8 = (const uint8_t*)a, *const b8 = (const uint8_t*)b; + unsigned char i; + LIBXSMM_PRAGMA_UNROLL/*_N(2)*/ + for (i = 0; i < (size & 0xF0); i += 16) { + LIBXSMM_DIFF_SSE_DECL(aa); + LIBXSMM_DIFF_SSE_LOAD(aa, a8 + i); + if (LIBXSMM_DIFF_SSE(aa, b8 + i, 0/*dummy*/)) return 1; + } + for (; i < size; ++i) if (a8[i] ^ b8[i]) return 1; + return 0; +#else + return internal_diff_sw(a, b, size); +#endif +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) +unsigned char internal_diff_avx2(const void* a, const void* b, unsigned char size) +{ +#if defined(LIBXSMM_INTRINSICS_AVX2) && !defined(LIBXSMM_MEMORY_SW) + const uint8_t *const a8 = (const uint8_t*)a, *const b8 = (const uint8_t*)b; + unsigned char i; + LIBXSMM_PRAGMA_UNROLL/*_N(2)*/ + for (i = 0; i < (size & 0xE0); i += 32) { + LIBXSMM_DIFF_AVX2_DECL(aa); + LIBXSMM_DIFF_AVX2_LOAD(aa, a8 + i); + if (LIBXSMM_DIFF_AVX2(aa, b8 + i, 0/*dummy*/)) return 1; + } + for (; i < size; ++i) if (a8[i] ^ b8[i]) return 1; + return 0; +#else + return internal_diff_sw(a, b, size); +#endif +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +unsigned char internal_diff_avx512(const void* a, const void* b, unsigned char size) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512) && !defined(LIBXSMM_MEMORY_SW) + const uint8_t *const a8 = (const uint8_t*)a, *const b8 = (const uint8_t*)b; + unsigned char i; + LIBXSMM_PRAGMA_UNROLL/*_N(2)*/ + for (i = 0; i < (size & 0xC0); i += 64) { + LIBXSMM_DIFF_AVX512_DECL(aa); + LIBXSMM_DIFF_AVX512_LOAD(aa, a8 + i); + if (LIBXSMM_DIFF_AVX512(aa, b8 + i, 0/*dummy*/)) return 1; + } + for (; i < size; ++i) if (a8[i] ^ b8[i]) return 1; + return 0; +#else + return internal_diff_sw(a, b, size); +#endif +} + + +LIBXSMM_API_INLINE +int internal_memcmp_sw(const void* a, const void* b, size_t size) +{ +#if defined(LIBXSMM_MEMORY_STDLIB) + return memcmp(a, b, size); +#else + const uint8_t *const a8 = (const uint8_t*)a, *const b8 = (const uint8_t*)b; + size_t i; + LIBXSMM_DIFF_16_DECL(aa); + LIBXSMM_PRAGMA_UNROLL/*_N(2)*/ + for (i = 0; i < (size & 0xFFFFFFFFFFFFFFF0); i += 16) { + LIBXSMM_DIFF_16_LOAD(aa, a8 + i); + if (LIBXSMM_DIFF_16(aa, b8 + i, 0/*dummy*/)) return 1; + } + for (; i < size; ++i) if (a8[i] ^ b8[i]) return 1; + return 0; +#endif +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_GENERIC) +int internal_memcmp_sse(const void* a, const void* b, size_t size) +{ +#if defined(LIBXSMM_INTRINSICS_X86) && !defined(LIBXSMM_MEMORY_SW) + const uint8_t *const a8 = (const uint8_t*)a, *const b8 = (const uint8_t*)b; + size_t i; + LIBXSMM_DIFF_SSE_DECL(aa); + LIBXSMM_PRAGMA_UNROLL/*_N(2)*/ + for (i = 0; i < (size & 0xFFFFFFFFFFFFFFF0); i += 16) { + LIBXSMM_DIFF_SSE_LOAD(aa, a8 + i); + if (LIBXSMM_DIFF_SSE(aa, b8 + i, 0/*dummy*/)) return 1; + } + for (; i < size; ++i) if (a8[i] ^ b8[i]) return 1; + return 0; +#else + return internal_memcmp_sw(a, b, size); +#endif +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) +int internal_memcmp_avx2(const void* a, const void* b, size_t size) +{ +#if defined(LIBXSMM_INTRINSICS_AVX2) && !defined(LIBXSMM_MEMORY_SW) + const uint8_t *const a8 = (const uint8_t*)a, *const b8 = (const uint8_t*)b; + size_t i; + LIBXSMM_DIFF_AVX2_DECL(aa); + LIBXSMM_PRAGMA_UNROLL/*_N(2)*/ + for (i = 0; i < (size & 0xFFFFFFFFFFFFFFE0); i += 32) { + LIBXSMM_DIFF_AVX2_LOAD(aa, a8 + i); + if (LIBXSMM_DIFF_AVX2(aa, b8 + i, 0/*dummy*/)) return 1; + } + for (; i < size; ++i) if (a8[i] ^ b8[i]) return 1; + return 0; +#else + return internal_memcmp_sw(a, b, size); +#endif +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +int internal_memcmp_avx512(const void* a, const void* b, size_t size) +{ +#if defined(LIBXSMM_INTRINSICS_AVX512) && !defined(LIBXSMM_MEMORY_SW) + const uint8_t *const a8 = (const uint8_t*)a, *const b8 = (const uint8_t*)b; + size_t i; + LIBXSMM_DIFF_AVX512_DECL(aa); + LIBXSMM_PRAGMA_UNROLL/*_N(2)*/ + for (i = 0; i < (size & 0xFFFFFFFFFFFFFFC0); i += 64) { + LIBXSMM_DIFF_AVX512_LOAD(aa, a8 + i); + if (LIBXSMM_DIFF_AVX512(aa, b8 + i, 0/*dummy*/)) return 1; + } + for (; i < size; ++i) if (a8[i] ^ b8[i]) return 1; + return 0; +#else + return internal_memcmp_sw(a, b, size); +#endif +} + + +LIBXSMM_API_INTERN void libxsmm_memory_init(int target_arch) +{ +#if defined(LIBXSMM_MEMORY_SW) + LIBXSMM_UNUSED(target_arch); +#else + if (LIBXSMM_X86_AVX512 <= target_arch) { +# if defined(LIBXSMM_DIFF_AVX512_ENABLED) + internal_diff_function = internal_diff_avx512; +# else + internal_diff_function = internal_diff_avx2; +# endif +# if defined(LIBXSMM_DIFF_AVX512_ENABLED) + internal_memcmp_function = internal_memcmp_avx512; +# else + internal_memcmp_function = internal_memcmp_avx2; +# endif + } + else if (LIBXSMM_X86_AVX2 <= target_arch) { + internal_diff_function = internal_diff_avx2; + internal_memcmp_function = internal_memcmp_avx2; + } + else if (LIBXSMM_X86_GENERIC <= target_arch) { + internal_diff_function = internal_diff_sse; + internal_memcmp_function = internal_memcmp_sse; + } + else { + internal_diff_function = internal_diff_sw; + internal_memcmp_function = internal_memcmp_sw; + } + LIBXSMM_ASSERT(NULL != internal_diff_function); + LIBXSMM_ASSERT(NULL != internal_memcmp_function); +#endif +} + + +LIBXSMM_API_INTERN void libxsmm_memory_finalize(void) +{ +#if !defined(NDEBUG) && !defined(LIBXSMM_MEMORY_SW) + internal_diff_function = NULL; + internal_memcmp_function = NULL; +#endif +} + + +LIBXSMM_API unsigned char libxsmm_diff_4(const void* a, const void* b, ...) +{ +#if defined(LIBXSMM_MEMORY_SW) + return internal_diff_sw(a, b, 4); +#else + LIBXSMM_DIFF_4_DECL(a4); + LIBXSMM_DIFF_4_LOAD(a4, a); + return LIBXSMM_DIFF_4(a4, b, 0/*dummy*/); +#endif +} + + +LIBXSMM_API unsigned char libxsmm_diff_8(const void* a, const void* b, ...) +{ +#if defined(LIBXSMM_MEMORY_SW) + return internal_diff_sw(a, b, 8); +#else + LIBXSMM_DIFF_8_DECL(a8); + LIBXSMM_DIFF_8_LOAD(a8, a); + return LIBXSMM_DIFF_8(a8, b, 0/*dummy*/); +#endif +} + + +LIBXSMM_API unsigned char libxsmm_diff_16(const void* a, const void* b, ...) +{ +#if defined(LIBXSMM_MEMORY_SW) + return internal_diff_sw(a, b, 16); +#else + LIBXSMM_DIFF_16_DECL(a16); + LIBXSMM_DIFF_16_LOAD(a16, a); + return LIBXSMM_DIFF_16(a16, b, 0/*dummy*/); +#endif +} + + +LIBXSMM_API unsigned char libxsmm_diff_32(const void* a, const void* b, ...) +{ +#if defined(LIBXSMM_MEMORY_SW) + return internal_diff_sw(a, b, 32); +#else + LIBXSMM_DIFF_32_DECL(a32); + LIBXSMM_DIFF_32_LOAD(a32, a); + return LIBXSMM_DIFF_32(a32, b, 0/*dummy*/); +#endif +} + + +LIBXSMM_API unsigned char libxsmm_diff_48(const void* a, const void* b, ...) +{ +#if defined(LIBXSMM_MEMORY_SW) + return internal_diff_sw(a, b, 48); +#else + LIBXSMM_DIFF_48_DECL(a48); + LIBXSMM_DIFF_48_LOAD(a48, a); + return LIBXSMM_DIFF_48(a48, b, 0/*dummy*/); +#endif +} + + +LIBXSMM_API unsigned char libxsmm_diff_64(const void* a, const void* b, ...) +{ +#if defined(LIBXSMM_MEMORY_SW) + return internal_diff_sw(a, b, 64); +#else + LIBXSMM_DIFF_64_DECL(a64); + LIBXSMM_DIFF_64_LOAD(a64, a); + return LIBXSMM_DIFF_64(a64, b, 0/*dummy*/); +#endif +} + + +LIBXSMM_API unsigned char libxsmm_diff(const void* a, const void* b, unsigned char size) +{ +#if defined(LIBXSMM_MEMORY_SW) && !defined(LIBXSMM_MEMORY_STDLIB) + return internal_diff_sw(a, b, size); +#else +# if defined(LIBXSMM_MEMORY_STDLIB) + return 0 != memcmp(a, b, size); +# elif (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) && defined(LIBXSMM_DIFF_AVX512_ENABLED) + return internal_diff_avx512(a, b, size); +# elif (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH) + return internal_diff_avx2(a, b, size); +# elif (LIBXSMM_X86_SSE3 <= LIBXSMM_STATIC_TARGET_ARCH) +# if (LIBXSMM_X86_AVX2 > LIBXSMM_MAX_STATIC_TARGET_ARCH) + return internal_diff_sse(a, b, size); +# else /* pointer based function call */ +# if defined(LIBXSMM_INIT_COMPLETED) + LIBXSMM_ASSERT(NULL != internal_diff_function); + return internal_diff_function(a, b, size); +# else + return (unsigned char)(NULL != internal_diff_function + ? internal_diff_function(a, b, size) + : internal_diff_sse(a, b, size)); +# endif +# endif +# else + return internal_diff_sw(a, b, size); +# endif +#endif +} + + +LIBXSMM_API unsigned int libxsmm_diff_n(const void* a, const void* bn, unsigned char size, + unsigned char stride, unsigned int hint, unsigned int n) +{ + unsigned int result; + LIBXSMM_ASSERT(size <= stride); +#if defined(LIBXSMM_MEMORY_STDLIB) && !defined(LIBXSMM_MEMORY_SW) + LIBXSMM_DIFF_N(unsigned int, result, memcmp, a, bn, size, stride, hint, n); +#else +# if !defined(LIBXSMM_MEMORY_SW) + switch (size) { + case 64: { + LIBXSMM_DIFF_64_DECL(a64); + LIBXSMM_DIFF_64_LOAD(a64, a); + LIBXSMM_DIFF_N(unsigned int, result, LIBXSMM_DIFF_64, a64, bn, size, stride, hint, n); + } break; + case 48: { + LIBXSMM_DIFF_48_DECL(a48); + LIBXSMM_DIFF_48_LOAD(a48, a); + LIBXSMM_DIFF_N(unsigned int, result, LIBXSMM_DIFF_48, a48, bn, size, stride, hint, n); + } break; + case 32: { + LIBXSMM_DIFF_32_DECL(a32); + LIBXSMM_DIFF_32_LOAD(a32, a); + LIBXSMM_DIFF_N(unsigned int, result, LIBXSMM_DIFF_32, a32, bn, size, stride, hint, n); + } break; + case 16: { + LIBXSMM_DIFF_16_DECL(a16); + LIBXSMM_DIFF_16_LOAD(a16, a); + LIBXSMM_DIFF_N(unsigned int, result, LIBXSMM_DIFF_16, a16, bn, size, stride, hint, n); + } break; + case 8: { + LIBXSMM_DIFF_8_DECL(a8); + LIBXSMM_DIFF_8_LOAD(a8, a); + LIBXSMM_DIFF_N(unsigned int, result, LIBXSMM_DIFF_8, a8, bn, size, stride, hint, n); + } break; + case 4: { + LIBXSMM_DIFF_4_DECL(a4); + LIBXSMM_DIFF_4_LOAD(a4, a); + LIBXSMM_DIFF_N(unsigned int, result, LIBXSMM_DIFF_4, a4, bn, size, stride, hint, n); + } break; + default: +# endif + { + LIBXSMM_DIFF_N(unsigned int, result, libxsmm_diff, a, bn, size, stride, hint, n); + } +# if !defined(LIBXSMM_MEMORY_SW) + } +# endif +#endif + return result; +} + + +LIBXSMM_API int libxsmm_memcmp(const void* a, const void* b, size_t size) +{ +#if defined(LIBXSMM_MEMORY_SW) && !defined(LIBXSMM_MEMORY_STDLIB) + return internal_memcmp_sw(a, b, size); +#else +# if defined(LIBXSMM_MEMORY_STDLIB) + return memcmp(a, b, size); +# elif (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) && defined(LIBXSMM_DIFF_AVX512_ENABLED) + return internal_memcmp_avx512(a, b, size); +# elif (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH) + return internal_memcmp_avx2(a, b, size); +# elif (LIBXSMM_X86_SSE3 <= LIBXSMM_STATIC_TARGET_ARCH) +# if (LIBXSMM_X86_AVX2 > LIBXSMM_MAX_STATIC_TARGET_ARCH) + return internal_memcmp_sse(a, b, size); +# else /* pointer based function call */ +# if defined(LIBXSMM_INIT_COMPLETED) + LIBXSMM_ASSERT(NULL != internal_memcmp_function); + return internal_memcmp_function(a, b, size); +# else + return NULL != internal_memcmp_function + ? internal_memcmp_function(a, b, size) + : internal_memcmp_sse(a, b, size); +# endif +# endif +# else + return internal_memcmp_sw(a, b, size); +# endif +#endif +} + + +LIBXSMM_API unsigned int libxsmm_hash(const void* data, unsigned int size, unsigned int seed) +{ + LIBXSMM_INIT + return libxsmm_crc32(seed, data, size); +} + + +LIBXSMM_API unsigned long long libxsmm_hash_string(const char* string) +{ + unsigned long long result; + const size_t length = (NULL != string ? strlen(string) : 0); + if (sizeof(result) < length) { + const size_t length2 = length / 2; + unsigned int seed32 = 0; /* seed=0: match else-optimization */ + LIBXSMM_INIT + seed32 = libxsmm_crc32(seed32, string, length2); + result = libxsmm_crc32(seed32, string + length2, length - length2); + result = (result << 32) | seed32; + } + else { /* reinterpret directly as hash value */ +#if 1 + result = (unsigned long long)string; +#else + char *const s = (char*)&result; signed char i; + for (i = 0; i < (signed char)length; ++i) s[i] = string[i]; + for (; i < (signed char)sizeof(result); ++i) s[i] = 0; +#endif + } + return result; +} + + +LIBXSMM_API const char* libxsmm_stristr(const char* a, const char* b) +{ + const char* result = NULL; + if (NULL != a && NULL != b && '\0' != *a && '\0' != *b) { + do { + if (tolower(*a) != tolower(*b)) { + ++a; + } + else { + const char* c = b; + result = a; + while ('\0' != *++a && '\0' != *++c) { + if (tolower(*a) != tolower(*c)) { + result = NULL; + break; + } + } + if ('\0' != c[0] && '\0' != c[1]) { + result = NULL; + } + else break; + } + } while ('\0' != *a); + } + return result; +} + + +LIBXSMM_API int libxsmm_aligned(const void* ptr, const size_t* inc, int* alignment) +{ + const int minalign = 4 * libxsmm_cpuid_vlen32(libxsmm_target_archid); + const uintptr_t address = (uintptr_t)ptr; + int ptr_is_aligned; + LIBXSMM_ASSERT(LIBXSMM_ISPOT(minalign)); + if (NULL == alignment) { + ptr_is_aligned = !LIBXSMM_MOD2(address, (uintptr_t)minalign); + } + else { + const unsigned int nbits = LIBXSMM_INTRINSICS_BITSCANFWD64(address); + *alignment = (32 > nbits ? (1 << nbits) : INT_MAX); + ptr_is_aligned = (minalign <= *alignment); + } + return ptr_is_aligned && (NULL == inc || !LIBXSMM_MOD2(*inc, (size_t)minalign)); +} + + +#if defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__)) + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xhash)(int* /*hash_seed*/, const void* /*data*/, const int* /*size*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xhash)(int* hash_seed, const void* data, const int* size) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != hash_seed && NULL != data && NULL != size && 0 <= *size) +#endif + { + *hash_seed = (int)(libxsmm_hash(data, (unsigned int)*size, (unsigned int)*hash_seed) & 0x7FFFFFFF/*sign-bit*/); + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xhash specified!\n"); + } +#endif +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xdiff)(int* /*result*/, const void* /*a*/, const void* /*b*/, const long long* /*size*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xdiff)(int* result, const void* a, const void* b, const long long* size) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != result && NULL != a && NULL != b && NULL != size && 0 <= *size) +#endif + { + *result = libxsmm_memcmp(a, b, (size_t)*size); + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xdiff specified!\n"); + } +#endif +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xclear)(void* /*dst*/, const int* /*size*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_xclear)(void* dst, const int* size) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != dst && NULL != size && 0 <= *size && 128 > *size) +#endif + { + LIBXSMM_MEMSET127(dst, 0, *size); + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_xclear specified!\n"); + } +#endif +} + + +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_aligned)(int* /*result*/, const void* /*ptr*/, const int* /*inc*/, int* /*alignment*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_aligned)(int* result, const void* ptr, const int* inc, int* alignment) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != result) +#endif + { + const size_t next = (NULL != inc ? *inc : 0); + *result = libxsmm_aligned(ptr, &next, alignment); + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_aligned specified!\n"); + } +#endif +} + +#endif /*defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/ + diff --git a/third_party/libxsmm/src/libxsmm_mhd.c b/third_party/libxsmm/src/libxsmm_mhd.c new file mode 100644 index 0000000000000000000000000000000000000000..864e01e2d2702463ea51fc8dd88176dbb68ff9fd --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_mhd.c @@ -0,0 +1,925 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include +#include "libxsmm_main.h" /* libxsmm_typesize */ + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if !defined(LIBXSMM_MHD_MAX_LINELENGTH) +# define LIBXSMM_MHD_MAX_LINELENGTH 1024 +#endif + +#if !defined(LIBXSMM_MHD_MAX_ELEMSIZE) +# define LIBXSMM_MHD_MAX_ELEMSIZE 8 +#endif + +#define LIBXSMM_MHD_MINMAX(TYPE, DATA, NELEMENTS, PMIN_INOUT, PMAX_INOUT) { \ + LIBXSMM_ASSERT(NULL != (PMIN_INOUT) && NULL != (PMAX_INOUT)); \ + if (0 < (NELEMENTS)) { \ + size_t libxsmm_mhd_minmax_index_ = 0; \ + do { \ + TYPE libxsmm_mhd_minmax_value_; \ + LIBXSMM_ASSERT(NULL != (DATA)); \ + libxsmm_mhd_minmax_value_ = ((const TYPE*)DATA)[libxsmm_mhd_minmax_index_]; \ + if (libxsmm_mhd_minmax_value_ < *((const TYPE*)PMIN_INOUT)) { \ + *((TYPE*)PMIN_INOUT) = libxsmm_mhd_minmax_value_; \ + } \ + else if (libxsmm_mhd_minmax_value_ > *((const TYPE*)PMAX_INOUT)) { \ + *((TYPE*)PMAX_INOUT) = libxsmm_mhd_minmax_value_; \ + } \ + ++libxsmm_mhd_minmax_index_; \ + } while (libxsmm_mhd_minmax_index_ < (NELEMENTS)); \ + } \ + else *((TYPE*)PMIN_INOUT) = *((TYPE*)PMAX_INOUT) = 0; \ +} + +#define LIBXSMM_MHD_TYPE_PROMOTE(DST_TYPE, SRC_TYPE) \ + (LIBXSMM_MHD_ELEMTYPE_I64 > (DST_TYPE) || (LIBXSMM_MHD_ELEMTYPE_U64 > (DST_TYPE) \ + ? /*dst is signed*/(LIBXSMM_MHD_ELEMTYPE_U64 > (SRC_TYPE) ? ((SRC_TYPE) > (DST_TYPE)) : 0) \ + : /*dst is unsigned*/(LIBXSMM_MHD_ELEMTYPE_U64 > (SRC_TYPE) ? 0 : ((SRC_TYPE) > (DST_TYPE))))) + +#define LIBXSMM_MHD_ELEMENT_CONVERSION_F(SRC_TYPE, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT) { \ + const double h = (0.5 - (DST_TYPE)0.5); \ + SRC_TYPE s = *((const SRC_TYPE*)PSRC); \ + double s0 = 0, s1 = 0; \ + if (NULL != (PSRC_MIN) && LIBXSMM_NOTNAN(s)) { \ + LIBXSMM_ASSERT_MSG(NULL != (PSRC_MAX) && *((const SRC_TYPE*)PSRC_MIN) <= s && s <= *((const SRC_TYPE*)PSRC_MAX), "Invalid value range"); \ + s0 = (double)*((const SRC_TYPE*)PSRC_MIN); s1 = (double)*((const SRC_TYPE*)PSRC_MAX); \ + } \ + if (LIBXSMM_MHD_ELEMTYPE_I64 <= (DST_ENUM) && s0 < s1) { /* scale */ \ + if (LIBXSMM_MHD_ELEMTYPE_U64 <= (DST_ENUM)) { \ + const double s0pos = LIBXSMM_MAX(0, s0), s1pos = LIBXSMM_MAX(0, s1), scale = (s0pos < s1pos ? ((s1 - s0) / (s1pos - s0pos)) : 1); \ + s = (SRC_TYPE)(scale * (double)LIBXSMM_MAX(0, s)); \ + s0 = s0pos; s1 = s1pos; \ + } \ + else if (0 == LIBXSMM_MHD_TYPE_PROMOTE(DST_ENUM, SRC_ENUM) && 0 > s0 && 0 < s1) { \ + s1 = LIBXSMM_MAX(-s0, s1); s0 = -s1; \ + } \ + { const double d0 = (0 <= s0 ? 0 : (DST_MIN)), d1 = (0 <= s1 ? (DST_MAX) : 0), d = ((double)s - s0) * (d1 - d0) / (s1 - s0) + d0; \ + *((DST_TYPE*)PDST) = (DST_TYPE)LIBXSMM_CLMP(0 <= d ? (d + h) : (d - h), d0, d1); \ + } \ + } \ + else if (0 == LIBXSMM_MHD_TYPE_PROMOTE(DST_ENUM, SRC_ENUM)) { /* clamp */ \ + *((DST_TYPE*)PDST) = (DST_TYPE)(0 <= s ? LIBXSMM_CLMP(s + h, DST_MIN, DST_MAX) : LIBXSMM_CLMP(s - h, DST_MIN, DST_MAX)); \ + } \ + else { /* promote */ \ + *((DST_TYPE*)PDST) = (DST_TYPE)(0 <= s ? (s + h) : (s - h)); \ + } \ + RESULT = EXIT_SUCCESS; \ +} + +#define LIBXSMM_MHD_ELEMENT_CONVERSION_I(SRC_TYPE, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT) { \ + const double h = (0.5 - (DST_TYPE)0.5); \ + SRC_TYPE s = *((const SRC_TYPE*)PSRC); \ + double s0 = 0, s1 = 0; \ + if (NULL != (PSRC_MIN)) { \ + LIBXSMM_ASSERT_MSG(NULL != (PSRC_MAX) && *((const SRC_TYPE*)PSRC_MIN) <= s && s <= *((const SRC_TYPE*)PSRC_MAX), "Invalid value range"); \ + s0 = (double)*((const SRC_TYPE*)PSRC_MIN); s1 = (double)*((const SRC_TYPE*)PSRC_MAX); \ + } \ + if (LIBXSMM_MHD_ELEMTYPE_I64 <= (DST_ENUM) && s0 < s1) { /* scale */ \ + if (LIBXSMM_MHD_ELEMTYPE_U64 <= (DST_ENUM)) { \ + const double s0pos = LIBXSMM_MAX(0, s0), s1pos = LIBXSMM_MAX(0, s1), scale = (s0pos < s1pos ? ((s1 - s0) / (s1pos - s0pos)) : 1); \ + const double ss = scale * (double)LIBXSMM_MAX(0, s); \ + s = (SRC_TYPE)(0 <= ss ? (ss + h) : (ss - h)); \ + s0 = s0pos; s1 = s1pos; \ + } \ + else if (0 == LIBXSMM_MHD_TYPE_PROMOTE(DST_ENUM, SRC_ENUM) && 0 > s0 && 0 < s1) { \ + s1 = LIBXSMM_MAX(-s0, s1); s0 = -s1; \ + } \ + { const double d0 = (0 <= s0 ? 0 : (DST_MIN)), d1 = (0 <= s1 ? (DST_MAX) : 0), d = ((double)s - s0) * (d1 - d0) / (s1 - s0) + d0; \ + *((DST_TYPE*)PDST) = (DST_TYPE)LIBXSMM_CLMP(0 <= d ? (d + h) : (d - h), d0, d1); \ + } \ + } \ + else if (0 == LIBXSMM_MHD_TYPE_PROMOTE(DST_ENUM, SRC_ENUM)) { /* clamp */ \ + *((DST_TYPE*)PDST) = (DST_TYPE)LIBXSMM_CLMP(s, DST_MIN, DST_MAX); \ + } \ + else { /* promote */ \ + *((DST_TYPE*)PDST) = (DST_TYPE)s; \ + } \ + RESULT = EXIT_SUCCESS; \ +} + +#define LIBXSMM_MHD_ELEMENT_CONVERSION_U LIBXSMM_MHD_ELEMENT_CONVERSION_I + +#define LIBXSMM_MHD_ELEMENT_CONVERSION(DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT) { \ + LIBXSMM_ASSERT_MSG(NULL != (PDST) && NULL != (PSRC), "Invalid input or output"); \ + switch(SRC_ENUM) { \ + case LIBXSMM_MHD_ELEMTYPE_F64: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_F(double, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_F32: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_F(float, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_BF16: { \ + LIBXSMM_ASSERT_MSG(0, "Not implemented yet"); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_I64: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_I(long long, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_I32: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_I(int, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_I16: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_I(short, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_I8: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_I(signed char, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_U64: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_U(unsigned long long, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_U32: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_U(unsigned int, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_U16: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_U(unsigned short, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + case LIBXSMM_MHD_ELEMTYPE_U8: { \ + LIBXSMM_MHD_ELEMENT_CONVERSION_U(unsigned char, DST_TYPE, DST_ENUM, DST_MIN, DST_MAX, PDST, SRC_ENUM, PSRC, PSRC_MIN, PSRC_MAX, RESULT); \ + } break; \ + default: RESULT = EXIT_FAILURE; \ + } \ +} + + +LIBXSMM_API const char* libxsmm_mhd_typename(libxsmm_mhd_elemtype type, size_t* typesize, const char** ctypename) +{ + const char *mhd_typename = NULL, *c_typename = NULL; + size_t size = 0; + switch (type) { + case LIBXSMM_MHD_ELEMTYPE_F64: { size = 8; mhd_typename = "MET_DOUBLE"; c_typename = "double"; } break; + case LIBXSMM_MHD_ELEMTYPE_F32: { size = 4; mhd_typename = "MET_FLOAT"; c_typename = "float"; } break; + case LIBXSMM_MHD_ELEMTYPE_BF16: { size = 2; mhd_typename = "MET_BFLOAT"; c_typename = "unsigned short"; } break; + case LIBXSMM_MHD_ELEMTYPE_I64: { size = 8; mhd_typename = "MET_LONG"; c_typename = "signed long long"; } break; + case LIBXSMM_MHD_ELEMTYPE_I32: { size = 4; mhd_typename = "MET_INT"; c_typename = "signed int"; } break; + case LIBXSMM_MHD_ELEMTYPE_I16: { size = 2; mhd_typename = "MET_SHORT"; c_typename = "signed short"; } break; + case LIBXSMM_MHD_ELEMTYPE_I8: { size = 1; mhd_typename = "MET_CHAR"; c_typename = "signed char"; } break; + case LIBXSMM_MHD_ELEMTYPE_U64: { size = 8; mhd_typename = "MET_ULONG"; c_typename = "unsigned long long"; } break; + case LIBXSMM_MHD_ELEMTYPE_U32: { size = 4; mhd_typename = "MET_UINT"; c_typename = "unsigned int"; } break; + case LIBXSMM_MHD_ELEMTYPE_U16: { size = 2; mhd_typename = "MET_USHORT"; c_typename = "unsigned short"; } break; + case LIBXSMM_MHD_ELEMTYPE_U8: { size = 1; mhd_typename = "MET_UCHAR"; c_typename = "unsigned char"; } break; + default: size = libxsmm_typesize((libxsmm_datatype)type); /* fallback */ + } + LIBXSMM_ASSERT(size <= LIBXSMM_MHD_MAX_ELEMSIZE); + if (NULL != ctypename) *ctypename = c_typename; + if (NULL != typesize) *typesize = size; + return mhd_typename; +} + + +LIBXSMM_API libxsmm_mhd_elemtype libxsmm_mhd_typeinfo(const char elemname[]) +{ + libxsmm_mhd_elemtype result = LIBXSMM_MHD_ELEMTYPE_UNKNOWN; + if (0 == strcmp("MET_DOUBLE", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_F64; + } + else if (0 == strcmp("MET_FLOAT", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_F32; + } + else if (0 == strcmp("MET_BFLOAT", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_BF16; + } + else if (0 == strcmp("MET_LONG", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_I64; + } + else if (0 == strcmp("MET_INT", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_I32; + } + else if (0 == strcmp("MET_SHORT", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_I16; + } + else if (0 == strcmp("MET_CHAR", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_I8; + } + else if (0 == strcmp("MET_ULONG", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_U64; + } + else if (0 == strcmp("MET_UINT", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_U32; + } + else if (0 == strcmp("MET_USHORT", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_U16; + } + else if (0 == strcmp("MET_UCHAR", elemname)) { + result = LIBXSMM_MHD_ELEMTYPE_U8; + } + return result; +} + +LIBXSMM_API_INLINE int internal_mhd_readline(char buffer[], char split, size_t* key_end, size_t* value_begin) +{ + int result = EXIT_SUCCESS; + char *const isplit = strchr(buffer, split); + + if (0 != isplit) { + char* i = isplit; + LIBXSMM_ASSERT(0 != key_end && 0 != value_begin); + while (buffer != i) { --i; if (0 == isspace((int)(*i))) break; } + *key_end = i - buffer + 1; + i = isplit; + while ('\n' != *++i) if (0 == isspace((int)(*i))) break; + *value_begin = i - buffer; + while (0 != *i && 0 != isprint((int)(*i))) ++i; + if (0 == isprint((int)(*i))) *i = 0; /* fix-up */ + if (i <= (buffer + *value_begin)) { + result = EXIT_FAILURE; + } + } + else { + result = EXIT_FAILURE; + } + + return result; +} + + +LIBXSMM_API int libxsmm_mhd_read_header(const char header_filename[], size_t filename_max_length, + char filename[], size_t* ndims, size_t size[], size_t* ncomponents, libxsmm_mhd_elemtype* type, + size_t* header_size, size_t* extension_size) +{ + int result = EXIT_SUCCESS; + char buffer[LIBXSMM_MHD_MAX_LINELENGTH]; + FILE *const file = (0 < filename_max_length && 0 != filename && 0 != ndims && 0 < *ndims && 0 != size && 0 != type && 0 != ncomponents) + ? fopen(header_filename, "rb") : 0; + + if (0 != file) { + size_t key_end, value_begin; + if (0 != extension_size) *extension_size = 0; + if (0 != header_size) *header_size = 0; + memset(size, 0, *ndims * sizeof(*size)); + *type = LIBXSMM_MHD_ELEMTYPE_UNKNOWN; + *ncomponents = 1; + if (header_filename != filename) { + *filename = 0; + } + + while (0 != fgets(buffer, sizeof(buffer), file) && EXIT_SUCCESS == result && + EXIT_SUCCESS == internal_mhd_readline(buffer, '=', &key_end, &value_begin)) + { + if (0 == strncmp("NDims", buffer, key_end) + && key_end == strlen("NDims")) + { + const int value = atoi(buffer + value_begin); + if (0 < value && value <= ((int)*ndims)) { + *ndims = value; + } + } + else if (0 == strncmp("ElementNumberOfChannels", buffer, key_end) + && key_end == strlen("ElementNumberOfChannels")) + { + const int value = atoi(buffer + value_begin); + if (0 < value) { + *ncomponents = value; + } + else { + result = EXIT_FAILURE; + } + } + else if (0 != extension_size + && 0 == strncmp("ExtensionDataSize", buffer, key_end) + && key_end == strlen("ExtensionDataSize")) + { + const int value = atoi(buffer + value_begin); + if (0 <= value) { + *extension_size = value; + } + else { + result = EXIT_FAILURE; + } + } + else if (0 == strncmp("ElementType", buffer, key_end) + && key_end == strlen("ElementType")) + { + const libxsmm_mhd_elemtype value = libxsmm_mhd_typeinfo(buffer + value_begin); + if (LIBXSMM_MHD_ELEMTYPE_UNKNOWN != value) { + *type = value; + } + } + else if (0 == strncmp("ElementDataFile", buffer, key_end) + && key_end == strlen("ElementDataFile")) + { + const char *const value = buffer + value_begin; + if (0 == strcmp("LOCAL", value) || 0 == strcmp(header_filename, value)) { + if (header_size) { + const long file_position = ftell(file); /* determine the header size */ + const size_t len = strlen(header_filename); + if (0 <= file_position && len < filename_max_length) { + memcpy(filename, header_filename, len + 1); + LIBXSMM_ASSERT(0 == filename[len]); + *header_size = ftell(file); + } + else { + result = EXIT_FAILURE; + } + break; /* ElementDataFile is just before the raw data */ + } + } + else { + const size_t len = strlen(value); + if (len < filename_max_length) { + memcpy(filename, value, len + 1); + LIBXSMM_ASSERT(0 == filename[len]); + } + else { + result = EXIT_FAILURE; + } + } + } + else if (0 == strncmp("DimSize", buffer, key_end) + && key_end == strlen("DimSize")) + { + char* value = buffer + value_begin; + size_t *isize = size, n = 0; + while (EXIT_SUCCESS == internal_mhd_readline(value, ' ', &key_end, &value_begin) && n < *ndims) { + const int ivalue = atoi(value); + if (0 < ivalue) { + *isize = ivalue; + } + else { + result = EXIT_FAILURE; + } + value += key_end + 1; + ++isize; + ++n; + } + if (EXIT_SUCCESS == result) { + if (0 != *value && n < *ndims) { + const int ivalue = atoi(value); + if (0 < ivalue) { + *isize = ivalue; + } + else { + result = EXIT_FAILURE; + } + ++n; + } +#if 0 + else { + result = EXIT_FAILURE; + } +#endif + } + } + else if (0 == strncmp("BinaryData", buffer, key_end) + && key_end == strlen("BinaryData")) + { + const char *const value = buffer + value_begin; + if (0 == strcmp("False", value) || 0 != strcmp("True", value)) { + result = EXIT_FAILURE; + } + } + else if (0 == strncmp("CompressedData", buffer, key_end) + && key_end == strlen("CompressedData")) + { + const char *const value = buffer + value_begin; + if (0 == strcmp("True", value) || 0 != strcmp("False", value)) { + result = EXIT_FAILURE; + } + } + else if ((0 == strncmp("BinaryDataByteOrderMSB", buffer, key_end) && key_end == strlen("BinaryDataByteOrderMSB")) + || (0 == strncmp("ElementByteOrderMSB", buffer, key_end) && key_end == strlen("ElementByteOrderMSB"))) + { + const char *const value = buffer + value_begin; + if (0 == strcmp("True", value) || 0 != strcmp("False", value)) { + result = EXIT_FAILURE; + } + } + } + + if (EXIT_SUCCESS == result && (0 == *filename || LIBXSMM_MHD_ELEMTYPE_UNKNOWN == *type)) { + result = EXIT_FAILURE; + } + /* check size, and eventually trim dimensionality */ + if (EXIT_SUCCESS == result) { + size_t i, d = 1; + for (i = *ndims; 0 < i; --i) { + if (0 != d && 1 == size[i-1]) { + --*ndims; + } + else if (0 == size[i-1]) { + result = EXIT_FAILURE; + break; + } + else { + d = 0; + } + } + } + /* prefix the path of the header file to make sure that the data file can be found */ + if (EXIT_SUCCESS == result && (0 == header_size || 0 == *header_size)) { + const char* split = header_filename + strlen(header_filename) - 1; + while (header_filename != split && 0 == strchr("/\\", *split)) --split; + if (header_filename < split) { + const size_t len = strlen(filename), n = split - header_filename + 1; + if ((len+ n) <= filename_max_length) { + size_t i; + for (i = 1; i <= len; ++i) { + filename[len + n - i] = filename[len - i]; + } + for (i = 0; i < n; ++i) { + filename[i] = header_filename[i]; + } + } + } + } + /* release file handle */ + if (0 != fclose(file) && EXIT_SUCCESS == result) result = EXIT_FAILURE; + } + else { + result = EXIT_FAILURE; + } + + return result; +} + + +LIBXSMM_API int libxsmm_mhd_element_conversion( + void* dst, libxsmm_mhd_elemtype dst_type, libxsmm_mhd_elemtype src_type, + const void* src, const void* src_min, const void* src_max) +{ + int result = EXIT_SUCCESS; + switch (dst_type) { + case LIBXSMM_MHD_ELEMTYPE_F64: { + LIBXSMM_MHD_ELEMENT_CONVERSION(double, dst_type, -1.0, 1.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_F32: { + LIBXSMM_MHD_ELEMENT_CONVERSION(float, dst_type, -1.0, 1.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_BF16: { + LIBXSMM_MHD_ELEMENT_CONVERSION(libxsmm_bfloat16, dst_type, -1.0, 1.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_I64: { + LIBXSMM_MHD_ELEMENT_CONVERSION(long long, dst_type, -9223372036854775808.0, 9223372036854775807.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_I32: { + LIBXSMM_MHD_ELEMENT_CONVERSION(int, dst_type, -2147483648.0, 2147483647.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_I16: { + LIBXSMM_MHD_ELEMENT_CONVERSION(short, dst_type, -32768.0, 32767.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_I8: { + LIBXSMM_MHD_ELEMENT_CONVERSION(signed char, dst_type, -128.0, 127.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_U64: { + LIBXSMM_MHD_ELEMENT_CONVERSION(unsigned long long, dst_type, 0.0, 18446744073709551615.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_U32: { + LIBXSMM_MHD_ELEMENT_CONVERSION(unsigned int, dst_type, 0.0, 4294967295.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_U16: { + LIBXSMM_MHD_ELEMENT_CONVERSION(unsigned short, dst_type, 0.0, 65535.0, dst, src_type, src, src_min, src_max, result); + } break; + case LIBXSMM_MHD_ELEMTYPE_U8: { + LIBXSMM_MHD_ELEMENT_CONVERSION(unsigned char, dst_type, 0.0, 255.0, dst, src_type, src, src_min, src_max, result); + } break; + default: result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API int libxsmm_mhd_element_comparison( + void* dst, libxsmm_mhd_elemtype dst_type, libxsmm_mhd_elemtype src_type, + const void* src, const void* src_min, const void* src_max) +{ + size_t typesize; + int result; + + if (0 != libxsmm_mhd_typename(src_type, &typesize, NULL/*ctypename*/)) { + if (dst_type == src_type) { /* direct comparison */ + result = libxsmm_diff(src, dst, (unsigned char)typesize); + } + else { /* conversion into source type */ + char element[LIBXSMM_MHD_MAX_ELEMSIZE]; + result = libxsmm_mhd_element_conversion(element, dst_type, src_type, src, src_min, src_max); + if (EXIT_SUCCESS == result) { + result = libxsmm_diff(src, element, (unsigned char)typesize); + } + } + } + else { + result = EXIT_FAILURE; + } + + return result; +} + + +/* coverity[var_deref_op] */ +LIBXSMM_API_INLINE int internal_mhd_minmax(const void* data, size_t nelements, + libxsmm_mhd_elemtype type, const void* minval, const void* maxval) +{ + int result; + if ((NULL != data || 0 == nelements) && NULL != minval && NULL != maxval) { + result = EXIT_SUCCESS; + switch (type) { + case LIBXSMM_MHD_ELEMTYPE_F64: { + LIBXSMM_MHD_MINMAX(double, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_F32: { + LIBXSMM_MHD_MINMAX(float, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_BF16: { + LIBXSMM_MHD_MINMAX(libxsmm_bfloat16, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_I64: { + LIBXSMM_MHD_MINMAX(long long, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_I32: { + LIBXSMM_MHD_MINMAX(int, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_I16: { + LIBXSMM_MHD_MINMAX(short, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_I8: { + LIBXSMM_MHD_MINMAX(signed char, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_U64: { + LIBXSMM_MHD_MINMAX(unsigned long long, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_U32: { + LIBXSMM_MHD_MINMAX(unsigned int, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_U16: { + LIBXSMM_MHD_MINMAX(unsigned short, data, nelements, minval, maxval); } break; + case LIBXSMM_MHD_ELEMTYPE_U8: { + LIBXSMM_MHD_MINMAX(unsigned char, data, nelements, minval, maxval); } break; + default: result = EXIT_FAILURE; + } + } + else { + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API_INLINE int internal_mhd_read(FILE* file, void* data, const size_t size[], const size_t pitch[], + size_t ndims, size_t ncomponents, libxsmm_mhd_elemtype type_stored, libxsmm_mhd_elemtype type_data, + size_t typesize, libxsmm_mhd_element_handler handle_element, int minmax, void* minval, void* maxval) +{ + int result = EXIT_SUCCESS; + size_t typesize_stored; + + LIBXSMM_ASSERT(0 != pitch && 0 != typesize); + if (0 != libxsmm_mhd_typename(type_stored, &typesize_stored, NULL/*ctypename*/)) { + if (1 < ndims) { + if (size[0] <= pitch[0]) { + const size_t d = ndims - 1; + + if (EXIT_SUCCESS == result) { + if (size[d] <= pitch[d]) { + size_t sub_size = ncomponents * typesize * pitch[0], i; + + for (i = 1; i < d; ++i) { + if (size[i] <= pitch[i]) { + sub_size *= pitch[i]; + } + else { + result = EXIT_FAILURE; + break; + } + } + for (i = 0; i < size[d] && EXIT_SUCCESS == result; ++i) { + result = internal_mhd_read(file, data, size, pitch, d, ncomponents, + type_stored, type_data, typesize, handle_element, minmax, minval, maxval); + data = ((char*)data) + sub_size; + } + } + else { + result = EXIT_FAILURE; + } + } + } + else { + result = EXIT_FAILURE; + } + } + else if (1 == ndims) { + if (size[0] <= pitch[0]) { + if (type_stored == type_data && 0 == handle_element) { + if (size[0] != fread(data, ncomponents * typesize_stored, size[0], file)) { + result = EXIT_FAILURE; + } + } + else { /* data-conversion or custom data-handler */ + const libxsmm_mhd_element_handler handler = (0 == minmax ? (0 != handle_element ? handle_element : libxsmm_mhd_element_conversion) : NULL); + char element[LIBXSMM_MHD_MAX_ELEMSIZE]; + size_t i, j; + + for (i = 0; i < size[0]; ++i) { + for (j = 0; j < ncomponents; ++j) { + if (EXIT_SUCCESS == result) { + if (1 == fread(element, typesize_stored, 1, file)) { + if (NULL == handler) { /* determine value-range for scaled data-conversion */ + LIBXSMM_ASSERT(0 != minmax); + result = internal_mhd_minmax(element, 1/*n*/, type_stored, minval, maxval); + } + else { /* re-read data incl. conversion */ + LIBXSMM_ASSERT(0 == minmax); + result = handler(data, type_data, type_stored, element, minval, maxval); + data = ((char*)data) + typesize; + } + } + else { + result = EXIT_FAILURE; + } + } + else { + i = size[0]; /* break outer */ + break; + } + } + } + } + } + else { + result = EXIT_FAILURE; + } + } + } + else { + result = EXIT_FAILURE; + } + + return result; +} + + +LIBXSMM_API int libxsmm_mhd_read(const char filename[], + const size_t offset[], const size_t size[], const size_t pitch[], size_t ndims, size_t ncomponents, + size_t header_size, libxsmm_mhd_elemtype type_stored, const libxsmm_mhd_elemtype* type_data, + void* data, libxsmm_mhd_element_handler handle_element, char extension[], size_t extension_size) +{ + int result = EXIT_SUCCESS; + FILE *const file = (0 != filename && 0 != *filename && + 0 != size && 0 != ndims && 0 != ncomponents && + LIBXSMM_MHD_ELEMTYPE_UNKNOWN != type_stored && + (0 == type_data || LIBXSMM_MHD_ELEMTYPE_UNKNOWN != *type_data) && + 0 != data) + ? fopen(filename, "rb") + : NULL; + + if (0 != file) { + const libxsmm_mhd_elemtype datatype = (type_data ? *type_data : type_stored); + const size_t *const shape = (0 != pitch ? pitch : size); + size_t offset1 = (0 != offset ? offset[0] : 0), typesize = 0, i; + + /* check that size is less-equal than pitch */ + if (EXIT_SUCCESS == result) { + for (i = 0; i < ndims; ++i) { + if (size[i] > shape[i]) { + result = EXIT_FAILURE; + break; + } + } + } + /* zeroing buffer if pitch is larger than size */ + if (EXIT_SUCCESS == result) { + if (0 != libxsmm_mhd_typename(datatype, &typesize, NULL/*ctypename*/)) { + size_t size1 = size[0], pitch1 = shape[0]; + for (i = 1; i < ndims; ++i) { + offset1 += (0 != offset ? offset[i] : 0) * pitch1; + pitch1 *= shape[i]; + size1 *= size[i]; + } + LIBXSMM_ASSERT(size1 <= pitch1); + if (size1 != pitch1 && 0 == handle_element) { + memset(data, 0, pitch1 * ncomponents * typesize); + } + } + else { + result = EXIT_FAILURE; + } + } + if (EXIT_SUCCESS == result) { + char *const output = ((char*)data) + offset1 * ncomponents * typesize; + char minmax[2*(LIBXSMM_MHD_MAX_ELEMSIZE)]; + + if (0 != header_size) result = fseek(file, (long)header_size, SEEK_SET); /* set file position to data section */ + if (EXIT_SUCCESS == result && datatype != type_stored) { /* conversion needed */ + if (1 == fread(minmax, typesize, 1, file)) { + LIBXSMM_ASSERT(typesize <= (LIBXSMM_MHD_MAX_ELEMSIZE)); + LIBXSMM_MEMCPY127(minmax + (LIBXSMM_MHD_MAX_ELEMSIZE), minmax, typesize); + result = fseek(file, (long)header_size, SEEK_SET); /* reset file position */ + if (EXIT_SUCCESS == result) { + result = internal_mhd_read(file, NULL/*output*/, size, shape, + ndims, ncomponents, type_stored, datatype, typesize, handle_element, + 1/*search min-max*/, minmax, minmax + (LIBXSMM_MHD_MAX_ELEMSIZE)); + } + if (EXIT_SUCCESS == result) { + result = fseek(file, (long)header_size, SEEK_SET); /* reset file position */ + } + } + else { + result = EXIT_FAILURE; + } + } + if (EXIT_SUCCESS == result) { + result = internal_mhd_read(file, output, size, shape, + ndims, ncomponents, type_stored, datatype, typesize, handle_element, + 0/*use min-max*/, minmax, minmax + (LIBXSMM_MHD_MAX_ELEMSIZE)); + } + } + if (0 != extension && 0 < extension_size) { + if (extension_size != fread(extension, 1, extension_size, file)) { + result = EXIT_FAILURE; + } + } + /* release file handle */ + if (0 != fclose(file) && EXIT_SUCCESS == result) result = EXIT_FAILURE; + } + else { + result = EXIT_FAILURE; + } + + return result; +} + + +LIBXSMM_API_INLINE int internal_mhd_write(FILE* file, const void* data, const size_t size[], const size_t pitch[], + size_t ndims, size_t ncomponents, libxsmm_mhd_elemtype type_data, libxsmm_mhd_elemtype type, + size_t typesize_data, size_t typesize, int minmax, void* minval, void* maxval) +{ + int result = EXIT_SUCCESS; + + LIBXSMM_ASSERT(0 != pitch); + if (1 < ndims) { + if (size[0] <= pitch[0]) { + const size_t d = ndims - 1; + + if (EXIT_SUCCESS == result) { + if (size[d] <= pitch[d]) { + size_t sub_size = ncomponents * typesize_data * pitch[0], i; + + for (i = 1; i < d; ++i) { + if (size[i] <= pitch[i]) { + sub_size *= pitch[i]; + } + else { + result = EXIT_FAILURE; + break; + } + } + for (i = 0; i < size[d] && EXIT_SUCCESS == result; ++i) { + result = internal_mhd_write(file, data, size, pitch, d, ncomponents, + type_data, type, typesize_data, typesize, minmax, minval, maxval); + data = ((const char*)data) + sub_size; + } + } + else { + result = EXIT_FAILURE; + } + } + } + else { + result = EXIT_FAILURE; + } + } + else if (1 == ndims) { + if (size[0] <= pitch[0]) { + if (type == type_data) { + if (size[0] != fwrite(data, ncomponents * typesize_data, size[0], file)) { + result = EXIT_FAILURE; + } + } + else { /* data-conversion */ + char element[LIBXSMM_MHD_MAX_ELEMSIZE]; + size_t i, j; + + if (0 != minmax) { + /* determine value-range for scaled data-conversion */ + result = internal_mhd_minmax(data, size[0] * ncomponents, type_data, minval, maxval); + } + else { + for (i = 0; i < size[0]; ++i) { + for (j = 0; j < ncomponents; ++j) { + if (EXIT_SUCCESS == result) { + result = libxsmm_mhd_element_conversion(element, type, type_data, data, minval, maxval); + if (EXIT_SUCCESS == result) { + if (1 == fwrite(element, typesize, 1, file)) { + data = ((char*)data) + typesize_data; + } + else { + result = EXIT_FAILURE; + } + } + } + else { + i = size[0]; /* break outer */ + break; + } + } + } + } + } + } + else { + result = EXIT_FAILURE; + } + } + + return result; +} + + +LIBXSMM_API int libxsmm_mhd_write(const char filename[], + const size_t offset[], const size_t size[], const size_t pitch[], size_t ndims, size_t ncomponents, + libxsmm_mhd_elemtype type_data, const libxsmm_mhd_elemtype* type, const void* data, size_t* header_size, + const char extension_header[], const void* extension, size_t extension_size) +{ + size_t typesize = 0; + const libxsmm_mhd_elemtype elemtype = (NULL == type ? type_data : *type); + const char *const elemname = libxsmm_mhd_typename(elemtype, &typesize, NULL/*ctypename*/); + FILE *const file = (0 != filename && 0 != *filename && + NULL != size && 0 != ndims && 0 != ncomponents && NULL != data && NULL != elemname && 0 < typesize) + ? fopen(filename, "wb") + : NULL; + int result = EXIT_SUCCESS; + + if (0 != file) { + size_t typesize_data = 0, i; + if (0 < fprintf(file, "NDims = %u\nElementNumberOfChannels = %u\nElementByteOrderMSB = False\nDimSize =", + (unsigned int)ndims, (unsigned int)ncomponents)) + { + for (i = 0; i != ndims; ++i) { + if (0 >= fprintf(file, " %u", (unsigned int)size[i])) { + result = EXIT_FAILURE; + break; + } + } + } + else { + result = EXIT_FAILURE; + } + if (EXIT_SUCCESS == result) { + if (0 < fprintf(file, "\nElementSpacing =")) { + for (i = 0; i != ndims; ++i) { + if (0 >= fprintf(file, " 1.0")) { + result = EXIT_FAILURE; + break; + } + } + } + else { + result = EXIT_FAILURE; + } + } + if (EXIT_SUCCESS == result && 0 != extension_header && 0 != *extension_header) { + if (0 >= fprintf(file, "\n%s", extension_header)) { + result = EXIT_FAILURE; + } + } + /* size of the data, which is silently appended after the regular data section */ + if (EXIT_SUCCESS == result && 0 < extension_size) { + if (0 >= fprintf(file, "\nExtensionDataSize = %u", (unsigned int)extension_size)) { + result = EXIT_FAILURE; + } + } + /* source data type is not required to have MHD element name (type-size is needed) */ + if (EXIT_SUCCESS == result) { + libxsmm_mhd_typename(type_data, &typesize_data, NULL/*ctypename*/); + if (0 == typesize_data) result = EXIT_FAILURE; + } + /* ElementDataFile must be the last entry before writing the data */ + if (EXIT_SUCCESS == result && 0 < fprintf(file, "\nElementType = %s\nElementDataFile = LOCAL\n", elemname)) { + const size_t *const shape = (0 != pitch ? pitch : size); + const char *const input = ((const char*)data) + libxsmm_offset(offset, shape, ndims, NULL/*size*/) * ncomponents * typesize_data; + const long file_position = ftell(file); /* determine the header size */ + char minmax[2*(LIBXSMM_MHD_MAX_ELEMSIZE)]; + + result = (0 <= file_position ? EXIT_SUCCESS : EXIT_FAILURE); + if (EXIT_SUCCESS == result && type_data != elemtype) { /* conversion needed */ + LIBXSMM_MEMCPY127(minmax, data, typesize_data); + LIBXSMM_MEMCPY127(minmax + (LIBXSMM_MHD_MAX_ELEMSIZE), data, typesize_data); /* initial condition */ + result = internal_mhd_write(file, input, size, shape, ndims, ncomponents, type_data, elemtype, typesize_data, typesize, + 1/*search min-max*/, minmax, minmax + (LIBXSMM_MHD_MAX_ELEMSIZE)); + } + if (EXIT_SUCCESS == result) { + if (NULL != header_size) *header_size = file_position; + assert(file_position == ftell(file)); /* !LIBXSMM_ASSERT */ + result = internal_mhd_write(file, input, size, shape, ndims, ncomponents, type_data, elemtype, typesize_data, typesize, + 0/*use min-max*/, minmax, minmax + (LIBXSMM_MHD_MAX_ELEMSIZE)); + } + } + /* append the extension data after the regular data section */ + if (EXIT_SUCCESS == result && 0 < extension_size) { + if (extension_size != fwrite(extension, 1, extension_size, file)) { + result = EXIT_FAILURE; + } + } + /* release file handle */ + if (0 != fclose(file) && EXIT_SUCCESS == result) result = EXIT_FAILURE; + } + else { + result = EXIT_FAILURE; + } + + return result; +} + diff --git a/third_party/libxsmm/src/libxsmm_perf.c b/third_party/libxsmm/src/libxsmm_perf.c new file mode 100644 index 0000000000000000000000000000000000000000..4d7b3d3976dfcdbc6bc490a59e1cb763f1cf9064 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_perf.c @@ -0,0 +1,287 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Maciej Debski (Google Inc.) +******************************************************************************/ +#include "libxsmm_perf.h" +#include +#include +#include + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include "perf_jitdump.h" +#if defined(LIBXSMM_PERF_JITDUMP) && !defined(_WIN32) +# include +# include +# include +# include +# include +# include +# include +#endif +#if defined(__linux__) +# include +#endif +#if defined(_WIN32) +# include +# define LIBXSMM_MAX_PATH MAX_PATH +#else +# if defined(__linux__) +# include +# define LIBXSMM_MAX_PATH PATH_MAX +# elif defined(PATH_MAX) +# define LIBXSMM_MAX_PATH PATH_MAX +# else /* fallback */ +# define LIBXSMM_MAX_PATH 1024 +# endif +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if !defined(NDEBUG) +# define LIBXSMM_PERF_ERROR(msg) fprintf(stderr, msg) +#else +# define LIBXSMM_PERF_ERROR(msg) +#endif + +#if !defined(PERF_JITDUMP_NOLIBXSMM) +LIBXSMM_APIVAR_PRIVATE_DEF(/*const*/ uint32_t JITDUMP_MAGIC); +LIBXSMM_APIVAR_PRIVATE_DEF(/*const*/ uint32_t JITDUMP_MAGIC_SWAPPED); +LIBXSMM_APIVAR_PRIVATE_DEF(/*const*/ uint32_t JITDUMP_VERSION); +LIBXSMM_APIVAR_PRIVATE_DEF(/*const*/ uint64_t JITDUMP_FLAGS_ARCH_TIMESTAMP); +LIBXSMM_APIVAR_PRIVATE_DEF(/*const*/ uint32_t JITDUMP_CODE_LOAD); +LIBXSMM_APIVAR_PRIVATE_DEF(/*const*/ uint32_t JITDUMP_CODE_MOVE); +LIBXSMM_APIVAR_PRIVATE_DEF(/*const*/ uint32_t JITDUMP_CODE_DEBUG_INFO); +LIBXSMM_APIVAR_PRIVATE_DEF(/*const*/ uint32_t JITDUMP_CODE_CLOSE); +#endif + +LIBXSMM_APIVAR_DEFINE(FILE* internal_perf_fp); +#if defined(LIBXSMM_PERF_JITDUMP) && !defined(_WIN32) +LIBXSMM_APIVAR_DEFINE(void* internal_perf_marker); +LIBXSMM_APIVAR_DEFINE(int internal_perf_codeidx); +#endif + + +LIBXSMM_API_INTERN void libxsmm_perf_init(void) +{ + const uint32_t pid = (uint32_t)libxsmm_get_pid(); + char file_name[LIBXSMM_MAX_PATH]; +#if defined(LIBXSMM_PERF_JITDUMP) && !defined(_WIN32) + char file_path[LIBXSMM_MAX_PATH]; + int fd, page_size, res; + struct jitdump_file_header header; + char * path_base; + char date[64]; + time_t t = time(NULL); + struct tm tm = *localtime(&t); + + /* initialize global variables */ + JITDUMP_MAGIC = ('J' << 24 | 'i' << 16 | 'T' << 8 | 'D'); + JITDUMP_MAGIC_SWAPPED = ('J' | 'i' << 8 | 'T' << 16 | 'D' << 24); + JITDUMP_VERSION = 1; + JITDUMP_FLAGS_ARCH_TIMESTAMP = 1ULL /*<< 0*/; + JITDUMP_CODE_LOAD = 0; + JITDUMP_CODE_MOVE = 1; + JITDUMP_CODE_DEBUG_INFO = 2; + JITDUMP_CODE_CLOSE = 3; + + path_base = getenv("JITDUMPDIR"); + if (path_base == NULL) { + path_base = getenv("HOME"); + } + if (path_base == NULL) { + path_base = "."; + } + + LIBXSMM_SNPRINTF(file_path, sizeof(file_path), "%s/.debug/", path_base); + res = mkdir(file_path, S_IRWXU); + if (res < 0 && errno != EEXIST) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: failed to create .debug dir\n"); + goto error; + } + + LIBXSMM_SNPRINTF(file_path, sizeof(file_path), "%s/.debug/jit", path_base); + res = mkdir(file_path, S_IRWXU); + if (res < 0 && errno != EEXIST) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: failed to create .debug/jit dir\n"); + goto error; + } + + strftime(date, sizeof(date), "%Y%m%d", &tm); + + LIBXSMM_SNPRINTF(file_path, sizeof(file_path), + "%s/.debug/jit/libxsmm-jit-%s.XXXXXX", path_base, date); + path_base = mkdtemp(file_path); + if (path_base == NULL) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: failed to create temporary dir\n"); + goto error; + } + + LIBXSMM_SNPRINTF(file_name, sizeof(file_name), "%s/jit-%u.dump", path_base, pid); + + fd = open(file_name, O_CREAT|O_TRUNC|O_RDWR, 0600); + if (fd < 0) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: failed to open file\n"); + goto error; + } + + page_size = sysconf(_SC_PAGESIZE); + if (page_size < 0) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: failed to get page size\n"); + goto error; + } + internal_perf_marker = mmap(NULL, page_size, PROT_READ|PROT_EXEC, MAP_PRIVATE, fd, 0); + if (internal_perf_marker == MAP_FAILED) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: mmap failed.\n"); + goto error; + } + + /* initialize code index */ + internal_perf_codeidx = 0; + + internal_perf_fp = fdopen(fd, "wb+"); + if (internal_perf_fp == NULL) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: fdopen failed.\n"); + goto error; + } + + LIBXSMM_MEMZERO127(&header); + header.magic = JITDUMP_MAGIC; + header.version = JITDUMP_VERSION; + header.elf_mach = 62; /* EM_X86_64 */ + header.total_size = sizeof(header); + header.pid = pid; + header.timestamp = libxsmm_timer_tick(); + header.flags = JITDUMP_FLAGS_ARCH_TIMESTAMP; + + res = fwrite(&header, sizeof(header), 1, internal_perf_fp); + if (res != 1) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: failed to write header.\n"); + goto error; + } +#else + LIBXSMM_SNPRINTF(file_name, sizeof(file_name), "/tmp/perf-%u.map", pid); + internal_perf_fp = fopen(file_name, "w+"); + if (internal_perf_fp == NULL) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: failed to open map file\n"); + goto error; + } +#endif + return; +error: + if (internal_perf_fp != NULL) { + fclose(internal_perf_fp); + internal_perf_fp = NULL; + } + assert(0); +} + + +LIBXSMM_API_INTERN void libxsmm_perf_finalize(void) +{ +#if defined(LIBXSMM_PERF_JITDUMP) && !defined(_WIN32) + int res, page_size; + struct jitdump_record_header hdr; + + if (internal_perf_fp == NULL) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: jit dump file not opened\n"); + goto error; + } + + LIBXSMM_MEMZERO127(&hdr); + hdr.id = JITDUMP_CODE_CLOSE; + hdr.total_size = sizeof(hdr); + hdr.timestamp = libxsmm_timer_tick(); + res = fwrite(&hdr, sizeof(hdr), 1, internal_perf_fp); + + if (res != 1) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: failed to write JIT_CODE_CLOSE record\n"); + goto error; + } + + page_size = sysconf(_SC_PAGESIZE); + if (page_size < 0) { + LIBXSMM_PERF_ERROR("LIBXSMM ERROR: failed to get page_size\n"); + goto error; + } + munmap(internal_perf_marker, page_size); + fclose(internal_perf_fp); + return; +error: + assert(0); +#else + fclose(internal_perf_fp); +#endif +} + + +#if defined(LIBXSMM_PERF_JITDUMP) && !defined(_WIN32) +/** Utility function to receive the OS-specific thread ID. */ +LIBXSMM_API_INLINE unsigned int internal_perf_get_tid(void) +{ +#if defined(__linux__) + return (unsigned int)syscall(__NR_gettid); +#else /* fallback */ + return libxsmm_get_tid(); +#endif +} +#endif + + +LIBXSMM_API_INTERN void libxsmm_perf_dump_code(const void* memory, size_t size, const char* name) +{ + assert(internal_perf_fp != NULL); + assert(name && *name); + assert(memory != NULL && size != 0); + if (internal_perf_fp != NULL) { +#if defined(LIBXSMM_PERF_JITDUMP) && !defined(_WIN32) + int res; + struct jitdump_record_header hdr; + struct jitdump_record_code_load rec; + size_t name_len = strlen(name) + 1; + + LIBXSMM_MEMZERO127(&hdr); + LIBXSMM_MEMZERO127(&rec); + + hdr.id = JITDUMP_CODE_LOAD; + hdr.total_size = sizeof(hdr) + sizeof(rec) + name_len + size; + hdr.timestamp = libxsmm_timer_tick(); + + rec.code_size = size; + rec.vma = (uintptr_t) memory; + rec.code_addr = (uintptr_t) memory; + rec.pid = (uint32_t) libxsmm_get_pid(); + rec.tid = (uint32_t) internal_perf_get_tid(); + + LIBXSMM_FLOCK(internal_perf_fp); + + /* This will be unique as we hold the file lock. */ + rec.code_index = internal_perf_codeidx++; + + /* Count number of written items to check for errors. */ + res = 0; + res += fwrite_unlocked(&hdr, sizeof(hdr), 1, internal_perf_fp); + res += fwrite_unlocked(&rec, sizeof(rec), 1, internal_perf_fp); + res += fwrite_unlocked(name, name_len, 1, internal_perf_fp); + res += fwrite_unlocked((const void*) memory, size, 1, internal_perf_fp); + + LIBXSMM_FUNLOCK(internal_perf_fp); + fflush(internal_perf_fp); + + assert(res == 4); /* Expected 4 items written above */ +#else + fprintf(internal_perf_fp, "%" PRIxPTR " %lx %s\n", (uintptr_t)memory, (unsigned long)size, name); + fflush(internal_perf_fp); +#endif + } +} + diff --git a/third_party/libxsmm/src/libxsmm_perf.h b/third_party/libxsmm/src/libxsmm_perf.h new file mode 100644 index 0000000000000000000000000000000000000000..66029c64882c41d0208c428c8c02f8da05c42309 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_perf.h @@ -0,0 +1,23 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Maciej Debski (Google Inc.) +******************************************************************************/ +#ifndef LIBXSMM_PERF_H +#define LIBXSMM_PERF_H + +#include + + +LIBXSMM_API_INTERN void libxsmm_perf_init(void); +LIBXSMM_API_INTERN void libxsmm_perf_finalize(void); +LIBXSMM_API_INTERN void libxsmm_perf_dump_code( + const void* memory, size_t size, + const char* name); + +#endif /* LIBXSMM_PERF_H */ diff --git a/third_party/libxsmm/src/libxsmm_python.c b/third_party/libxsmm/src/libxsmm_python.c new file mode 100644 index 0000000000000000000000000000000000000000..e7da7cbc7e742cfb6b5cd02170a89a7ce4459a50 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_python.c @@ -0,0 +1,142 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#if defined(__PYTHON) && defined(LIBXSMM_BUILD) && !defined(__STATIC) +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include /* must be included first */ +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif +#endif +#include + + +#if defined(__PYTHON) && defined(LIBXSMM_BUILD) && !defined(__STATIC) + +LIBXSMM_API PyObject* libxsmmpy_get_target_arch(PyObject* self, PyObject* args); +LIBXSMM_API PyObject* libxsmmpy_get_target_arch(PyObject* self, PyObject* args) +{ + LIBXSMM_UNUSED(self); LIBXSMM_UNUSED(args); + return PyString_InternFromString(libxsmm_get_target_arch()); +} + +LIBXSMM_API PyObject* libxsmmpy_set_target_arch(PyObject* self, PyObject* args); +LIBXSMM_API PyObject* libxsmmpy_set_target_arch(PyObject* self, PyObject* args) +{ + int ivalue = LIBXSMM_TARGET_ARCH_UNKNOWN; + char* svalue = NULL; + LIBXSMM_UNUSED(self); + if (0 != PyArg_ParseTuple(args, "s", &svalue)) { + libxsmm_set_target_arch(svalue); + } + else if (0 != PyArg_ParseTuple(args, "i", &ivalue)) { + libxsmm_set_target_archid(ivalue); + } + else { /* error */ + return NULL; + } + Py_RETURN_NONE; +} + + +LIBXSMM_API PyObject* libxsmmpy_get_target_archid(PyObject* self, PyObject* args); +LIBXSMM_API PyObject* libxsmmpy_get_target_archid(PyObject* self, PyObject* args) +{ + LIBXSMM_UNUSED(self); LIBXSMM_UNUSED(args); + return Py_BuildValue("i", libxsmm_get_target_archid()); +} + +LIBXSMM_API PyObject* libxsmmpy_set_target_archid(PyObject* self, PyObject* args); +LIBXSMM_API PyObject* libxsmmpy_set_target_archid(PyObject* self, PyObject* args) +{ + int value = LIBXSMM_TARGET_ARCH_UNKNOWN; + LIBXSMM_UNUSED(self); + if (0 != PyArg_ParseTuple(args, "i", &value)) { + libxsmm_set_target_archid(value); + } + else { /* error */ + return NULL; + } + Py_RETURN_NONE; +} + + +LIBXSMM_API PyObject* libxsmmpy_get_verbosity(PyObject* self, PyObject* args); +LIBXSMM_API PyObject* libxsmmpy_get_verbosity(PyObject* self, PyObject* args) +{ + LIBXSMM_UNUSED(self); LIBXSMM_UNUSED(args); + return Py_BuildValue("i", libxsmm_get_verbosity()); +} + +LIBXSMM_API PyObject* libxsmmpy_set_verbosity(PyObject* self, PyObject* args); +LIBXSMM_API PyObject* libxsmmpy_set_verbosity(PyObject* self, PyObject* args) +{ + int value = 0; + LIBXSMM_UNUSED(self); + if (0 != PyArg_ParseTuple(args, "i", &value)) { + libxsmm_set_verbosity(value); + } + else { /* error */ + return NULL; + } + Py_RETURN_NONE; +} + + +LIBXSMM_API PyMODINIT_FUNC initlibxsmm(void); +LIBXSMM_API PyMODINIT_FUNC initlibxsmm(void) +{ + static PyMethodDef pymethod_def[] = { + { "GetTargetArch", libxsmmpy_get_target_arch, METH_NOARGS, + PyDoc_STR("Get the name of the code path.") }, + { "SetTargetArch", libxsmmpy_set_target_arch, METH_VARARGS, + PyDoc_STR("Set the name of the code path.") }, + { "GetTargetArchId", libxsmmpy_get_target_archid, METH_NOARGS, + PyDoc_STR("Get the id of the code path.") }, + { "SetTargetArchId", libxsmmpy_set_target_archid, METH_VARARGS, + PyDoc_STR("Set the id of the code path.") }, + { "GetVerbosity", libxsmmpy_get_verbosity, METH_NOARGS, + PyDoc_STR("Get the verbosity level.") }, + { "SetVerbosity", libxsmmpy_set_verbosity, METH_VARARGS, + PyDoc_STR("Set the verbosity level.") }, + { NULL, NULL, 0, NULL } /* end of table */ + }; + PyObject *const pymod = Py_InitModule3("libxsmm", pymethod_def, PyDoc_STR( + "Library targeting Intel Architecture for small, dense or " + "sparse matrix multiplications, and small convolutions.")); + PyModule_AddIntConstant(pymod, "VERSION_API", LIBXSMM_VERSION2(LIBXSMM_VERSION_MAJOR, LIBXSMM_VERSION_MINOR)); + PyModule_AddIntConstant(pymod, "VERSION_ALL", LIBXSMM_VERSION4(LIBXSMM_VERSION_MAJOR, LIBXSMM_VERSION_MINOR, + LIBXSMM_VERSION_UPDATE, LIBXSMM_VERSION_PATCH)); + PyModule_AddIntConstant(pymod, "VERSION_MAJOR", LIBXSMM_VERSION_MAJOR); + PyModule_AddIntConstant(pymod, "VERSION_MINOR", LIBXSMM_VERSION_MINOR); + PyModule_AddIntConstant(pymod, "VERSION_UPDATE", LIBXSMM_VERSION_UPDATE); + PyModule_AddIntConstant(pymod, "VERSION_PATCH", LIBXSMM_VERSION_PATCH); + PyModule_AddStringConstant(pymod, "VERSION", LIBXSMM_VERSION); + PyModule_AddStringConstant(pymod, "BRANCH", LIBXSMM_BRANCH); + PyModule_AddIntConstant(pymod, "TARGET_ARCH_UNKNOWN", LIBXSMM_TARGET_ARCH_UNKNOWN); + PyModule_AddIntConstant(pymod, "TARGET_ARCH_GENERIC", LIBXSMM_TARGET_ARCH_GENERIC); + PyModule_AddIntConstant(pymod, "X86_GENERIC", LIBXSMM_X86_GENERIC); + PyModule_AddIntConstant(pymod, "X86_SSE3", LIBXSMM_X86_SSE3); + PyModule_AddIntConstant(pymod, "X86_SSE42", LIBXSMM_X86_SSE42); + PyModule_AddIntConstant(pymod, "X86_AVX", LIBXSMM_X86_AVX); + PyModule_AddIntConstant(pymod, "X86_AVX2", LIBXSMM_X86_AVX2); + PyModule_AddIntConstant(pymod, "X86_AVX512", LIBXSMM_X86_AVX512); + PyModule_AddIntConstant(pymod, "X86_AVX512_MIC", LIBXSMM_X86_AVX512_MIC); + PyModule_AddIntConstant(pymod, "X86_AVX512_KNM", LIBXSMM_X86_AVX512_KNM); + PyModule_AddIntConstant(pymod, "X86_AVX512_CORE", LIBXSMM_X86_AVX512_CORE); + PyModule_AddIntConstant(pymod, "X86_AVX512_CLX", LIBXSMM_X86_AVX512_CLX); + PyModule_AddIntConstant(pymod, "X86_AVX512_CPX", LIBXSMM_X86_AVX512_CPX); + libxsmm_init(); /* initialize LIBXSMM */ +} + +#endif /*defined(__PYTHON) && defined(LIBXSMM_BUILD) && !defined(__STATIC)*/ + diff --git a/third_party/libxsmm/src/libxsmm_rng.c b/third_party/libxsmm/src/libxsmm_rng.c new file mode 100644 index 0000000000000000000000000000000000000000..0a8f868b19a23b6fec1c7a122c26edcba6883a41 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_rng.c @@ -0,0 +1,314 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_rng.h" +#include "libxsmm_main.h" + +#if !defined(LIBXSMM_RNG_DRAND48) && (!defined(_WIN32) && !defined(__CYGWIN__) && (defined(_SVID_SOURCE) || defined(_XOPEN_SOURCE))) +# define LIBXSMM_RNG_DRAND48 +#endif + +#if !defined(LIBXSMM_RNG_SIMD_MIN) +# define LIBXSMM_RNG_SIMD_MIN 8 +#endif + +/* dispatched RNG functions (separate typedef for legacy Cray C++ needed) */ +typedef void (*internal_rng_f32_seq_fn)(float*, libxsmm_blasint); +LIBXSMM_APIVAR_DEFINE(internal_rng_f32_seq_fn internal_rng_f32_seq); +/* 2048-bit state for RNG */ +LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state0[16]); +LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state1[16]); +LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state2[16]); +LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state3[16]); + + +LIBXSMM_API_INLINE void internal_rng_float_jump(uint32_t* state0, uint32_t* state1, uint32_t* state2, uint32_t* state3) +{ + static const uint32_t jump_table[] = { 0x8764000b, 0xf542d2d3, 0x6fa035c3, 0x77f2db5b }; + uint32_t s0 = 0, s1 = 0, s2 = 0, s3 = 0; + int i, b; + + LIBXSMM_ASSERT(4 == sizeof(jump_table) / sizeof(*jump_table)); + for (i = 0; i < 4; ++i) { + for (b = 0; b < 32; ++b) { + if (jump_table[i] & (1U << b)) { + s0 ^= *state0; + s1 ^= *state1; + s2 ^= *state2; + s3 ^= *state3; + } + { /* draw one more integer */ + const uint32_t t = *state1 << 9; + *state2 ^= *state0; + *state3 ^= *state1; + *state1 ^= *state2; + *state0 ^= *state3; + *state2 ^= t; + *state3 = ((*state3 << 11) | (*state3 >> (32 - 11))); + } + } + } + *state0 = s0; + *state1 = s1; + *state2 = s2; + *state3 = s3; +} + + +LIBXSMM_API_INLINE float internal_rng_scalar_float_next(int i) +{ + const uint32_t rng_mantissa = (internal_rng_state0[i] + internal_rng_state3[i]) >> 9; + const uint32_t t = internal_rng_state1[i] << 9; + union { uint32_t i; float f; } rng; + + internal_rng_state2[i] ^= internal_rng_state0[i]; + internal_rng_state3[i] ^= internal_rng_state1[i]; + internal_rng_state1[i] ^= internal_rng_state2[i]; + internal_rng_state0[i] ^= internal_rng_state3[i]; + internal_rng_state2[i] ^= t; + internal_rng_state3[i] = ((internal_rng_state3[i] << 11) | (internal_rng_state3[i] >> (32 - 11))); + + rng.i = 0x3f800000 | rng_mantissa; + return rng.f - 1.0f; +} + + +LIBXSMM_API_INTERN void internal_rng_set_seed_sw(uint32_t seed); +LIBXSMM_API_INTERN void internal_rng_set_seed_sw(uint32_t seed) +{ + static const uint32_t temp_state[] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 131, 130, 129, 128, 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116, + 231, 230, 229, 228, 227, 226, 225, 224, 223, 222, 221, 220, 219, 218, 217, 216, + 331, 330, 329, 328, 327, 326, 325, 324, 323, 322, 321, 320, 319, 318, 317, 316 + }; + libxsmm_blasint i; + + /* finish initializing the state */ + LIBXSMM_ASSERT((16 * 4) == sizeof(temp_state) / sizeof(*temp_state)); + for (i = 0; i < 16; ++i) { + internal_rng_state0[i] = seed + temp_state[i]; + internal_rng_state1[i] = seed + temp_state[i+16]; + internal_rng_state2[i] = seed + temp_state[i+32]; + internal_rng_state3[i] = seed + temp_state[i+48]; + } + for (i = 0; i < 16; ++i) { + internal_rng_float_jump( /* progress each sequence by 2^64 */ + internal_rng_state0 + i, internal_rng_state1 + i, + internal_rng_state2 + i, internal_rng_state3 + i); + } + /* for consistency, other RNGs are seeded as well */ +#if !defined(_WIN32) && !defined(__CYGWIN__) && (defined(_SVID_SOURCE) || defined(_XOPEN_SOURCE)) + srand48(seed); +#endif + srand(seed); +} + + +LIBXSMM_API_INLINE void internal_rng_f32_seq_sw(float* rngs, libxsmm_blasint count) +{ + libxsmm_blasint i = 0; + for (; i < count; ++i) { + rngs[i] = internal_rng_scalar_float_next(LIBXSMM_MOD2(i, 16)); + } +} + + +#if defined(LIBXSMM_INTRINSICS_AVX512) /* __AVX512F__ */ +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +void internal_rng_set_seed_avx512(uint32_t seed) +{ + internal_rng_set_seed_sw(seed); + /* bring scalar state to AVX-512 */ + LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_loadu_si512(internal_rng_state0); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_loadu_si512(internal_rng_state1); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_loadu_si512(internal_rng_state2); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_loadu_si512(internal_rng_state3); +} + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) +void internal_rng_f32_seq_avx512(float* rngs, libxsmm_blasint count) +{ + if ((LIBXSMM_RNG_SIMD_MIN << 4) <= count) { /* SIMD code path */ + const libxsmm_blasint n = (count >> 4) << 4; /* multiple of vector-length */ + libxsmm_blasint i = 0; + for (; i < n; i += 16) { + _mm512_storeu_ps(rngs + i, LIBXSMM_INTRINSICS_MM512_RNG_PS()); + } + if (i < count) { /* remainder */ +#if 0 /* assert(0 < n) */ + if (0 < n) +#endif + { /* bring AVX-512 state to scalar */ + _mm512_storeu_si512(internal_rng_state0, LIBXSMM_INTRINSICS_MM512_RNG_STATE(0)); + _mm512_storeu_si512(internal_rng_state1, LIBXSMM_INTRINSICS_MM512_RNG_STATE(1)); + _mm512_storeu_si512(internal_rng_state2, LIBXSMM_INTRINSICS_MM512_RNG_STATE(2)); + _mm512_storeu_si512(internal_rng_state3, LIBXSMM_INTRINSICS_MM512_RNG_STATE(3)); + } + LIBXSMM_ASSERT(count < i + 16); + do { /* scalar remainder */ + rngs[i] = internal_rng_scalar_float_next(LIBXSMM_MOD2(i, 16)); + ++i; + } while (i < count); + /* bring scalar state to AVX-512 */ + LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_loadu_si512(internal_rng_state0); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_loadu_si512(internal_rng_state1); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_loadu_si512(internal_rng_state2); + LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_loadu_si512(internal_rng_state3); + } + } + else { /* scalar code path */ + internal_rng_f32_seq_sw(rngs, count); + } +} +#endif /*defined(LIBXSMM_INTRINSICS_AVX512)*/ + + +LIBXSMM_API unsigned int* libxsmm_rng_create_extstate(unsigned int/*uint32_t*/ seed) +{ + unsigned int* state = (unsigned int*) libxsmm_aligned_malloc( 64*sizeof(unsigned int), 64 ); + static const uint32_t temp_state[] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 131, 130, 129, 128, 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116, + 231, 230, 229, 228, 227, 226, 225, 224, 223, 222, 221, 220, 219, 218, 217, 216, + 331, 330, 329, 328, 327, 326, 325, 324, 323, 322, 321, 320, 319, 318, 317, 316 + }; + libxsmm_blasint i; + + /* finish initializing the state */ + LIBXSMM_ASSERT((16 * 4) == sizeof(temp_state) / sizeof(*temp_state)); + for (i = 0; i < 16; ++i) { + state[i ] = seed + temp_state[i]; + state[i+16] = seed + temp_state[i+16]; + state[i+32] = seed + temp_state[i+32]; + state[i+48] = seed + temp_state[i+48]; + } + for (i = 0; i < 16; ++i) { + internal_rng_float_jump( /* progress each sequence by 2^64 */ + state + i, state + 16 + i, + state + 32 + i, state + 48 + i); + } + + return state; +} + + +LIBXSMM_API void libxsmm_rng_destroy_extstate(unsigned int* stateptr) +{ + if ( stateptr != NULL ) { + libxsmm_free( stateptr ); + } +} + + +LIBXSMM_API void libxsmm_rng_set_seed(unsigned int/*uint32_t*/ seed) +{ + LIBXSMM_INIT +#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) +# if !defined(NDEBUG) /* used to track if seed is initialized */ + internal_rng_f32_seq = internal_rng_f32_seq_avx512; +# endif + internal_rng_set_seed_avx512(seed); +#elif defined(LIBXSMM_INTRINSICS_AVX512) /* __AVX512F__ */ + if (LIBXSMM_X86_AVX512 <= libxsmm_target_archid) { + internal_rng_f32_seq = internal_rng_f32_seq_avx512; + internal_rng_set_seed_avx512(seed); + } + else { + internal_rng_f32_seq = internal_rng_f32_seq_sw; + internal_rng_set_seed_sw(seed); + } +#else +# if !defined(NDEBUG) /* used to track if seed is initialized */ + internal_rng_f32_seq = internal_rng_f32_seq_sw; +# endif + internal_rng_set_seed_sw(seed); +#endif +} + + +LIBXSMM_API void libxsmm_rng_f32_seq(float* rngs, libxsmm_blasint count) +{ + LIBXSMM_ASSERT_MSG(NULL != internal_rng_f32_seq, "RNG must be initialized"); +#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH) + internal_rng_f32_seq_avx512(rngs, count); +#else +# if defined(LIBXSMM_INTRINSICS_AVX512) /* __AVX512F__ */ + if ((LIBXSMM_RNG_SIMD_MIN << 4) <= count) { /* SIMD code path */ + internal_rng_f32_seq(rngs, count); /* pointer based function call */ + } + else /* scalar code path */ +# endif + internal_rng_f32_seq_sw(rngs, count); +#endif +} + + +LIBXSMM_API unsigned int libxsmm_rng_u32(unsigned int n) +{ +#if defined(LIBXSMM_RNG_DRAND48) + const unsigned int q = ((1U << 31) / n) * n; + unsigned int r = (unsigned int)lrand48(); + if (q != (1U << 31)) +#else + const unsigned int rand_max1 = (unsigned int)(RAND_MAX)+1U; + const unsigned int q = (rand_max1 / n) * n; + unsigned int r = (unsigned int)rand(); + if (q != rand_max1) +#endif + { +#if defined(LIBXSMM_RNG_DRAND48) + /* coverity[dont_call] */ + while (q <= r) r = (unsigned int)lrand48(); +#else + while (q <= r) r = (unsigned int)rand(); +#endif + } + return r % n; +} + + +LIBXSMM_API void libxsmm_rng_seq(void* data, libxsmm_blasint nbytes) +{ + unsigned char* dst = (unsigned char*)data; + unsigned char* end = dst + (nbytes & 0xFFFFFFFFFFFFFFFC); + unsigned int r; + for (; dst < end; dst += 4) { +#if defined(LIBXSMM_RNG_DRAND48) + /* coverity[dont_call] */ + r = (unsigned int)lrand48(); +#else + r = (unsigned int)rand(); +#endif + LIBXSMM_MEMCPY127(dst, &r, 4); + } + end = (unsigned char*)data + nbytes; + if (dst < end) { +#if defined(LIBXSMM_RNG_DRAND48) + r = (unsigned int)lrand48(); +#else + r = (unsigned int)rand(); +#endif + LIBXSMM_MEMCPY127(dst, &r, end - dst); + } +} + + +LIBXSMM_API double libxsmm_rng_f64(void) +{ +#if defined(LIBXSMM_RNG_DRAND48) + /* coverity[dont_call] */ + return drand48(); +#else + static const double scale = 1.0 / (RAND_MAX); + return scale * (double)rand(); +#endif +} + diff --git a/third_party/libxsmm/src/libxsmm_spmdm.c b/third_party/libxsmm/src/libxsmm_spmdm.c new file mode 100644 index 0000000000000000000000000000000000000000..4d677226a7d18711e084ca8a63807ef513a7cbeb --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_spmdm.c @@ -0,0 +1,612 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Nadathur Satish, Hans Pabst (Intel Corp.) +******************************************************************************/ +#include +#include "libxsmm_main.h" + +/* Enable/disable specific code paths */ +#if defined(LIBXSMM_INTRINSICS_AVX) && !defined(LIBXSMM_SPMDM_AVX) +# define LIBXSMM_SPMDM_AVX +#endif +#if defined(LIBXSMM_INTRINSICS_AVX2) && !defined(LIBXSMM_SPMDM_AVX2) && \ + !(defined(__PGI) && defined(__cplusplus)) +# define LIBXSMM_SPMDM_AVX2 +#endif +#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) && !defined(LIBXSMM_SPMDM_AVX512_CORE) && \ + !(defined(__PGI) && defined(__cplusplus)) +# define LIBXSMM_SPMDM_AVX512_CORE +#endif + + +/* function pointer for the CPUID-dispatched implementation (separate typedef for legacy Cray C++ needed) */ +typedef void (*internal_spmdm_createSparseSlice_fp32_thread_fn)(const libxsmm_spmdm_handle*, char, const float*, libxsmm_CSR_sparseslice*, int, int, int); +LIBXSMM_APIVAR_DEFINE(internal_spmdm_createSparseSlice_fp32_thread_fn internal_spmdm_createSparseSlice_fp32_thread); +typedef void (*internal_spmdm_createSparseSlice_bfloat16_thread_fn)(const libxsmm_spmdm_handle*, char, const libxsmm_bfloat16*, libxsmm_CSR_sparseslice*, int, int, int); +LIBXSMM_APIVAR_DEFINE(internal_spmdm_createSparseSlice_bfloat16_thread_fn internal_spmdm_createSparseSlice_bfloat16_thread); +typedef void (*internal_spmdm_compute_fp32_thread_fn)(const libxsmm_spmdm_handle*, char, char, const float*, libxsmm_CSR_sparseslice*, const float*, char, const float*, float*, int, int, int); +LIBXSMM_APIVAR_DEFINE(internal_spmdm_compute_fp32_thread_fn internal_spmdm_compute_fp32_thread); +typedef void (*internal_spmdm_compute_bfloat16_thread_fn)(const libxsmm_spmdm_handle*, char, char, const libxsmm_bfloat16*, libxsmm_CSR_sparseslice*, const libxsmm_bfloat16*, char, const libxsmm_bfloat16*, float*, int, int, int); +LIBXSMM_APIVAR_DEFINE(internal_spmdm_compute_bfloat16_thread_fn internal_spmdm_compute_bfloat16_thread); + +#if defined(LIBXSMM_SPMDM_AVX) +LIBXSMM_APIVAR_DEFINE(__m256i* internal_spmdm_shufmasks_32); +LIBXSMM_APIVAR_DEFINE(__m256i* internal_spmdm_shufmasks_16); +#endif + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX) +LIBXSMM_ATTRIBUTE_UNUSED void internal_spmdm_init_shufmask_avx(void) +{ +#if defined(LIBXSMM_SPMDM_AVX) + static __m256i spmdm_shufmasks_32[256], spmdm_shufmasks_16[256]; + LIBXSMM_ALIGNED(int temp_shufmasks[8], 64); + LIBXSMM_ALIGNED(uint16_t temp_shufmasks2[16], 64); + unsigned int i, j, c, last_bit; + int cnt; + for (i = 0; i < 256; i++) { + cnt = 0; + j = i; + for (c = 0; c < 8; c++) temp_shufmasks[c] = 0; + for (c = 0; c < 16; c++) temp_shufmasks2[c] = 0; + while (j) { + last_bit = LIBXSMM_INTRINSICS_BITSCANFWD32(j); + temp_shufmasks[cnt] = last_bit; + temp_shufmasks2[cnt] = (uint16_t)last_bit; + j &= (~(1<mb; + int k_blocks = handle->kb; + + const size_t sz_block = (((size_t)handle->bm + 1) * sizeof(uint16_t) + + (size_t)handle->bm * handle->bk * sizeof(uint16_t) + + (size_t)handle->bm * handle->bk * sizeof(float) + + sizeof(libxsmm_CSR_sparseslice)); + size_t sz_all_blocks = sz_block * handle->mb * handle->kb; + char* memory_block = 0; + void *const pv = &memory_block; + + /* use low-level scratch memory allocation since life-time of this buffer is unknown */ + if (EXIT_SUCCESS == libxsmm_xmalloc((void**)pv, sz_all_blocks, 2097152, + LIBXSMM_MALLOC_FLAG_SCRATCH | LIBXSMM_MALLOC_FLAG_PRIVATE, 0/*extra*/, 0/*extra_size*/)) + { + char* memory_head = memory_block; + libxsmm_CSR_sparseslice* libxsmm_output_csr_a = (libxsmm_CSR_sparseslice*)(memory_head); + memory_head += (size_t)handle->mb * handle->kb * sizeof(libxsmm_CSR_sparseslice); + LIBXSMM_ASSERT(0 != libxsmm_output_csr_a/*sanity check*/); + + for (kb = 0; kb < k_blocks; kb++) { + for (mb = 0; mb < m_blocks; mb++) { + int i = kb*m_blocks + mb; + libxsmm_output_csr_a[i].rowidx = (uint16_t*)(memory_head); + memory_head += ((size_t)handle->bm + 1) * sizeof(uint16_t); + libxsmm_output_csr_a[i].colidx = (uint16_t*)(memory_head); + memory_head += (size_t)handle->bm * handle->bk * sizeof(uint16_t); + libxsmm_output_csr_a[i].values = (float*)(memory_head); + memory_head += (size_t)handle->bm * handle->bk * sizeof(float); + } + } + LIBXSMM_ASSERT(memory_head == (memory_block + sz_all_blocks)); + *libxsmm_output_csr = libxsmm_output_csr_a; + } + else if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + fprintf(stderr, "LIBXSMM ERROR: SPMDM CSR scratch memory allocation failed!\n"); + } + + handle->base_ptr_scratch_A = memory_block; +} + + +LIBXSMM_API_INLINE void internal_spmdm_allocate_scratch(libxsmm_spmdm_handle* handle, int max_threads) +{ + void *const pv = &handle->base_ptr_scratch_B_scratch_C; + size_t sz_total_memory, sz_memory_for_scratch_per_thread = + (size_t)handle->bm * handle->bn * sizeof(float) + + (size_t)handle->bk * handle->bn * sizeof(float); + sz_memory_for_scratch_per_thread = LIBXSMM_UP2(sz_memory_for_scratch_per_thread, 4096); + sz_total_memory = sz_memory_for_scratch_per_thread * max_threads; + handle->base_ptr_scratch_B_scratch_C = 0; + + /* use low-level scratch memory allocation since life-time of this buffer is unknown */ + if (EXIT_SUCCESS == libxsmm_xmalloc((void**)pv, sz_total_memory, 2097152, + LIBXSMM_MALLOC_FLAG_SCRATCH | LIBXSMM_MALLOC_FLAG_PRIVATE, 0/*extra*/, 0/*extra_size*/)) + { + handle->memory_for_scratch_per_thread = (int)sz_memory_for_scratch_per_thread; + } + else { + if (0 != libxsmm_verbosity) { /* library code is expected to be mute */ + fprintf(stderr, "LIBXSMM ERROR: SPMDM scratch memory allocation failed!\n"); + } + handle->memory_for_scratch_per_thread = 0; + } +} + + +LIBXSMM_API_INLINE void internal_spmdm_deallocate_csr_a(libxsmm_spmdm_handle* handle) +{ + libxsmm_xfree(handle->base_ptr_scratch_A, 0/*no check*/); + handle->base_ptr_scratch_A = NULL; + libxsmm_xfree(handle->base_ptr_scratch_B_scratch_C, 0/*no check*/); + handle->base_ptr_scratch_B_scratch_C = NULL; +} + + +LIBXSMM_API void libxsmm_spmdm_destroy(libxsmm_spmdm_handle* handle) +{ + internal_spmdm_deallocate_csr_a(handle); +} + + +LIBXSMM_API int libxsmm_spmdm_get_num_createSparseSlice_blocks(const libxsmm_spmdm_handle* handle) +{ + return handle->mb * handle->kb; +} + + +LIBXSMM_API int libxsmm_spmdm_get_num_compute_blocks(const libxsmm_spmdm_handle* handle) +{ + return handle->mb * handle->nb; +} + + +LIBXSMM_API_INLINE +void internal_spmdm_createSparseSlice_fp32_thread_sw( + const libxsmm_spmdm_handle* handle, + char transa, + const float* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads) +{ +# include "libxsmm_spmdm_begin.h" +# include "template/libxsmm_spmdm_createSparseSlice_fp32_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) +LIBXSMM_ATTRIBUTE_UNUSED void internal_spmdm_createSparseSlice_fp32_thread_avx2( + const libxsmm_spmdm_handle* handle, + char transa, + const float* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads) +{ +#if defined(LIBXSMM_SPMDM_AVX2) +# include "libxsmm_spmdm_begin_avx2.h" +# include "template/libxsmm_spmdm_createSparseSlice_fp32_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +#else + internal_spmdm_createSparseSlice_fp32_thread_sw(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#endif +} + + +#if defined(LIBXSMM_SPMDM_AVX512_CORE) +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +LIBXSMM_ATTRIBUTE_UNUSED void internal_spmdm_createSparseSlice_fp32_thread_avx512_core( + const libxsmm_spmdm_handle* handle, + char transa, + const float* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads) +{ +#if defined(LIBXSMM_SPMDM_AVX512_CORE) +# include "libxsmm_spmdm_begin_avx512.h" +# include "template/libxsmm_spmdm_createSparseSlice_fp32_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +#else + internal_spmdm_createSparseSlice_fp32_thread_avx2(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#endif +} +#endif + + +LIBXSMM_API +void libxsmm_spmdm_createSparseSlice_fp32_thread( + const libxsmm_spmdm_handle* handle, + char transa, + const float* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads) +{ + /* if highest implemented code path is statically present, no need for an indirect call (function pointer) */ +#if (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_STATIC_TARGET_ARCH) && defined(LIBXSMM_SPMDM_AVX512_CORE) + internal_spmdm_createSparseSlice_fp32_thread_avx512_core(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#elif (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH) && /* no need for an indirect call */ \ + (LIBXSMM_X86_AVX512_CORE > LIBXSMM_MAX_STATIC_TARGET_ARCH) + internal_spmdm_createSparseSlice_fp32_thread_avx2(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#else /* pointer based function call */ + LIBXSMM_ASSERT(0 != internal_spmdm_createSparseSlice_fp32_thread); + internal_spmdm_createSparseSlice_fp32_thread(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#endif +} + + +LIBXSMM_API_INLINE +void internal_spmdm_createSparseSlice_bfloat16_thread_sw( + const libxsmm_spmdm_handle* handle, + char transa, + const libxsmm_bfloat16* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads) +{ +# include "libxsmm_spmdm_begin.h" +# include "template/libxsmm_spmdm_createSparseSlice_bfloat16_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) +LIBXSMM_ATTRIBUTE_UNUSED void internal_spmdm_createSparseSlice_bfloat16_thread_avx2( + const libxsmm_spmdm_handle* handle, + char transa, + const libxsmm_bfloat16* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads) +{ +#if defined(LIBXSMM_SPMDM_AVX2) +# include "libxsmm_spmdm_begin_avx2.h" +# include "template/libxsmm_spmdm_createSparseSlice_bfloat16_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +#else + internal_spmdm_createSparseSlice_bfloat16_thread_sw(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#endif +} + + +#if defined(LIBXSMM_SPMDM_AVX512_CORE) +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +LIBXSMM_ATTRIBUTE_UNUSED void internal_spmdm_createSparseSlice_bfloat16_thread_avx512_core( + const libxsmm_spmdm_handle* handle, + char transa, + const libxsmm_bfloat16* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads) +{ +#if defined(LIBXSMM_SPMDM_AVX512_CORE) +# include "libxsmm_spmdm_begin_avx512.h" +# include "template/libxsmm_spmdm_createSparseSlice_bfloat16_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +#else + internal_spmdm_createSparseSlice_bfloat16_thread_avx2(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#endif +} +#endif + + +LIBXSMM_API +void libxsmm_spmdm_createSparseSlice_bfloat16_thread( + const libxsmm_spmdm_handle* handle, + char transa, + const libxsmm_bfloat16* a, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, + int block_id, + int tid, int nthreads) +{ + /* if highest implemented code path is statically present, no need for an indirect call (function pointer) */ +#if (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_STATIC_TARGET_ARCH) && defined(LIBXSMM_SPMDM_AVX512_CORE) + internal_spmdm_createSparseSlice_bfloat16_thread_avx512_core(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#elif (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH) && /* no need for an indirect call */ \ + (LIBXSMM_X86_AVX512_CORE > LIBXSMM_MAX_STATIC_TARGET_ARCH) + internal_spmdm_createSparseSlice_bfloat16_thread_avx2(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#else /* pointer based function call */ + LIBXSMM_ASSERT(0 != internal_spmdm_createSparseSlice_fp32_thread); + internal_spmdm_createSparseSlice_bfloat16_thread(handle, transa, a, libxsmm_output_csr_a, block_id, tid, nthreads); +#endif +} + + +LIBXSMM_API_INLINE +void internal_spmdm_compute_fp32_thread_sw( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const float* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const float* b, + char transc, + const float* beta, + float* c, + int block_id, + int tid, int nthreads) +{ +# include "libxsmm_spmdm_begin.h" +# include "template/libxsmm_spmdm_compute_fp32_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) +LIBXSMM_ATTRIBUTE_UNUSED void internal_spmdm_compute_fp32_thread_avx2( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const float* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const float* b, + char transc, + const float* beta, + float* c, + int block_id, + int tid, int nthreads) +{ +#if defined(LIBXSMM_SPMDM_AVX2) +# include "libxsmm_spmdm_begin_avx2.h" +# include "template/libxsmm_spmdm_compute_fp32_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +#else + internal_spmdm_compute_fp32_thread_sw(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#endif +} + + +#if defined(LIBXSMM_SPMDM_AVX512_CORE) +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +LIBXSMM_ATTRIBUTE_UNUSED void internal_spmdm_compute_fp32_thread_avx512_core( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const float* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const float* b, + char transc, + const float* beta, + float* c, + int block_id, + int tid, int nthreads) +{ +#if defined(LIBXSMM_SPMDM_AVX512_CORE) +# include "libxsmm_spmdm_begin_avx512.h" +# include "template/libxsmm_spmdm_compute_fp32_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +#else + internal_spmdm_compute_fp32_thread_avx2(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#endif +} +#endif + + +LIBXSMM_API +void libxsmm_spmdm_compute_fp32_thread( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const float* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const float* b, + char transc, + const float* beta, + float* c, + int block_id, + int tid, int nthreads) +{ + /* if highest implemented code path is statically present, no need for an indirect call (function pointer) */ +#if (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_STATIC_TARGET_ARCH) && defined(LIBXSMM_SPMDM_AVX512_CORE) + internal_spmdm_compute_fp32_thread_avx512_core(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#elif (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH) && /* no need for an indirect call */ \ + (LIBXSMM_X86_AVX512_CORE > LIBXSMM_MAX_STATIC_TARGET_ARCH) + internal_spmdm_compute_fp32_thread_avx2(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#else /* pointer based function call */ + LIBXSMM_ASSERT(0 != internal_spmdm_compute_fp32_thread); + internal_spmdm_compute_fp32_thread(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#endif +} + + +LIBXSMM_API_INLINE +void internal_spmdm_compute_bfloat16_thread_sw( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const libxsmm_bfloat16* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const libxsmm_bfloat16* b, + char transc, + const libxsmm_bfloat16* beta, + float* c, + int block_id, + int tid, int nthreads) +{ +# include "libxsmm_spmdm_begin.h" +# include "template/libxsmm_spmdm_compute_bfloat16_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +} + + +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) +LIBXSMM_ATTRIBUTE_UNUSED void internal_spmdm_compute_bfloat16_thread_avx2( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const libxsmm_bfloat16* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const libxsmm_bfloat16* b, + char transc, + const libxsmm_bfloat16* beta, + float* c, + int block_id, + int tid, int nthreads) +{ +#if defined(LIBXSMM_SPMDM_AVX2) +# include "libxsmm_spmdm_begin_avx2.h" +# include "template/libxsmm_spmdm_compute_bfloat16_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +#else + internal_spmdm_compute_bfloat16_thread_sw(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#endif +} + + +#if defined(LIBXSMM_SPMDM_AVX512_CORE) +LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) +LIBXSMM_ATTRIBUTE_UNUSED void internal_spmdm_compute_bfloat16_thread_avx512_core( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const libxsmm_bfloat16* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const libxsmm_bfloat16* b, + char transc, + const libxsmm_bfloat16* beta, + float* c, + int block_id, + int tid, int nthreads) +{ +#if defined(LIBXSMM_SPMDM_AVX512_CORE) +# include "libxsmm_spmdm_begin_avx512.h" +# include "template/libxsmm_spmdm_compute_bfloat16_thread.tpl.c" +# include "libxsmm_spmdm_end.h" +#else + internal_spmdm_compute_bfloat16_thread_avx2(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#endif +} +#endif + + +LIBXSMM_API +void libxsmm_spmdm_compute_bfloat16_thread( + const libxsmm_spmdm_handle* handle, + char transa, + char transb, + const libxsmm_bfloat16* alpha, + libxsmm_CSR_sparseslice* a_sparse, + const libxsmm_bfloat16* b, + char transc, + const libxsmm_bfloat16* beta, + float* c, + int block_id, + int tid, int nthreads) +{ + /* if highest implemented code path is statically present, no need for an indirect call (function pointer) */ +#if (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_STATIC_TARGET_ARCH) && defined(LIBXSMM_SPMDM_AVX512_CORE) + internal_spmdm_compute_bfloat16_thread_avx512_core(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#elif (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH) && /* no need for an indirect call */ \ + (LIBXSMM_X86_AVX512_CORE > LIBXSMM_MAX_STATIC_TARGET_ARCH) + internal_spmdm_compute_bfloat16_thread_avx2(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#else /* pointer based function call */ + LIBXSMM_ASSERT(0 != internal_spmdm_compute_bfloat16_thread); + internal_spmdm_compute_bfloat16_thread(handle, transa, transb, alpha, a_sparse, b, transc, beta, c, block_id, tid, nthreads); +#endif +} + + +LIBXSMM_API void libxsmm_spmdm_init(int M, int N, int K, int max_threads, + libxsmm_spmdm_handle* handle, libxsmm_CSR_sparseslice** libxsmm_output_csr) +{ + double load_imbalance_tolerate = 1.1; + int max_work_per_block; + double avg_work_per_block; + int max_blocks_per_thread; + double avg_blocks_per_thread; + double load_imbalance_1, load_imbalance_2, load_imbalance; + + libxsmm_init(); /* !LIBXSMM_INIT */ + { unsigned int dummy = + LIBXSMM_ATOMIC_ADD_FETCH(&libxsmm_statistic_num_spmdm, 1, + LIBXSMM_ATOMIC_RELAXED); /* count number of invocations */ + LIBXSMM_UNUSED(dummy); + } + + handle->m = M; + handle->n = N; + handle->k = K; + handle->bm = (M >= 4096 || M <= 1024) ? 512 : 256; + +#if defined(LIBXSMM_SPMDM_AVX512_CORE) + if (LIBXSMM_X86_AVX512_CORE <= libxsmm_target_archid || LIBXSMM_X86_AVX512_CORE <= LIBXSMM_STATIC_TARGET_ARCH) { + internal_spmdm_createSparseSlice_fp32_thread = internal_spmdm_createSparseSlice_fp32_thread_avx512_core; + internal_spmdm_createSparseSlice_bfloat16_thread = internal_spmdm_createSparseSlice_bfloat16_thread_avx512_core; + internal_spmdm_compute_fp32_thread = internal_spmdm_compute_fp32_thread_avx512_core; + internal_spmdm_compute_bfloat16_thread = internal_spmdm_compute_bfloat16_thread_avx512_core; + handle->bn = 96; + } + else +#endif +#if defined(LIBXSMM_SPMDM_AVX2) + if (LIBXSMM_X86_AVX2 <= libxsmm_target_archid || LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH) { + internal_spmdm_createSparseSlice_fp32_thread = internal_spmdm_createSparseSlice_fp32_thread_avx2; + internal_spmdm_createSparseSlice_bfloat16_thread = internal_spmdm_createSparseSlice_bfloat16_thread_avx2; + internal_spmdm_compute_fp32_thread = internal_spmdm_compute_fp32_thread_avx2; + internal_spmdm_compute_bfloat16_thread = internal_spmdm_compute_bfloat16_thread_avx2; + handle->bn = 48; + } + else +#endif + { + internal_spmdm_createSparseSlice_fp32_thread = internal_spmdm_createSparseSlice_fp32_thread_sw; + internal_spmdm_createSparseSlice_bfloat16_thread = internal_spmdm_createSparseSlice_bfloat16_thread_sw; + internal_spmdm_compute_fp32_thread = internal_spmdm_compute_fp32_thread_sw; + internal_spmdm_compute_bfloat16_thread = internal_spmdm_compute_bfloat16_thread_sw; + handle->bn = 6; + } + handle->bk = 128; + handle->mb = LIBXSMM_UPDIV(handle->m, handle->bm); + handle->nb = LIBXSMM_UPDIV(handle->n, handle->bn); + handle->kb = LIBXSMM_UPDIV(handle->k, handle->bk); + + max_work_per_block = handle->bm * handle->bn; + avg_work_per_block = (double)((size_t)handle->m * handle->n) / ((size_t)handle->mb * handle->nb); + load_imbalance_1 = max_work_per_block / avg_work_per_block; + max_blocks_per_thread = LIBXSMM_UPDIV(handle->mb * handle->nb, max_threads); + avg_blocks_per_thread = (double)handle->mb * handle->nb / max_threads; + load_imbalance_2 = max_blocks_per_thread / avg_blocks_per_thread; + load_imbalance = load_imbalance_1 * load_imbalance_2; + + while (32 < handle->bm && load_imbalance > load_imbalance_tolerate) { + handle->bm--; + handle->mb = LIBXSMM_UPDIV(handle->m, handle->bm); + + max_blocks_per_thread = LIBXSMM_UPDIV(handle->mb * handle->nb, max_threads); + avg_blocks_per_thread = (double)handle->mb * handle->nb / max_threads; + load_imbalance_2 = max_blocks_per_thread / avg_blocks_per_thread; + max_work_per_block = handle->bm * handle->bn; + avg_work_per_block = (double)((size_t)handle->m * handle->n) / ((size_t)handle->mb * handle->nb); + load_imbalance_1 = max_work_per_block / avg_work_per_block; + load_imbalance = load_imbalance_1 * load_imbalance_2; + } + + /* This is temporary space needed; allocate for each different size of a */ + internal_spmdm_allocate_csr_a(handle, libxsmm_output_csr); + internal_spmdm_allocate_scratch(handle, max_threads); + + /* Initialize shuffle masks for the computation */ +#if defined(LIBXSMM_SPMDM_AVX) + if (LIBXSMM_X86_AVX <= libxsmm_target_archid || LIBXSMM_X86_AVX <= LIBXSMM_STATIC_TARGET_ARCH) { + internal_spmdm_init_shufmask_avx(); + LIBXSMM_ASSERT(0 != internal_spmdm_shufmasks_32); + LIBXSMM_ASSERT(0 != internal_spmdm_shufmasks_16); + } +#endif + /* post-conditions */ + LIBXSMM_ASSERT(0 != internal_spmdm_createSparseSlice_fp32_thread); + LIBXSMM_ASSERT(0 != internal_spmdm_createSparseSlice_bfloat16_thread); + LIBXSMM_ASSERT(0 != internal_spmdm_compute_fp32_thread); + LIBXSMM_ASSERT(0 != internal_spmdm_compute_bfloat16_thread); +} + diff --git a/third_party/libxsmm/src/libxsmm_spmdm_begin.h b/third_party/libxsmm/src/libxsmm_spmdm_begin.h new file mode 100644 index 0000000000000000000000000000000000000000..af703326b0d93c90981c451cd0fc39ef69c77317 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_spmdm_begin.h @@ -0,0 +1,64 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Nadathur Satish, Hans Pabst (Intel Corp.) +******************************************************************************/ + +#define SIMD_WIDTH_FP32 (1) +#define SIMDTYPE_FP32 float +#define SIMDTYPE_INT32 int +#define SIMDMASKTYPE_FP32 int +#define _MM_SETZERO_FP32() (0) +#define _MM_SETZERO_INT32() (0) +#define _MM_SET1_FP32(x) (x) +#define _MM_SET1_INT32(x) (x) +#define _MM_SET1_INT16 (x) +#define _MM_LOAD_FP32(x) (*(x)) +#define _MM_LOADU_FP32(x) (*(x)) +#define _MM_LOAD_INT32(x) (*(x)) +#define _MM_STORE_INT32(x,y) ((*(x)) = (y)) +#define _MM_LOADU_INT32(x) (*(x)) +#define _MM_GATHER_FP32(Addr, idx, scale) (*(Addr + (idx))) +#define _MM_CMPNEQ_FP32(v1,v2) (LIBXSMM_FEQ(v1, v2) ? 0 : 1) +#define _MM_STORE_FP32(x,y) ((*(x)) = (y)) +#define _MM_STOREU_FP32(x,y) ((*(x)) = (y)) +#define _MM_ADD_FP32(x,y) ((x) + (y)) +#define _MM_FMADD_FP32(x,y,z) (((x)*(y))+(z)) +#define _MM_MUL_FP32(x,y) ((x)*(y)) +#define _MM_PREFETCH(x, y) +#define TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_A, ldA, ptr_B, ldB) ((*(ptr_B)) = (*(ptr_A))) +#define TRANSPOSE_SIMD_WIDTH_KERNEL_BFLOAT16(ptr_A, ldA, ptr_B, ldB) { \ + uint16_t restmp = (*(ptr_A)); \ + union { int i; float f; } res; \ + res.i = restmp; \ + res.i <<= 16; \ + (*(ptr_B)) = res.f; \ +} + +#define COMPRESS_FP32(v, k, m, cnt) if (m) { \ + values_ptr[cnt] = v; \ + colidx_ptr[cnt] = (uint16_t)(k); \ + cnt++; \ +} + +#define EXPAND_BFLOAT16(v, vlo_final, vhi_final) { \ + union { int i; float f; } vlo_tmp, vhi_tmp; \ + vlo_tmp.i = (v) & 0xFFFF; vlo_tmp.i <<= 16; \ + vlo_final = vlo_tmp.f; \ + vhi_tmp.i = (v) & 0x0000FFFF; \ + vhi_final = vhi_tmp.f; \ +} + +#define COMPRESS_BFLOAT16(vlo, vhi, v) { \ + union { int i; float f; } vlo_tmp, vhi_tmp; \ + vlo_tmp.f = vlo; \ + v = (vlo_tmp.i >> 16); \ + vhi_tmp.f = vhi; \ + v = v | (vhi_tmp.i & 0xFFFF0000); \ +} + diff --git a/third_party/libxsmm/src/libxsmm_spmdm_begin_avx2.h b/third_party/libxsmm/src/libxsmm_spmdm_begin_avx2.h new file mode 100644 index 0000000000000000000000000000000000000000..0912a4895379c98508070c26acd0887696d8f3ae --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_spmdm_begin_avx2.h @@ -0,0 +1,166 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Nadathur Satish, Hans Pabst (Intel Corp.) +******************************************************************************/ +#if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) +# error "libxsmm_intrinsics_x86.h not included!" +#endif + +#if (LIBXSMM_X86_AVX2 <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +#define SIMD_WIDTH_FP32 (8) +#define SIMDTYPE_FP32 __m256 +#define SIMDTYPE_INT32 __m256i +#define SIMDMASKTYPE_FP32 __m256 +#define _MM_SETZERO_FP32 _mm256_setzero_ps +#define _MM_SETZERO_INT32 _mm256_setzero_si256 +#define _MM_SET1_FP32 _mm256_set1_ps +#define _MM_SET1_INT32 _mm256_set1_epi32 +#define _MM_SET1_INT16 _mm256_set1_epi16 +#define _MM_SET_INT32 _mm256_set_epi32 +#define _MM_LOAD_FP32 _mm256_loadu_ps +#define _MM_LOADU_FP32 _mm256_loadu_ps +#define _MM_LOAD_INT32 _mm256_loadu_si256 +#define _MM_STORE_INT32 _mm256_storeu_si256 +#define _MM_LOADU_INT32(x) _mm256_loadu_si256( (__m256i const *)(x)) +#define _MM_GATHER_INT32(Addr, idx, scale) _mm256_i32gather_epi32((Addr), (idx), (scale)) +#define _MM_GATHER_FP32(Addr, idx, scale) _mm256_i32gather_ps(((float const *)(Addr)), (idx), (scale)) +#define _MM_CMPNEQ_FP32(v1,v2) _mm256_cmp_ps(v1,v2,12) +#define _MM_STORE_FP32 _mm256_storeu_ps +#define _MM_STOREU_FP32 _mm256_storeu_ps +#define _MM_ADD_FP32 _mm256_add_ps +#define _MM_FMADD_FP32 _mm256_fmadd_ps +#define _MM_MUL_FP32 _mm256_mul_ps +#define _MM_PREFETCH(x, y) _mm_prefetch(x, y) +#define TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_A, ldA, ptr_B, ldB) { \ + __m256 ymm9 = _mm256_loadu_ps(ptr_A); \ + __m256 ymm10 = _mm256_loadu_ps(ptr_A + (size_t)ldA); \ + __m256 ymm11 = _mm256_loadu_ps(ptr_A + (size_t)ldA*2); \ + __m256 ymm12 = _mm256_loadu_ps(ptr_A + (size_t)ldA*3); \ + __m256 ymm13 = _mm256_loadu_ps(ptr_A + (size_t)ldA*4); \ + __m256 ymm14 = _mm256_loadu_ps(ptr_A + (size_t)ldA*5); \ + __m256 ymm15 = _mm256_loadu_ps(ptr_A + (size_t)ldA*6); \ + __m256 ymm2 = _mm256_loadu_ps(ptr_A + (size_t)ldA*7); \ + __m256 ymm6 = _mm256_unpacklo_ps(ymm9, ymm10); \ + __m256 ymm1 = _mm256_unpacklo_ps(ymm11, ymm12); \ + __m256 ymm8 = _mm256_unpackhi_ps(ymm9, ymm10); \ + __m256 ymm0 = _mm256_unpacklo_ps(ymm13, ymm14); \ + ymm9 = _mm256_unpacklo_ps(ymm15, ymm2);{ \ + __m256 ymm3 = _mm256_shuffle_ps(ymm6, ymm1, 0x4E); \ + ymm10 = _mm256_blend_ps(ymm6, ymm3, 0xCC); \ + ymm6 = _mm256_shuffle_ps(ymm0, ymm9, 0x4E);{ \ + __m256 ymm7 = _mm256_unpackhi_ps(ymm11, ymm12); \ + ymm11 = _mm256_blend_ps(ymm0, ymm6, 0xCC); \ + ymm12 = _mm256_blend_ps(ymm3, ymm1, 0xCC); \ + ymm3 = _mm256_permute2f128_ps(ymm10, ymm11, 0x20); \ + _mm256_storeu_ps(ptr_B, ymm3);{ \ + __m256 ymm5 = _mm256_unpackhi_ps(ymm13, ymm14); \ + ymm13 = _mm256_blend_ps(ymm6, ymm9, 0xCC);{ \ + __m256 ymm4 = _mm256_unpackhi_ps(ymm15, ymm2); \ + ymm2 = _mm256_permute2f128_ps(ymm12, ymm13, 0x20); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB, ymm2); \ + ymm14 = _mm256_shuffle_ps(ymm8, ymm7, 0x4E); \ + ymm15 = _mm256_blend_ps(ymm14, ymm7, 0xCC); \ + ymm7 = _mm256_shuffle_ps(ymm5, ymm4, 0x4E); \ + ymm8 = _mm256_blend_ps(ymm8, ymm14, 0xCC); \ + ymm5 = _mm256_blend_ps(ymm5, ymm7, 0xCC); \ + ymm6 = _mm256_permute2f128_ps(ymm8, ymm5, 0x20); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*2, ymm6); \ + ymm4 = _mm256_blend_ps(ymm7, ymm4, 0xCC); \ + ymm7 = _mm256_permute2f128_ps(ymm15, ymm4, 0x20); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*3, ymm7); \ + ymm1 = _mm256_permute2f128_ps(ymm10, ymm11, 0x31); \ + ymm0 = _mm256_permute2f128_ps(ymm12, ymm13, 0x31); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*4, ymm1); \ + ymm5 = _mm256_permute2f128_ps(ymm8, ymm5, 0x31); \ + ymm4 = _mm256_permute2f128_ps(ymm15, ymm4, 0x31); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*5, ymm0); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*6, ymm5); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*7, ymm4);}}}} \ +} + +#define TRANSPOSE_SIMD_WIDTH_KERNEL_BFLOAT16(ptr_A, ldA, ptr_B, ldB) { \ + __m256 ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15, ymm2; \ + __m256i vload_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)(ptr_A))); \ + vload_1 = _mm256_inserti128_si256(vload_1, _mm_loadu_si128((const __m128i*)(ptr_A + (size_t)ldA)), 1); \ + EXPAND_BFLOAT16(vload_1, ymm9, ymm10);{ \ + __m256i vload_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)(ptr_A + (size_t)ldA*2))); \ + vload_2 = _mm256_inserti128_si256(vload_2, _mm_loadu_si128((const __m128i*)(ptr_A + (size_t)ldA*3)), 1); \ + EXPAND_BFLOAT16(vload_2, ymm11, ymm12);{ \ + __m256i vload_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)(ptr_A + (size_t)ldA*4))); \ + vload_3 = _mm256_inserti128_si256(vload_3, _mm_loadu_si128((const __m128i*)(ptr_A + (size_t)ldA*5)), 1); \ + EXPAND_BFLOAT16(vload_3, ymm13, ymm14);{ \ + __m256i vload_4 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)(ptr_A + (size_t)ldA*6))); \ + vload_4 = _mm256_inserti128_si256(vload_4, _mm_loadu_si128((const __m128i*)(ptr_A + (size_t)ldA*7)), 1); \ + EXPAND_BFLOAT16(vload_4, ymm15, ymm2);{ \ + __m256 ymm6 = _mm256_unpacklo_ps(ymm9, ymm10); \ + __m256 ymm1 = _mm256_unpacklo_ps(ymm11, ymm12); \ + __m256 ymm8 = _mm256_unpackhi_ps(ymm9, ymm10); \ + __m256 ymm0 = _mm256_unpacklo_ps(ymm13, ymm14); \ + ymm9 = _mm256_unpacklo_ps(ymm15, ymm2);{ \ + __m256 ymm3 = _mm256_shuffle_ps(ymm6, ymm1, 0x4E); \ + ymm10 = _mm256_blend_ps(ymm6, ymm3, 0xCC); \ + ymm6 = _mm256_shuffle_ps(ymm0, ymm9, 0x4E);{ \ + __m256 ymm7 = _mm256_unpackhi_ps(ymm11, ymm12); \ + ymm11 = _mm256_blend_ps(ymm0, ymm6, 0xCC); \ + ymm12 = _mm256_blend_ps(ymm3, ymm1, 0xCC); \ + ymm3 = _mm256_permute2f128_ps(ymm10, ymm11, 0x20); \ + _mm256_storeu_ps(ptr_B, ymm3);{ \ + __m256 ymm5 = _mm256_unpackhi_ps(ymm13, ymm14); \ + ymm13 = _mm256_blend_ps(ymm6, ymm9, 0xCC);{ \ + __m256 ymm4 = _mm256_unpackhi_ps(ymm15, ymm2); \ + ymm2 = _mm256_permute2f128_ps(ymm12, ymm13, 0x20); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB, ymm2); \ + ymm14 = _mm256_shuffle_ps(ymm8, ymm7, 0x4E); \ + ymm15 = _mm256_blend_ps(ymm14, ymm7, 0xCC); \ + ymm7 = _mm256_shuffle_ps(ymm5, ymm4, 0x4E); \ + ymm8 = _mm256_blend_ps(ymm8, ymm14, 0xCC); \ + ymm5 = _mm256_blend_ps(ymm5, ymm7, 0xCC); \ + ymm6 = _mm256_permute2f128_ps(ymm8, ymm5, 0x20); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*2, ymm6); \ + ymm4 = _mm256_blend_ps(ymm7, ymm4, 0xCC); \ + ymm7 = _mm256_permute2f128_ps(ymm15, ymm4, 0x20); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*3, ymm7); \ + ymm1 = _mm256_permute2f128_ps(ymm10, ymm11, 0x31); \ + ymm0 = _mm256_permute2f128_ps(ymm12, ymm13, 0x31); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*4, ymm1); \ + ymm5 = _mm256_permute2f128_ps(ymm8, ymm5, 0x31); \ + ymm4 = _mm256_permute2f128_ps(ymm15, ymm4, 0x31); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*5, ymm0); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*6, ymm5); \ + _mm256_storeu_ps(ptr_B + (size_t)ldB*7, ymm4);}}}}}}}} \ +} + +#define COMPRESS_FP32(v, k, m, cnt) { \ + const unsigned int mask = _mm256_movemask_ps(m); \ + const SIMDTYPE_INT32 vk = _MM_SET1_INT16((short)(k)); \ + const __m256i perm_ctrl = _mm256_loadu_si256(&shufmasks[mask]); \ + const __m256 v_packed = _mm256_permutevar8x32_ps(v, perm_ctrl); \ + const __m256i v_shuff = _mm256_loadu_si256(&shufmasks2[mask]); \ + const __m256i v_idx = _mm256_add_epi32(vk, v_shuff); \ + _mm256_storeu_ps(values_ptr + (cnt), v_packed); \ + _mm256_storeu_si256((__m256i *)(colidx_ptr + (cnt)), v_idx); \ + cnt = (unsigned short)((cnt) + _mm_popcnt_u32(mask)); \ +} + +#define EXPAND_BFLOAT16(v, vlo_final, vhi_final) { \ + const __m256i vlo = _mm256_unpacklo_epi16(vzero, v); \ + const __m256i vhi = _mm256_unpackhi_epi16(vzero, v); \ + vlo_final = _mm256_castsi256_ps(_mm256_permute2f128_si256(vlo, vhi, 0x20)); \ + vhi_final = _mm256_castsi256_ps(_mm256_permute2f128_si256(vlo, vhi, 0x31)); \ +} + +#define COMPRESS_BFLOAT16(vlo, vhi, v) { \ + const __m256i vtmp1 = _mm256_castps_si256(_mm256_permute2f128_ps(vlo, vhi, 0x20)); \ + const __m256i vtmp2 = _mm256_castps_si256(_mm256_permute2f128_ps(vlo, vhi, 0x31)); \ + const __m256i a = _mm256_srli_epi32(vtmp1, 16), b = _mm256_srli_epi32(vtmp2, 16); \ + v = _mm256_packus_epi32(a, b); \ +} + +#endif + diff --git a/third_party/libxsmm/src/libxsmm_spmdm_begin_avx512.h b/third_party/libxsmm/src/libxsmm_spmdm_begin_avx512.h new file mode 100644 index 0000000000000000000000000000000000000000..0174e28776bf1e5bc6b71f4b94d2f7ffefa82347 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_spmdm_begin_avx512.h @@ -0,0 +1,310 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Nadathur Satish, Hans Pabst (Intel Corp.) +******************************************************************************/ +#if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH) +# error "libxsmm_intrinsics_x86.h not included!" +#endif + +#if (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_MAX_STATIC_TARGET_ARCH) +#define SIMD_WIDTH_FP32 (16) +#define SIMDTYPE_FP32 __m512 +#define SIMDTYPE_INT32 __m512i +#define SIMDMASKTYPE_FP32 __mmask16 +#define _MM_SETZERO_FP32 _mm512_setzero_ps +#define _MM_SETZERO_INT32 _mm512_setzero_epi32 +#define _MM_SET1_FP32 _mm512_set1_ps +#define _MM_SET1_INT32 _mm512_set1_epi32 +#define _MM_SET1_INT16 _mm512_set1_epi16 +#define _MM_SET_INT32 _mm512_set_epi32 +#define _MM_LOAD_FP32 LIBXSMM_INTRINSICS_MM512_LOAD_PS +#define _MM_LOADU_FP32 _mm512_loadu_ps +#define _MM_LOAD_INT32 _mm512_loadu_si512 +#define _MM_STORE_INT32 _mm512_storeu_si512 +#define _MM_LOADU_INT32(x) _mm512_loadu_si512( (void const *)(x)) +#define _MM_GATHER_INT32(Addr, idx, scale) _mm512_i32gather_epi32((idx), (Addr), (scale)) +#define _MM_GATHER_FP32(Addr, idx, scale) _mm512_i32gather_ps((idx), (Addr), (scale)) +#define _MM_CMPNEQ_FP32(v1,v2) _mm512_cmp_ps_mask(v1,v2,12) +#define _MM_STORE_FP32 _mm512_storeu_ps +#define _MM_STOREU_FP32 _mm512_storeu_ps +#define _MM_ADD_FP32 _mm512_add_ps +#define _MM_FMADD_FP32 _mm512_fmadd_ps +#define _MM_MUL_FP32 _mm512_mul_ps +#define _MM_PREFETCH(x, y) _mm_prefetch(x, y) +#define TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_A, ldA, ptr_B, ldB) { \ + __m512 r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; \ + __m512 t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; \ + r0 = _mm512_loadu_ps(ptr_A); \ + r1 = _mm512_loadu_ps(ptr_A + ldA); \ + r2 = _mm512_loadu_ps(ptr_A + 2*ldA); \ + r3 = _mm512_loadu_ps(ptr_A + 3*ldA); \ + r4 = _mm512_loadu_ps(ptr_A + 4*ldA); \ + r5 = _mm512_loadu_ps(ptr_A + 5*ldA); \ + r6 = _mm512_loadu_ps(ptr_A + 6*ldA); \ + r7 = _mm512_loadu_ps(ptr_A + 7*ldA); \ + r8 = _mm512_loadu_ps(ptr_A + 8*ldA); \ + r9 = _mm512_loadu_ps(ptr_A + 9*ldA); \ + ra = _mm512_loadu_ps(ptr_A + 10*ldA); \ + rb = _mm512_loadu_ps(ptr_A + 11*ldA); \ + rc = _mm512_loadu_ps(ptr_A + 12*ldA); \ + rd = _mm512_loadu_ps(ptr_A + 13*ldA); \ + re = _mm512_loadu_ps(ptr_A + 14*ldA); \ + rf = _mm512_loadu_ps(ptr_A + 15*ldA); \ + \ + t0 = _mm512_unpacklo_ps(r0,r1); \ + t1 = _mm512_unpackhi_ps(r0,r1); \ + t2 = _mm512_unpacklo_ps(r2,r3); \ + t3 = _mm512_unpackhi_ps(r2,r3); \ + t4 = _mm512_unpacklo_ps(r4,r5); \ + t5 = _mm512_unpackhi_ps(r4,r5); \ + t6 = _mm512_unpacklo_ps(r6,r7); \ + t7 = _mm512_unpackhi_ps(r6,r7); \ + t8 = _mm512_unpacklo_ps(r8,r9); \ + t9 = _mm512_unpackhi_ps(r8,r9); \ + ta = _mm512_unpacklo_ps(ra,rb); \ + tb = _mm512_unpackhi_ps(ra,rb); \ + tc = _mm512_unpacklo_ps(rc,rd); \ + td = _mm512_unpackhi_ps(rc,rd); \ + te = _mm512_unpacklo_ps(re,rf); \ + tf = _mm512_unpackhi_ps(re,rf); \ + \ + { const __m512d td1 = _mm512_castps_pd(t0), td2 = _mm512_castps_pd(t2); \ + r0 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r1 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2));} \ + { const __m512d td1 = _mm512_castps_pd(t1), td2 = _mm512_castps_pd(t3); \ + r2 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r3 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2));} \ + { const __m512d td1 = _mm512_castps_pd(t4), td2 = _mm512_castps_pd(t6); \ + r4 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r5 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2));} \ + { const __m512d td1 = _mm512_castps_pd(t5), td2 = _mm512_castps_pd(t7); \ + r6 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r7 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2));} \ + { const __m512d td1 = _mm512_castps_pd(t8), td2 = _mm512_castps_pd(ta); \ + r8 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r9 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2));} \ + { const __m512d td1 = _mm512_castps_pd(t9), td2 = _mm512_castps_pd(tb); \ + ra = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + rb = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2));} \ + { const __m512d td1 = _mm512_castps_pd(tc), td2 = _mm512_castps_pd(te); \ + rc = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + rd = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2));} \ + { const __m512d td1 = _mm512_castps_pd(td), td2 = _mm512_castps_pd(tf); \ + re = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + rf = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2));} \ + \ + t0 = _mm512_shuffle_f32x4(r0, r4, 0x88); \ + t1 = _mm512_shuffle_f32x4(r1, r5, 0x88); \ + t2 = _mm512_shuffle_f32x4(r2, r6, 0x88); \ + t3 = _mm512_shuffle_f32x4(r3, r7, 0x88); \ + t4 = _mm512_shuffle_f32x4(r0, r4, 0xdd); \ + t5 = _mm512_shuffle_f32x4(r1, r5, 0xdd); \ + t6 = _mm512_shuffle_f32x4(r2, r6, 0xdd); \ + t7 = _mm512_shuffle_f32x4(r3, r7, 0xdd); \ + t8 = _mm512_shuffle_f32x4(r8, rc, 0x88); \ + t9 = _mm512_shuffle_f32x4(r9, rd, 0x88); \ + ta = _mm512_shuffle_f32x4(ra, re, 0x88); \ + tb = _mm512_shuffle_f32x4(rb, rf, 0x88); \ + tc = _mm512_shuffle_f32x4(r8, rc, 0xdd); \ + td = _mm512_shuffle_f32x4(r9, rd, 0xdd); \ + te = _mm512_shuffle_f32x4(ra, re, 0xdd); \ + tf = _mm512_shuffle_f32x4(rb, rf, 0xdd); \ + \ + r0 = _mm512_shuffle_f32x4(t0, t8, 0x88); \ + r1 = _mm512_shuffle_f32x4(t1, t9, 0x88); \ + r2 = _mm512_shuffle_f32x4(t2, ta, 0x88); \ + r3 = _mm512_shuffle_f32x4(t3, tb, 0x88); \ + r4 = _mm512_shuffle_f32x4(t4, tc, 0x88); \ + r5 = _mm512_shuffle_f32x4(t5, td, 0x88); \ + r6 = _mm512_shuffle_f32x4(t6, te, 0x88); \ + r7 = _mm512_shuffle_f32x4(t7, tf, 0x88); \ + r8 = _mm512_shuffle_f32x4(t0, t8, 0xdd); \ + r9 = _mm512_shuffle_f32x4(t1, t9, 0xdd); \ + ra = _mm512_shuffle_f32x4(t2, ta, 0xdd); \ + rb = _mm512_shuffle_f32x4(t3, tb, 0xdd); \ + rc = _mm512_shuffle_f32x4(t4, tc, 0xdd); \ + rd = _mm512_shuffle_f32x4(t5, td, 0xdd); \ + re = _mm512_shuffle_f32x4(t6, te, 0xdd); \ + rf = _mm512_shuffle_f32x4(t7, tf, 0xdd); \ + \ + _mm512_storeu_ps(ptr_B + 0*ldB, r0); \ + _mm512_storeu_ps(ptr_B + 1*ldB, r1); \ + _mm512_storeu_ps(ptr_B + 2*ldB, r2); \ + _mm512_storeu_ps(ptr_B + 3*ldB, r3); \ + _mm512_storeu_ps(ptr_B + 4*ldB, r4); \ + _mm512_storeu_ps(ptr_B + 5*ldB, r5); \ + _mm512_storeu_ps(ptr_B + 6*ldB, r6); \ + _mm512_storeu_ps(ptr_B + 7*ldB, r7); \ + _mm512_storeu_ps(ptr_B + 8*ldB, r8); \ + _mm512_storeu_ps(ptr_B + 9*ldB, r9); \ + _mm512_storeu_ps(ptr_B + 10*ldB, ra); \ + _mm512_storeu_ps(ptr_B + 11*ldB, rb); \ + _mm512_storeu_ps(ptr_B + 12*ldB, rc); \ + _mm512_storeu_ps(ptr_B + 13*ldB, rd); \ + _mm512_storeu_ps(ptr_B + 14*ldB, re); \ + _mm512_storeu_ps(ptr_B + 15*ldB, rf); \ +} + +#define TRANSPOSE_SIMD_WIDTH_KERNEL_BFLOAT16(ptr_A, ldA, ptr_B, ldB) { \ + __m512 r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; \ + __m512 t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; \ + __m512i vload_1 = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(ptr_A))); \ + vload_1 = _mm512_inserti32x8(vload_1, _mm256_loadu_si256((const __m256i*)(ptr_A + ldA)), 1); \ + EXPAND_BFLOAT16(vload_1, r0, r1);{ \ + __m512i vload_2 = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(ptr_A + 2*ldA))); \ + vload_2 = _mm512_inserti32x8(vload_2, _mm256_loadu_si256((const __m256i*)(ptr_A + 3*ldA)), 1); \ + EXPAND_BFLOAT16(vload_2, r2, r3);{ \ + __m512i vload_3 = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(ptr_A + 4*ldA))); \ + vload_3 = _mm512_inserti32x8(vload_3, _mm256_loadu_si256((const __m256i*)(ptr_A + 5*ldA)), 1); \ + EXPAND_BFLOAT16(vload_3, r4, r5);{ \ + __m512i vload_4 = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(ptr_A + 6*ldA))); \ + vload_4 = _mm512_inserti32x8(vload_4, _mm256_loadu_si256((const __m256i*)(ptr_A + 7*ldA)), 1); \ + EXPAND_BFLOAT16(vload_4, r6, r7);{ \ + __m512i vload_5 = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(ptr_A + 8*ldA))); \ + vload_5 = _mm512_inserti32x8(vload_5, _mm256_loadu_si256((const __m256i*)(ptr_A + 9*ldA)), 1); \ + EXPAND_BFLOAT16(vload_5, r8, r9);{ \ + __m512i vload_6 = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(ptr_A + 10*ldA))); \ + vload_6 = _mm512_inserti32x8(vload_6, _mm256_loadu_si256((const __m256i*)(ptr_A + 11*ldA)), 1); \ + EXPAND_BFLOAT16(vload_6, ra, rb);{ \ + __m512i vload_7 = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(ptr_A + 12*ldA))); \ + vload_7 = _mm512_inserti32x8(vload_7, _mm256_loadu_si256((const __m256i*)(ptr_A + 13*ldA)), 1); \ + EXPAND_BFLOAT16(vload_7, rc, rd);{ \ + __m512i vload_8 = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(ptr_A + 14*ldA))); \ + vload_8 = _mm512_inserti32x8(vload_8, _mm256_loadu_si256((const __m256i*)(ptr_A + 15*ldA)), 1); \ + EXPAND_BFLOAT16(vload_8, re, rf); \ + \ + t0 = _mm512_unpacklo_ps(r0,r1); \ + t1 = _mm512_unpackhi_ps(r0,r1); \ + t2 = _mm512_unpacklo_ps(r2,r3); \ + t3 = _mm512_unpackhi_ps(r2,r3); \ + t4 = _mm512_unpacklo_ps(r4,r5); \ + t5 = _mm512_unpackhi_ps(r4,r5); \ + t6 = _mm512_unpacklo_ps(r6,r7); \ + t7 = _mm512_unpackhi_ps(r6,r7); \ + t8 = _mm512_unpacklo_ps(r8,r9); \ + t9 = _mm512_unpackhi_ps(r8,r9); \ + ta = _mm512_unpacklo_ps(ra,rb); \ + tb = _mm512_unpackhi_ps(ra,rb); \ + tc = _mm512_unpacklo_ps(rc,rd); \ + td = _mm512_unpackhi_ps(rc,rd); \ + te = _mm512_unpacklo_ps(re,rf); \ + tf = _mm512_unpackhi_ps(re,rf); \ + \ + { const __m512d td1 = _mm512_castps_pd(t0), td2 = _mm512_castps_pd(t2); \ + r0 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r1 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2)); } \ + { const __m512d td1 = _mm512_castps_pd(t1), td2 = _mm512_castps_pd(t3); \ + r2 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r3 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2)); } \ + { const __m512d td1 = _mm512_castps_pd(t4), td2 = _mm512_castps_pd(t6); \ + r4 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r5 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2)); } \ + { const __m512d td1 = _mm512_castps_pd(t5), td2 = _mm512_castps_pd(t7); \ + r6 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r7 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2)); } \ + { const __m512d td1 = _mm512_castps_pd(t8), td2 = _mm512_castps_pd(ta); \ + r8 = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + r9 = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2)); } \ + { const __m512d td1 = _mm512_castps_pd(t9), td2 = _mm512_castps_pd(tb); \ + ra = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + rb = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2)); } \ + { const __m512d td1 = _mm512_castps_pd(tc), td2 = _mm512_castps_pd(te); \ + rc = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + rd = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2)); } \ + { const __m512d td1 = _mm512_castps_pd(td), td2 = _mm512_castps_pd(tf); \ + re = _mm512_castpd_ps(_mm512_unpacklo_pd(td1, td2)); \ + rf = _mm512_castpd_ps(_mm512_unpackhi_pd(td1, td2)); } \ + \ + t0 = _mm512_shuffle_f32x4(r0, r4, 0x88); \ + t1 = _mm512_shuffle_f32x4(r1, r5, 0x88); \ + t2 = _mm512_shuffle_f32x4(r2, r6, 0x88); \ + t3 = _mm512_shuffle_f32x4(r3, r7, 0x88); \ + t4 = _mm512_shuffle_f32x4(r0, r4, 0xdd); \ + t5 = _mm512_shuffle_f32x4(r1, r5, 0xdd); \ + t6 = _mm512_shuffle_f32x4(r2, r6, 0xdd); \ + t7 = _mm512_shuffle_f32x4(r3, r7, 0xdd); \ + t8 = _mm512_shuffle_f32x4(r8, rc, 0x88); \ + t9 = _mm512_shuffle_f32x4(r9, rd, 0x88); \ + ta = _mm512_shuffle_f32x4(ra, re, 0x88); \ + tb = _mm512_shuffle_f32x4(rb, rf, 0x88); \ + tc = _mm512_shuffle_f32x4(r8, rc, 0xdd); \ + td = _mm512_shuffle_f32x4(r9, rd, 0xdd); \ + te = _mm512_shuffle_f32x4(ra, re, 0xdd); \ + tf = _mm512_shuffle_f32x4(rb, rf, 0xdd); \ + \ + r0 = _mm512_shuffle_f32x4(t0, t8, 0x88); \ + r1 = _mm512_shuffle_f32x4(t1, t9, 0x88); \ + r2 = _mm512_shuffle_f32x4(t2, ta, 0x88); \ + r3 = _mm512_shuffle_f32x4(t3, tb, 0x88); \ + r4 = _mm512_shuffle_f32x4(t4, tc, 0x88); \ + r5 = _mm512_shuffle_f32x4(t5, td, 0x88); \ + r6 = _mm512_shuffle_f32x4(t6, te, 0x88); \ + r7 = _mm512_shuffle_f32x4(t7, tf, 0x88); \ + r8 = _mm512_shuffle_f32x4(t0, t8, 0xdd); \ + r9 = _mm512_shuffle_f32x4(t1, t9, 0xdd); \ + ra = _mm512_shuffle_f32x4(t2, ta, 0xdd); \ + rb = _mm512_shuffle_f32x4(t3, tb, 0xdd); \ + rc = _mm512_shuffle_f32x4(t4, tc, 0xdd); \ + rd = _mm512_shuffle_f32x4(t5, td, 0xdd); \ + re = _mm512_shuffle_f32x4(t6, te, 0xdd); \ + rf = _mm512_shuffle_f32x4(t7, tf, 0xdd); \ + \ + _mm512_storeu_ps(ptr_B + 0*ldB, r0); \ + _mm512_storeu_ps(ptr_B + 1*ldB, r1); \ + _mm512_storeu_ps(ptr_B + 2*ldB, r2); \ + _mm512_storeu_ps(ptr_B + 3*ldB, r3); \ + _mm512_storeu_ps(ptr_B + 4*ldB, r4); \ + _mm512_storeu_ps(ptr_B + 5*ldB, r5); \ + _mm512_storeu_ps(ptr_B + 6*ldB, r6); \ + _mm512_storeu_ps(ptr_B + 7*ldB, r7); \ + _mm512_storeu_ps(ptr_B + 8*ldB, r8); \ + _mm512_storeu_ps(ptr_B + 9*ldB, r9); \ + _mm512_storeu_ps(ptr_B + 10*ldB, ra); \ + _mm512_storeu_ps(ptr_B + 11*ldB, rb); \ + _mm512_storeu_ps(ptr_B + 12*ldB, rc); \ + _mm512_storeu_ps(ptr_B + 13*ldB, rd); \ + _mm512_storeu_ps(ptr_B + 14*ldB, re); \ + _mm512_storeu_ps(ptr_B + 15*ldB, rf);}}}}}}} \ +} + +#define COMPRESS_FP32(v, k, m, cnt) { \ + _mm512_mask_compressstoreu_ps(values_ptr + (cnt), m, v); \ + { \ + __m256i vk1 = _mm256_set1_epi16((short)(k)); \ + __m256i vk2 = _mm256_set1_epi16((short)((k) + 8)); \ + __m256i v_idx = _mm256_add_epi32(vk1, _mm256_loadu_si256(&shufmasks2[(m)&0xFF])); \ + __m256i v_idx_2 = _mm256_add_epi32(vk2, _mm256_loadu_si256(&shufmasks2[((m)>>8)&0xFF])); \ + _mm256_storeu_si256((__m256i *)(colidx_ptr + (cnt)), v_idx); \ + cnt = (unsigned short)((cnt) + _mm_popcnt_u32((m)&0xFF)); \ + _mm256_storeu_si256((__m256i *)(colidx_ptr + (cnt)), v_idx_2); \ + cnt = (unsigned short)((cnt) + _mm_popcnt_u32(((m)>>8)&0xFF)); \ + } \ +} + +#define EXPAND_BFLOAT16(v, vlo_final, vhi_final) { \ + const __m512i vlo = _mm512_unpacklo_epi16(vzero, v); \ + const __m512i vhi = _mm512_unpackhi_epi16(vzero, v); \ + const __m512i permmask1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); \ + const __m512i permmask2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); \ + vlo_final = _mm512_castsi512_ps(_mm512_permutex2var_epi64(vlo, permmask1, vhi)); \ + vhi_final = _mm512_castsi512_ps(_mm512_permutex2var_epi64(vlo, permmask2, vhi)); \ +} + +#define COMPRESS_BFLOAT16(vlo, vhi, v) { \ + const __m512i permmask1 = _mm512_set_epi64(13, 12, 9, 8, 5, 4, 1, 0); \ + const __m512i permmask2 = _mm512_set_epi64(15, 14, 11, 10, 7, 6, 3, 2); \ + const __m512i va = _mm512_castps_si512(vlo), vb = _mm512_castps_si512(vhi); \ + const __m512i vtmp1 = _mm512_permutex2var_epi64(va, permmask1, vb); \ + const __m512i vtmp2 = _mm512_permutex2var_epi64(va, permmask2, vb); \ + const __m512i a = _mm512_srli_epi32(vtmp1, 16), b = _mm512_srli_epi32(vtmp2, 16); \ + v = _mm512_packus_epi32(a, b); \ +} + +#endif + diff --git a/third_party/libxsmm/src/libxsmm_spmdm_end.h b/third_party/libxsmm/src/libxsmm_spmdm_end.h new file mode 100644 index 0000000000000000000000000000000000000000..12bd27f735f4997953dcbd7126b3e9777a84eafe --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_spmdm_end.h @@ -0,0 +1,42 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ + +#undef SIMD_WIDTH_FP32 +#undef SIMDTYPE_FP32 +#undef SIMDTYPE_INT32 +#undef SIMDMASKTYPE_FP32 +#undef _MM_SETZERO_FP32 +#undef _MM_SETZERO_INT32 +#undef _MM_SET1_FP32 +#undef _MM_SET1_INT32 +#undef _MM_SET1_INT16 +#undef _MM_SET_INT32 +#undef _MM_LOAD_FP32 +#undef _MM_LOADU_FP32 +#undef _MM_LOAD_INT32 +#undef _MM_STORE_INT32 +#undef _MM_LOADU_INT32 +#undef _MM_GATHER_INT32 +#undef _MM_GATHER_FP32 +#undef _MM_CMPNEQ_FP32 +#undef _MM_STORE_FP32 +#undef _MM_STOREU_FP32 +#undef _MM_ADD_FP32 +#undef _MM_FMADD_FP32 +#undef _MM_MUL_FP32 +#undef _MM_PREFETCH +#undef TRANSPOSE_SIMD_WIDTH_KERNEL +#undef TRANSPOSE_SIMD_WIDTH_KERNEL_BFLOAT16 +#undef COMPRESS_FP32 +#undef EXPAND_BFLOAT16 +#undef COMPRESS_BFLOAT16 +#undef num_regs + diff --git a/third_party/libxsmm/src/libxsmm_sync.c b/third_party/libxsmm/src/libxsmm_sync.c new file mode 100644 index 0000000000000000000000000000000000000000..40dace51ede06f36ca38a004c34ddf6e862af205 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_sync.c @@ -0,0 +1,673 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +/* Lock primitives inspired by Karl Malbrain, Concurrency Kit, and TF/sync. +******************************************************************************/ +#include "libxsmm_main.h" + +#if !defined(LIBXSMM_SYNC_FUTEX) && defined(__linux__) && defined(__USE_GNU) +# define LIBXSMM_SYNC_FUTEX +#endif + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#include +#if defined(_WIN32) +# include +#else +# if defined(LIBXSMM_SYNC_FUTEX) && defined(__linux__) && defined(__USE_GNU) +# include +# endif +# include +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if !defined(LIBXSMM_SYNC_RWLOCK_BITS) +# if defined(__MINGW32__) +# define LIBXSMM_SYNC_RWLOCK_BITS 32 +# else +# define LIBXSMM_SYNC_RWLOCK_BITS 16 +# endif +#endif + +#if !defined(LIBXSMM_SYNC_GENERIC_PID) && 1 +# define LIBXSMM_SYNC_GENERIC_PID +#endif + + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE internal_sync_core_tag { /* per-core */ + uint8_t id; + volatile uint8_t core_sense; + volatile uint8_t* thread_senses; + volatile uint8_t* my_flags[2]; + uint8_t** partner_flags[2]; + uint8_t parity; + uint8_t sense; +} internal_sync_core_tag; + +LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE internal_sync_thread_tag { /* per-thread */ + int core_tid; + internal_sync_core_tag *core; +} internal_sync_thread_tag; + +struct LIBXSMM_RETARGETABLE libxsmm_barrier { + internal_sync_core_tag** cores; + internal_sync_thread_tag** threads; + int ncores, nthreads_per_core; + int nthreads, ncores_nbits; /* nbits(ncores) != log2(ncores) */ + /* internal counter type which is guaranteed to be atomic when using certain methods */ + volatile int threads_waiting; + /* thread-safety during initialization */ + volatile uint8_t init_done; +}; + + +LIBXSMM_API libxsmm_barrier* libxsmm_barrier_create(int ncores, int nthreads_per_core) +{ + libxsmm_barrier *const barrier = (libxsmm_barrier*)malloc(sizeof(libxsmm_barrier)); +#if (0 == LIBXSMM_SYNC) + LIBXSMM_UNUSED(ncores); LIBXSMM_UNUSED(nthreads_per_core); +#else + if (NULL != barrier && 1 < ncores && 1 <= nthreads_per_core) { + barrier->ncores = ncores; + barrier->ncores_nbits = (int)LIBXSMM_NBITS(ncores); + barrier->nthreads_per_core = nthreads_per_core; + barrier->nthreads = ncores * nthreads_per_core; + barrier->threads = (internal_sync_thread_tag**)libxsmm_aligned_malloc( + barrier->nthreads * sizeof(internal_sync_thread_tag*), LIBXSMM_CACHELINE); + barrier->cores = (internal_sync_core_tag**)libxsmm_aligned_malloc( + barrier->ncores * sizeof(internal_sync_core_tag*), LIBXSMM_CACHELINE); + barrier->threads_waiting = barrier->nthreads; /* atomic */ + barrier->init_done = 0; /* false */ + } + else +#endif + if (NULL != barrier) { + barrier->nthreads = 1; + } + return barrier; +} + + +LIBXSMM_API void libxsmm_barrier_init(libxsmm_barrier* barrier, int tid) +{ +#if (0 == LIBXSMM_SYNC) + LIBXSMM_UNUSED(barrier); LIBXSMM_UNUSED(tid); +#else + if (NULL != barrier && 1 < barrier->nthreads) { + const int cid = tid / barrier->nthreads_per_core; /* this thread's core ID */ + internal_sync_core_tag* core = 0; + int i; + internal_sync_thread_tag* thread; + + /* we only initialize the barrier once */ + if (barrier->init_done == 2) { + return; + } + + /* allocate per-thread structure */ + thread = (internal_sync_thread_tag*)libxsmm_aligned_malloc( + sizeof(internal_sync_thread_tag), LIBXSMM_CACHELINE); + barrier->threads[tid] = thread; + thread->core_tid = tid - (barrier->nthreads_per_core * cid); /* mod */ + + /* each core's thread 0 does all the allocations */ + if (0 == thread->core_tid) { + core = (internal_sync_core_tag*)libxsmm_aligned_malloc( + sizeof(internal_sync_core_tag), LIBXSMM_CACHELINE); + core->id = (uint8_t)cid; + core->core_sense = 1; + + core->thread_senses = (uint8_t*)libxsmm_aligned_malloc( + barrier->nthreads_per_core * sizeof(uint8_t), LIBXSMM_CACHELINE); + for (i = 0; i < barrier->nthreads_per_core; ++i) core->thread_senses[i] = 1; + + for (i = 0; i < 2; ++i) { + core->my_flags[i] = (uint8_t*)libxsmm_aligned_malloc( + barrier->ncores_nbits * sizeof(uint8_t) * LIBXSMM_CACHELINE, + LIBXSMM_CACHELINE); + core->partner_flags[i] = (uint8_t**)libxsmm_aligned_malloc( + barrier->ncores_nbits * sizeof(uint8_t*), + LIBXSMM_CACHELINE); + } + + core->parity = 0; + core->sense = 1; + barrier->cores[cid] = core; + } + + /* barrier to let all the allocations complete */ + if (0 == LIBXSMM_ATOMIC_SUB_FETCH(&barrier->threads_waiting, 1, LIBXSMM_ATOMIC_RELAXED)) { + barrier->threads_waiting = barrier->nthreads; /* atomic */ + barrier->init_done = 1; /* true */ + } + else { + while (0/*false*/ == barrier->init_done); + } + + /* set required per-thread information */ + thread->core = barrier->cores[cid]; + + /* each core's thread 0 completes setup */ + if (0 == thread->core_tid) { + int di; + for (i = di = 0; i < barrier->ncores_nbits; ++i, di += LIBXSMM_CACHELINE) { + /* find dissemination partner and link to it */ + const int dissem_cid = (cid + (1 << i)) % barrier->ncores; + assert(0 != core); /* initialized under the same condition; see above */ + core->my_flags[0][di] = core->my_flags[1][di] = 0; + core->partner_flags[0][i] = (uint8_t*)&barrier->cores[dissem_cid]->my_flags[0][di]; + core->partner_flags[1][i] = (uint8_t*)&barrier->cores[dissem_cid]->my_flags[1][di]; + } + } + + /* barrier to let initialization complete */ + if (0 == LIBXSMM_ATOMIC_SUB_FETCH(&barrier->threads_waiting, 1, LIBXSMM_ATOMIC_RELAXED)) { + barrier->threads_waiting = barrier->nthreads; /* atomic */ + barrier->init_done = 2; + } + else { + while (2 != barrier->init_done); + } + } +#endif +} + + +LIBXSMM_API LIBXSMM_INTRINSICS(LIBXSMM_X86_GENERIC) +void libxsmm_barrier_wait(libxsmm_barrier* barrier, int tid) +{ +#if (0 == LIBXSMM_SYNC) + LIBXSMM_UNUSED(barrier); LIBXSMM_UNUSED(tid); +#else + if (NULL != barrier && 1 < barrier->nthreads) { + internal_sync_thread_tag *const thread = barrier->threads[tid]; + internal_sync_core_tag *const core = thread->core; + + /* first let's execute a memory fence */ + LIBXSMM_ATOMIC_SYNC(LIBXSMM_ATOMIC_SEQ_CST); + + /* first signal this thread's arrival */ + core->thread_senses[thread->core_tid] = (uint8_t)(0 == core->thread_senses[thread->core_tid] ? 1 : 0); + + /* each core's thread 0 syncs across cores */ + if (0 == thread->core_tid) { + int i; + /* wait for the core's remaining threads */ + for (i = 1; i < barrier->nthreads_per_core; ++i) { + uint8_t core_sense = core->core_sense, thread_sense = core->thread_senses[i]; + while (core_sense == thread_sense) { /* avoid evaluation in unspecified order */ + LIBXSMM_SYNC_PAUSE; + core_sense = core->core_sense; + thread_sense = core->thread_senses[i]; + } + } + + if (1 < barrier->ncores) { + int di; +# if defined(__MIC__) + /* cannot use LIBXSMM_ALIGNED since attribute may not apply to local non-static arrays */ + uint8_t sendbuffer[LIBXSMM_CACHELINE+LIBXSMM_CACHELINE-1]; + uint8_t *const sendbuf = LIBXSMM_ALIGN(sendbuffer, LIBXSMM_CACHELINE); + __m512d m512d; + _mm_prefetch((const char*)core->partner_flags[core->parity][0], _MM_HINT_ET1); + sendbuf[0] = core->sense; + m512d = LIBXSMM_INTRINSICS_MM512_LOAD_PD(sendbuf); +# endif + + for (i = di = 0; i < barrier->ncores_nbits - 1; ++i, di += LIBXSMM_CACHELINE) { +# if defined(__MIC__) + _mm_prefetch((const char*)core->partner_flags[core->parity][i+1], _MM_HINT_ET1); + _mm512_storenrngo_pd(core->partner_flags[core->parity][i], m512d); +# else + *core->partner_flags[core->parity][i] = core->sense; +# endif + while (core->my_flags[core->parity][di] != core->sense) LIBXSMM_SYNC_PAUSE; + } + +# if defined(__MIC__) + _mm512_storenrngo_pd(core->partner_flags[core->parity][i], m512d); +# else + *core->partner_flags[core->parity][i] = core->sense; +# endif + while (core->my_flags[core->parity][di] != core->sense) LIBXSMM_SYNC_PAUSE; + if (1 == core->parity) { + core->sense = (uint8_t)(0 == core->sense ? 1 : 0); + } + core->parity = (uint8_t)(1 - core->parity); + } + + /* wake up the core's remaining threads */ + core->core_sense = core->thread_senses[0]; + } + else { /* other threads wait for cross-core sync to complete */ + uint8_t core_sense = core->core_sense, thread_sense = core->thread_senses[thread->core_tid]; + while (core_sense != thread_sense) { /* avoid evaluation in unspecified order */ + LIBXSMM_SYNC_PAUSE; + core_sense = core->core_sense; + thread_sense = core->thread_senses[thread->core_tid]; + } + } + } +#endif +} + + +LIBXSMM_API void libxsmm_barrier_destroy(const libxsmm_barrier* barrier) +{ +#if (0 != LIBXSMM_SYNC) + if (NULL != barrier && 1 < barrier->nthreads) { + if (2 == barrier->init_done) { + int i; + for (i = 0; i < barrier->ncores; ++i) { + int j; + libxsmm_free((const void*)barrier->cores[i]->thread_senses); + for (j = 0; j < 2; ++j) { + libxsmm_free((const void*)barrier->cores[i]->my_flags[j]); + libxsmm_free(barrier->cores[i]->partner_flags[j]); + } + libxsmm_free(barrier->cores[i]); + } + for (i = 0; i < barrier->nthreads; ++i) { + libxsmm_free(barrier->threads[i]); + } + } + libxsmm_free(barrier->threads); + libxsmm_free(barrier->cores); + } +#endif + free((libxsmm_barrier*)barrier); +} + + +#if (0 != LIBXSMM_SYNC) +enum { + INTERNAL_SYNC_LOCK_FREE = 0, + INTERNAL_SYNC_LOCK_LOCKED = 1, + INTERNAL_SYNC_LOCK_CONTESTED = 2, + INTERNAL_SYNC_RWLOCK_READINC = 0x10000/*(USHRT_MAX+1)*/, + INTERNAL_SYNC_FUTEX = 202 +}; +#endif + + +typedef unsigned int libxsmm_spinlock_state; +struct LIBXSMM_RETARGETABLE libxsmm_spinlock { + volatile libxsmm_spinlock_state state; +}; + + +LIBXSMM_API libxsmm_spinlock* libxsmm_spinlock_create(void) +{ + libxsmm_spinlock *const result = (libxsmm_spinlock*)malloc(sizeof(libxsmm_spinlock)); +#if (0 != LIBXSMM_SYNC) + if (0 != result) { + result->state = INTERNAL_SYNC_LOCK_FREE; + } +#endif + return result; +} + + +LIBXSMM_API void libxsmm_spinlock_destroy(const libxsmm_spinlock* spinlock) +{ + free((libxsmm_spinlock*)spinlock); +} + + +LIBXSMM_API int libxsmm_spinlock_trylock(libxsmm_spinlock* spinlock) +{ +#if (0 != LIBXSMM_SYNC) +# if 0 + /*const*/ libxsmm_spinlock_state lock_free = INTERNAL_SYNC_LOCK_FREE; + assert(0 != spinlock); + return 0/*false*/ == LIBXSMM_ATOMIC_CMPSWP(&spinlock->state, lock_free, INTERNAL_SYNC_LOCK_LOCKED, LIBXSMM_ATOMIC_RELAXED) + ? (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_SPINLOCK) + 1) /* not acquired */ + : (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_SPINLOCK)); +# else + return LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_SPINLOCK) + !LIBXSMM_ATOMIC_TRYLOCK(&spinlock->state, LIBXSMM_ATOMIC_RELAXED); +# endif +#else + LIBXSMM_UNUSED(spinlock); + return LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_SPINLOCK); +#endif +} + + +LIBXSMM_API void libxsmm_spinlock_acquire(libxsmm_spinlock* spinlock) +{ +#if (0 != LIBXSMM_SYNC) + assert(0 != spinlock); + for (;;) { + if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&spinlock->state, 1, LIBXSMM_ATOMIC_RELAXED)) break; + LIBXSMM_SYNC_CYCLE(&spinlock->state, INTERNAL_SYNC_LOCK_FREE, LIBXSMM_SYNC_NPAUSE); + } + LIBXSMM_ATOMIC_SYNC(LIBXSMM_ATOMIC_SEQ_CST); +#else + LIBXSMM_UNUSED(spinlock); +#endif +} + + +LIBXSMM_API void libxsmm_spinlock_release(libxsmm_spinlock* spinlock) +{ +#if (0 != LIBXSMM_SYNC) + assert(0 != spinlock); + LIBXSMM_ATOMIC_SYNC(LIBXSMM_ATOMIC_SEQ_CST); + spinlock->state = INTERNAL_SYNC_LOCK_FREE; +#else + LIBXSMM_UNUSED(spinlock); +#endif +} + + +#if defined(LIBXSMM_SYNC_FUTEX) && defined(__linux__) && defined(__USE_GNU) +typedef int libxsmm_mutex_state; +#else +typedef char libxsmm_mutex_state; +#endif +struct LIBXSMM_RETARGETABLE libxsmm_mutex { + volatile libxsmm_mutex_state state; +}; + + +LIBXSMM_API libxsmm_mutex* libxsmm_mutex_create(void) +{ + libxsmm_mutex *const result = (libxsmm_mutex*)malloc(sizeof(libxsmm_mutex)); +#if (0 != LIBXSMM_SYNC) + if (0 != result) { + result->state = INTERNAL_SYNC_LOCK_FREE; + } +#endif + return result; +} + + +LIBXSMM_API void libxsmm_mutex_destroy(const libxsmm_mutex* mutex) +{ + free((libxsmm_mutex*)mutex); +} + + +LIBXSMM_API int libxsmm_mutex_trylock(libxsmm_mutex* mutex) +{ +#if (0 != LIBXSMM_SYNC) + assert(0 != mutex); + return LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_MUTEX) + !LIBXSMM_ATOMIC_TRYLOCK(&mutex->state, LIBXSMM_ATOMIC_RELAXED); +#else + LIBXSMM_UNUSED(mutex); + return LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_MUTEX); +#endif +} + + +LIBXSMM_API void libxsmm_mutex_acquire(libxsmm_mutex* mutex) +{ +#if (0 != LIBXSMM_SYNC) +# if defined(_WIN32) + assert(0 != mutex); + while (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_MUTEX) != libxsmm_mutex_trylock(mutex)) { + LIBXSMM_SYNC_CYCLE(&mutex->state, 0/*free*/, LIBXSMM_SYNC_NPAUSE); + } +# else + libxsmm_mutex_state lock_free = INTERNAL_SYNC_LOCK_FREE, lock_state = INTERNAL_SYNC_LOCK_LOCKED; + assert(0 != mutex); + while (0/*false*/ == LIBXSMM_ATOMIC_CMPSWP(&mutex->state, lock_free, lock_state, LIBXSMM_ATOMIC_RELAXED)) { + libxsmm_mutex_state state; + /* coverity[unreachable] may be reachable more than once due to volatile state */ + for (state = mutex->state; INTERNAL_SYNC_LOCK_FREE != state; state = mutex->state) { +# if defined(LIBXSMM_SYNC_FUTEX) && defined(__linux__) + LIBXSMM_SYNC_CYCLE_ELSE(&mutex->state, INTERNAL_SYNC_LOCK_FREE, LIBXSMM_SYNC_NPAUSE, { + /*const*/ libxsmm_mutex_state state_locked = INTERNAL_SYNC_LOCK_LOCKED; + if (INTERNAL_SYNC_LOCK_LOCKED != state || LIBXSMM_ATOMIC_CMPSWP(&mutex->state, + state_locked, INTERNAL_SYNC_LOCK_CONTESTED, LIBXSMM_ATOMIC_RELAXED)) + { + syscall(INTERNAL_SYNC_FUTEX, &mutex->state, FUTEX_WAIT, INTERNAL_SYNC_LOCK_CONTESTED, NULL, NULL, 0); + lock_state = INTERNAL_SYNC_LOCK_CONTESTED; + }} + ); + break; +# else + LIBXSMM_SYNC_CYCLE(&mutex->state, INTERNAL_SYNC_LOCK_FREE, LIBXSMM_SYNC_NPAUSE); +# endif + } + lock_free = INTERNAL_SYNC_LOCK_FREE; + } +# endif +#else + LIBXSMM_UNUSED(mutex); +#endif +} + + +LIBXSMM_API void libxsmm_mutex_release(libxsmm_mutex* mutex) +{ +#if (0 != LIBXSMM_SYNC) + assert(0 != mutex); + LIBXSMM_ATOMIC_SYNC(LIBXSMM_ATOMIC_SEQ_CST); +# if defined(LIBXSMM_SYNC_FUTEX) && defined(__linux__) && defined(__USE_GNU) + if (INTERNAL_SYNC_LOCK_CONTESTED == LIBXSMM_ATOMIC_FETCH_SUB(&mutex->state, 1, LIBXSMM_ATOMIC_RELAXED)) { + mutex->state = INTERNAL_SYNC_LOCK_FREE; + syscall(INTERNAL_SYNC_FUTEX, &mutex->state, FUTEX_WAKE, 1, NULL, NULL, 0); + } +# else + mutex->state = INTERNAL_SYNC_LOCK_FREE; +# endif +#else + LIBXSMM_UNUSED(mutex); +#endif +} + + +#if (0 != LIBXSMM_SYNC) +typedef LIBXSMM_CONCATENATE3(uint,LIBXSMM_SYNC_RWLOCK_BITS,_t) internal_sync_uint_t; +typedef LIBXSMM_CONCATENATE3(int,LIBXSMM_SYNC_RWLOCK_BITS,_t) internal_sync_int_t; +LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_sync_counter { + struct { internal_sync_uint_t writer, reader; } kind; + uint32_t bits; +} internal_sync_counter; +#endif +LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_rwlock { +#if (0 != LIBXSMM_SYNC) + volatile internal_sync_counter completions; + volatile internal_sync_counter requests; +#else + int dummy; +#endif +}; + + +LIBXSMM_API libxsmm_rwlock* libxsmm_rwlock_create(void) +{ + libxsmm_rwlock *const result = (libxsmm_rwlock*)malloc(sizeof(libxsmm_rwlock)); + if (0 != result) { +#if (0 != LIBXSMM_SYNC) + LIBXSMM_MEMZERO127(&result->completions); + LIBXSMM_MEMZERO127(&result->requests); +#else + LIBXSMM_MEMZERO127(result); +#endif + } + return result; +} + + +LIBXSMM_API void libxsmm_rwlock_destroy(const libxsmm_rwlock* rwlock) +{ + free((libxsmm_rwlock*)rwlock); +} + + +#if (0 != LIBXSMM_SYNC) +LIBXSMM_API_INLINE int internal_rwlock_trylock(libxsmm_rwlock* rwlock, internal_sync_counter* prev) +{ + internal_sync_counter next; + assert(0 != rwlock && 0 != prev); + do { + prev->bits = rwlock->requests.bits; + next.bits = prev->bits; + ++next.kind.writer; + } + while (0/*false*/ == LIBXSMM_ATOMIC_CMPSWP(&rwlock->requests.bits, prev->bits, next.bits, LIBXSMM_ATOMIC_RELAXED)); + return rwlock->completions.bits != prev->bits + ? (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK) + 1) /* not acquired */ + : (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK)); +} +#endif + + +LIBXSMM_API int libxsmm_rwlock_trylock(libxsmm_rwlock* rwlock) +{ +#if (0 != LIBXSMM_SYNC) + internal_sync_counter prev; + return internal_rwlock_trylock(rwlock, &prev); +#else + LIBXSMM_UNUSED(rwlock); + return LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK); +#endif +} + + +LIBXSMM_API void libxsmm_rwlock_acquire(libxsmm_rwlock* rwlock) +{ +#if (0 != LIBXSMM_SYNC) + internal_sync_counter prev; + if (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK) != internal_rwlock_trylock(rwlock, &prev)) { + while (rwlock->completions.bits != prev.bits) { + LIBXSMM_SYNC_CYCLE(&rwlock->completions.bits, prev.bits, LIBXSMM_SYNC_NPAUSE); + } + } +#else + LIBXSMM_UNUSED(rwlock); +#endif +} + + +LIBXSMM_API void libxsmm_rwlock_release(libxsmm_rwlock* rwlock) +{ +#if (0 != LIBXSMM_SYNC) + assert(0 != rwlock); + LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_FETCH_ADD, LIBXSMM_SYNC_RWLOCK_BITS)(&rwlock->completions.kind.writer, 1, LIBXSMM_ATOMIC_SEQ_CST); +#else + LIBXSMM_UNUSED(rwlock); +#endif +} + + +#if (0 != LIBXSMM_SYNC) +LIBXSMM_API_INLINE int internal_rwlock_tryread(libxsmm_rwlock* rwlock, internal_sync_counter* prev) +{ +#if (0 != LIBXSMM_SYNC) + assert(0 != rwlock && 0 != prev); + prev->bits = LIBXSMM_ATOMIC_FETCH_ADD(&rwlock->requests.bits, INTERNAL_SYNC_RWLOCK_READINC, LIBXSMM_ATOMIC_SEQ_CST); + return rwlock->completions.kind.writer != prev->kind.writer + ? (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK) + 1) /* not acquired */ + : (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK)); +#else + LIBXSMM_UNUSED(rwlock); LIBXSMM_UNUSED(prev); + return LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK); +#endif +} +#endif + + +LIBXSMM_API int libxsmm_rwlock_tryread(libxsmm_rwlock* rwlock) +{ +#if (0 != LIBXSMM_SYNC) + internal_sync_counter prev; + return internal_rwlock_tryread(rwlock, &prev); +#else + LIBXSMM_UNUSED(rwlock); + return LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK); +#endif +} + + +LIBXSMM_API void libxsmm_rwlock_acqread(libxsmm_rwlock* rwlock) +{ +#if (0 != LIBXSMM_SYNC) + internal_sync_counter prev; + if (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_RWLOCK) != internal_rwlock_tryread(rwlock, &prev)) { + while (rwlock->completions.kind.writer != prev.kind.writer) { + LIBXSMM_SYNC_CYCLE(&rwlock->completions.kind.writer, prev.kind.writer, LIBXSMM_SYNC_NPAUSE); + } + } +#else + LIBXSMM_UNUSED(rwlock); +#endif +} + + +LIBXSMM_API void libxsmm_rwlock_relread(libxsmm_rwlock* rwlock) +{ +#if (0 != LIBXSMM_SYNC) + assert(0 != rwlock); + LIBXSMM_ATOMIC(LIBXSMM_ATOMIC_FETCH_ADD, LIBXSMM_SYNC_RWLOCK_BITS)(&rwlock->completions.kind.reader, 1, LIBXSMM_ATOMIC_SEQ_CST); +#else + LIBXSMM_UNUSED(rwlock); +#endif +} + + +LIBXSMM_API unsigned int libxsmm_get_pid(void) +{ +#if defined(_WIN32) + return (unsigned int)_getpid(); +#else + return (unsigned int)getpid(); +#endif +} + + +LIBXSMM_API_INTERN unsigned int internal_get_tid(void); +LIBXSMM_API_INTERN unsigned int internal_get_tid(void) +{ + const unsigned int nthreads = LIBXSMM_ATOMIC_ADD_FETCH(&libxsmm_thread_count, 1, LIBXSMM_ATOMIC_RELAXED); +#if !defined(NDEBUG) + static int error_once = 0; + if (LIBXSMM_NTHREADS_MAX < nthreads + && 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: maximum number of threads is exhausted!\n"); + } +#endif + LIBXSMM_ASSERT(LIBXSMM_ISPOT(LIBXSMM_NTHREADS_MAX)); + return LIBXSMM_MOD2(nthreads - 1, LIBXSMM_NTHREADS_MAX); +} + + +LIBXSMM_API unsigned int libxsmm_get_tid(void) +{ +#if (0 != LIBXSMM_SYNC) +# if defined(LIBXSMM_SYNC_GENERIC_PID) + static LIBXSMM_TLS unsigned int tid = 0xFFFFFFFF; + if (0xFFFFFFFF == tid) tid = internal_get_tid(); + return tid; +# else + void* tls = LIBXSMM_TLS_GETVALUE(libxsmm_tlskey); + if (NULL == tls) { + static unsigned int tid[LIBXSMM_NTHREADS_MAX]; + const int i = internal_get_tid(); + tid[i] = i; tls = tid + i; + /* coverity[check_return] */ + LIBXSMM_TLS_SETVALUE(libxsmm_tlskey, tls); + } + return *(unsigned int*)tls; +# endif +#else + return 0; +#endif +} + diff --git a/third_party/libxsmm/src/libxsmm_timer.c b/third_party/libxsmm/src/libxsmm_timer.c new file mode 100644 index 0000000000000000000000000000000000000000..5a5c570517e7e1e12b186c3b59d02a87a85b8559 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_timer.c @@ -0,0 +1,221 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include +#include "libxsmm_main.h" + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#if defined(_WIN32) +# include +#elif defined(__GNUC__) || defined(__PGI) || defined(_CRAYC) +# include +# include +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +#if defined(__powerpc64__) +# include +#endif + +#if !defined(LIBXSMM_TIMER_TSC) +# define LIBXSMM_TIMER_TSC +#endif +#if !defined(LIBXSMM_TIMER_WPC) +# define LIBXSMM_TIMER_WPC +#endif + +#if defined(LIBXSMM_TIMER_TSC) +# if defined(__powerpc64__) +# define LIBXSMM_TIMER_RDTSC(CYCLE) { \ + CYCLE = __ppc_get_timebase(); \ + } +# elif ((defined(LIBXSMM_PLATFORM_X86) && (64 <= (LIBXSMM_BITS))) && \ + (defined(__GNUC__) || defined(LIBXSMM_INTEL_COMPILER) || defined(__PGI))) +# define LIBXSMM_TIMER_RDTSC(CYCLE) { libxsmm_timer_tickint libxsmm_timer_rdtsc_hi_; \ + __asm__ __volatile__ ("rdtsc" : "=a"(CYCLE), "=d"(libxsmm_timer_rdtsc_hi_)); \ + CYCLE |= libxsmm_timer_rdtsc_hi_ << 32; \ + } +# elif (defined(_rdtsc) || defined(_WIN32)) +# define LIBXSMM_TIMER_RDTSC(CYCLE) (CYCLE = __rdtsc()) +# endif +#endif + + +LIBXSMM_API_INTERN double libxsmm_timer_duration_rtc(libxsmm_timer_tickint tick0, libxsmm_timer_tickint tick1) +{ + double result = (double)LIBXSMM_DELTA(tick0, tick1); +#if defined(_WIN32) +# if defined(LIBXSMM_TIMER_WPC) + LARGE_INTEGER frequency; + QueryPerformanceFrequency(&frequency); + result /= (double)frequency.QuadPart; +# else /* low resolution */ + result *= 1E-3; +# endif +#elif defined(CLOCK_MONOTONIC) + result *= 1E-9; +#else + result *= 1E-6; +#endif + return result; +} + + +LIBXSMM_API_INTERN libxsmm_timer_tickint libxsmm_timer_tick_rtc(void) +{ + libxsmm_timer_tickint result; +#if defined(_WIN32) +# if defined(LIBXSMM_TIMER_WPC) + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + result = (libxsmm_timer_tickint)t.QuadPart; +# else /* low resolution */ + result = (libxsmm_timer_tickint)GetTickCount64(); +# endif +#elif defined(CLOCK_MONOTONIC) + struct timespec t; + clock_gettime(CLOCK_MONOTONIC, &t); + result = 1000000000ULL * t.tv_sec + t.tv_nsec; +#else + struct timeval t; + gettimeofday(&t, 0); + result = 1000000ULL * t.tv_sec + t.tv_usec; +#endif + return result; +} + + +LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_GENERIC) +libxsmm_timer_tickint libxsmm_timer_tick_tsc(void) +{ + libxsmm_timer_tickint result; +#if defined(LIBXSMM_TIMER_RDTSC) + LIBXSMM_TIMER_RDTSC(result); +#else + result = libxsmm_timer_tick_rtc(); +#endif + return result; +} + + +LIBXSMM_API int libxsmm_get_timer_info(libxsmm_timer_info* info) +{ + int result; + if (NULL != info) { +#if defined(LIBXSMM_TIMER_RDTSC) + if (0 < libxsmm_timer_scale) { + info->tsc = 1; + } +# if !defined(LIBXSMM_INIT_COMPLETED) + else if (2 > libxsmm_ninit) { + libxsmm_init(); + if (0 < libxsmm_timer_scale) { + info->tsc = 1; + } + else { + info->tsc = 0; + } + } +# endif + else { + info->tsc = 0; + } +#else + info->tsc = 0; +#endif + result = EXIT_SUCCESS; + } + else { +#if !defined(NDEBUG) + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid argument for libxsmm_get_timer_info specified!\n"); + } +#endif + result = EXIT_FAILURE; + } + return result; +} + + +LIBXSMM_API libxsmm_timer_tickint libxsmm_timer_tick(void) +{ + libxsmm_timer_tickint result; +#if defined(LIBXSMM_TIMER_RDTSC) + if (0 < libxsmm_timer_scale) { + LIBXSMM_TIMER_RDTSC(result); + } +# if !defined(LIBXSMM_INIT_COMPLETED) + else if (2 > libxsmm_ninit) { + libxsmm_init(); + if (0 < libxsmm_timer_scale) { + LIBXSMM_TIMER_RDTSC(result); + } + else { + result = libxsmm_timer_tick_rtc(); + } + } +# endif + else { + result = libxsmm_timer_tick_rtc(); + } +#else + result = libxsmm_timer_tick_rtc(); +#endif + return result; +} + + +LIBXSMM_API double libxsmm_timer_duration(libxsmm_timer_tickint tick0, libxsmm_timer_tickint tick1) +{ + double result; +#if defined(LIBXSMM_TIMER_RDTSC) + if (0 < libxsmm_timer_scale) { + result = (double)LIBXSMM_DELTA(tick0, tick1) * libxsmm_timer_scale; + } + else +#endif + { + result = libxsmm_timer_duration_rtc(tick0, tick1); + } + return result; +} + + +#if defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__)) + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_timer_ncycles)(libxsmm_timer_tickint* /*ncycles*/, const libxsmm_timer_tickint* /*tick0*/, const libxsmm_timer_tickint* /*tick1*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_timer_ncycles)(libxsmm_timer_tickint* ncycles, const libxsmm_timer_tickint* tick0, const libxsmm_timer_tickint* tick1) +{ +#if !defined(NDEBUG) + static int error_once = 0; + if (NULL != ncycles && NULL != tick0 && NULL != tick1) +#endif + { + *ncycles = libxsmm_timer_ncycles(*tick0, *tick1); + } +#if !defined(NDEBUG) + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_timer_ncycles specified!\n"); + } +#endif +} + +#endif /*defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/ + diff --git a/third_party/libxsmm/src/libxsmm_trace.c b/third_party/libxsmm/src/libxsmm_trace.c new file mode 100644 index 0000000000000000000000000000000000000000..a0f41dcf889ce1a86a6c45cb11c9d9f3b784bcc8 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_trace.c @@ -0,0 +1,567 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_trace.h" +#include "libxsmm_main.h" + +#if !defined(LIBXSMM_TRACE_MINDEPTH) || 0 > (LIBXSMM_TRACE_MINDEPTH) +# undef LIBXSMM_TRACE_MINDEPTH +# define LIBXSMM_TRACE_MINDEPTH 1 +#endif +#if !defined(LIBXSMM_TRACE_MAXDEPTH) || 0 >= (LIBXSMM_TRACE_MAXDEPTH) +# undef LIBXSMM_TRACE_MAXDEPTH +# define LIBXSMM_TRACE_MAXDEPTH 1024 +#endif +#if !defined(LIBXSMM_TRACE_SYMBOLSIZE) || 0 >= (LIBXSMM_TRACE_SYMBOLSIZE) +# undef LIBXSMM_TRACE_SYMBOLSIZE +# define LIBXSMM_TRACE_SYMBOLSIZE 256 +#endif +#if !defined(LIBXSMM_TRACE_DLINFO) && defined(__USE_GNU) +# define LIBXSMM_TRACE_DLINFO +#endif + +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) +#endif +#if !defined(NDEBUG) +# include +#endif +#if defined(_WIN32) || defined(__CYGWIN__) +# include +# if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4091) +# endif +# include +# if defined(_MSC_VER) +# pragma comment(lib, "dbghelp") +# endif +# if defined(_MSC_VER) +# pragma warning(pop) +# endif +LIBXSMM_APIVAR_DEFINE(volatile LONG internal_trace_initialized); +#else +LIBXSMM_APIVAR_DEFINE(volatile int internal_trace_initialized); +# include +# if defined(LIBXSMM_TRACE_DLINFO) +# include +# else +# include +# include +# include +# include +# include +# if (0 != LIBXSMM_SYNC) +LIBXSMM_APIVAR_DEFINE(LIBXSMM_TLS_TYPE internal_trace_key); +LIBXSMM_APIVAR_DEFINE(void* internal_trace_symbols[LIBXSMM_NTHREADS_MAX]); +# endif +LIBXSMM_API_INLINE void internal_delete(void* value) +{ + int fd; +# if !(defined(__APPLE__) && defined(__MACH__)) + LIBXSMM_ASSERT(NULL != value); +# endif + fd = *((int*)value); +# if defined(NDEBUG) + munmap(value, LIBXSMM_TRACE_SYMBOLSIZE); +# else /* library code is expected to be mute */ + if (0 != munmap(value, LIBXSMM_TRACE_SYMBOLSIZE)) { + const int error = errno; + fprintf(stderr, "LIBXSMM ERROR: %s (munmap error #%i at %p)\n", + strerror(error), error, value); + } +# endif + if (0 <= fd) { + close(fd); + } +# if !defined(NDEBUG) /* library code is expected to be mute */ + else { + fprintf(stderr, "LIBXSMM ERROR: invalid file descriptor (%i)\n", fd); + } +# endif +} +# if defined(__APPLE__) && defined(__MACH__) +/* taken from "libtransmission" fdlimit.c */ +LIBXSMM_API_INLINE int posix_fallocate(int fd, off_t offset, off_t length) +{ + fstore_t fst; + fst.fst_flags = F_ALLOCATECONTIG; + fst.fst_posmode = F_PEOFPOSMODE; + fst.fst_offset = offset; + fst.fst_length = length; + fst.fst_bytesalloc = 0; + return fcntl(fd, F_PREALLOCATE, &fst); +} +# elif (!defined(_XOPEN_SOURCE) || 600 > _XOPEN_SOURCE) && \ + (!defined(_POSIX_C_SOURCE) || 200112L > _POSIX_C_SOURCE) +/* C89: avoid warning about posix_fallocate declared implicitly */ +LIBXSMM_EXTERN int posix_fallocate(int, off_t, off_t); +# endif +# endif +LIBXSMM_EXTERN int mkstemp(char*) LIBXSMM_NOTHROW; +#endif +#if defined(LIBXSMM_OFFLOAD_TARGET) +# pragma offload_attribute(pop) +#endif + +LIBXSMM_APIVAR_DEFINE(int internal_trace_mindepth); +LIBXSMM_APIVAR_DEFINE(int internal_trace_threadid); +LIBXSMM_APIVAR_DEFINE(int internal_trace_maxnsyms); + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_NO_TRACE int libxsmm_trace_init(int /*filter_threadid*/, int /*filter_mindepth*/, int /*filter_maxnsyms*/); +LIBXSMM_API int libxsmm_trace_init(int filter_threadid, int filter_mindepth, int filter_maxnsyms) +{ + int result = EXIT_SUCCESS; + if (0 == internal_trace_initialized) { + if (0 <= filter_threadid) ++filter_threadid; +#if defined(__TRACE) + { const char *const env = getenv("LIBXSMM_TRACE"); + if (NULL != env && 0 != *env) { + char buffer[32] = { 0 }; + if (1 == sscanf(env, "%32[^,],", buffer)) { + result = (0 <= sscanf(buffer, "%i", &filter_threadid) ? EXIT_SUCCESS : EXIT_FAILURE); + } + if (1 == sscanf(env, "%*[^,],%32[^,],", buffer)) { + result = (0 <= sscanf(buffer, "%i", &filter_mindepth) ? EXIT_SUCCESS : EXIT_FAILURE); + } + if (1 == sscanf(env, "%*[^,],%*[^,],%32s", buffer)) { + result = (0 <= sscanf(buffer, "%i", &filter_maxnsyms) ? EXIT_SUCCESS : EXIT_FAILURE); + } + else { + filter_maxnsyms = -1; /* all */ + } + if (EXIT_SUCCESS == result) { + internal_trace_initialized = -1; /* auto */ + } + } + } + if (EXIT_SUCCESS == result) +#endif + { +#if defined(LIBXSMM_TRACE) +# if defined(_WIN32) || defined(__CYGWIN__) + SymSetOptions(SYMOPT_DEFERRED_LOADS | SYMOPT_UNDNAME); + result = (FALSE != SymInitialize(GetCurrentProcess(), NULL, TRUE) ? EXIT_SUCCESS : GetLastError()); +# elif (0 != LIBXSMM_SYNC) && !defined(LIBXSMM_TRACE_DLINFO) + result = LIBXSMM_TLS_CREATE(&internal_trace_key); +# endif + if (EXIT_SUCCESS == result) { + internal_trace_threadid = filter_threadid; + internal_trace_maxnsyms = filter_maxnsyms; + internal_trace_mindepth = filter_mindepth; + if (0 == internal_trace_initialized) { + internal_trace_initialized = 1; + } + } +#else + LIBXSMM_UNUSED(filter_threadid); + LIBXSMM_UNUSED(filter_mindepth); + LIBXSMM_UNUSED(filter_maxnsyms); +#endif + } + } + return result; +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_NO_TRACE int libxsmm_trace_finalize(void); +LIBXSMM_API int libxsmm_trace_finalize(void) +{ + int result; +#if defined(LIBXSMM_TRACE) + result = EXIT_SUCCESS; + if (0 != internal_trace_initialized) { + internal_trace_initialized = 0; /* disable */ +# if defined(_WIN32) || defined(__CYGWIN__) + result = (FALSE != SymCleanup(GetCurrentProcess()) ? EXIT_SUCCESS : GetLastError()); +# elif (0 != LIBXSMM_SYNC) && !defined(LIBXSMM_TRACE_DLINFO) + result = LIBXSMM_TLS_DESTROY(internal_trace_key); + { int i = 0; + for (; i < LIBXSMM_NTHREADS_MAX; ++i) { + void *const buffer = internal_trace_symbols[i]; + if (NULL != buffer) internal_delete(buffer); + } + } +# endif + } +#else + result = EXIT_FAILURE; +#endif + return result; +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_NO_TRACE unsigned int libxsmm_backtrace(const void* /*buffer*/[], unsigned int /*size*/, unsigned int /*skip*/); +LIBXSMM_API +#if defined(_WIN32) +/*TODO: no inline*/ +#elif defined(__GNUC__) +/*LIBXSMM_ATTRIBUTE(noinline)*/ +#endif +unsigned int libxsmm_backtrace(const void* buffer[], unsigned int size, unsigned int skip) +{ + unsigned int result; + if (NULL != buffer && 0 != size && skip < size) { + skip += LIBXSMM_TRACE_MINDEPTH; +#if defined(_WIN32) || defined(__CYGWIN__) + result = CaptureStackBackTrace(skip, LIBXSMM_MIN(size, LIBXSMM_TRACE_MAXDEPTH), (PVOID*)buffer, NULL/*hash*/); +#else + { const int n = backtrace((void**)buffer, LIBXSMM_MIN((int)(size + skip), LIBXSMM_TRACE_MAXDEPTH)); + if ((int)skip < n) { + result = n - skip; + if (0 != skip) { + memmove(buffer, buffer + skip, result * sizeof(void*)); + } + } + else { + result = 0; + } + } +#endif + } + else { + result = 0; + } + return result; +} + + +#if !defined(_WIN32) && !defined(__CYGWIN__) +LIBXSMM_API_INLINE const char* internal_trace_get_symbolname(const void* address, char* map, int fd, off_t fdoff) +{ + const char* result = NULL; +#if defined(LIBXSMM_TRACE_DLINFO) + Dl_info info; + LIBXSMM_UNUSED(fd); LIBXSMM_UNUSED(fdoff); + LIBXSMM_ASSERT(NULL != address && NULL != map); + if (0 != dladdr(address, &info) && NULL != info.dli_sname) { + strncpy(map, info.dli_sname, LIBXSMM_TRACE_SYMBOLSIZE - 1); + result = map; + } +#else + LIBXSMM_ASSERT(NULL != address && NULL != map); + backtrace_symbols_fd((void**)&address, 1, fd); + if (fdoff == lseek(fd, fdoff, SEEK_SET) /* reset map */ + && 1 == sscanf(map, "%*[^(](%s0x", map)) + { + char* c = map; + for (; '+' != *c && 0 != *c; ++c); + if ('+' == *c && c != map) { + result = map; + map = c; + } + } + *map = 0; /* terminate */ +#endif + return result; +} +#endif + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_NO_TRACE +const char* libxsmm_trace_info(unsigned int* /*depth*/, unsigned int* /*threadid*/, const int* /*filter_threadid*/, + const void* /*filter_symbol*/, const int* /*filter_mindepth*/, const int* /*filter_maxnsyms*/); + +LIBXSMM_API +#if defined(_WIN32) +/*TODO: no inline*/ +#elif defined(__GNUC__) +/*LIBXSMM_ATTRIBUTE(noinline)*/ +#endif +const char* libxsmm_trace_info(unsigned int* depth, unsigned int* threadid, const int* filter_threadid, + const void* filter_symbol, const int* filter_mindepth, const int* filter_maxnsyms) +{ + const char *fname = NULL; +#if defined(LIBXSMM_TRACE) + static LIBXSMM_TLS int cerberus = 0; + /* check against entering a recursion (recursion should not happen due to + * attribute "no_instrument_function" but better prevent this in any case) + */ + if (0 == cerberus) { + int init; + ++cerberus; +# if defined(__GNUC__) && !defined(_CRAYC) + __asm__(""); +# endif + init = LIBXSMM_ATOMIC_LOAD(&internal_trace_initialized, LIBXSMM_ATOMIC_RELAXED); + if (0 != init) { /* do nothing if not yet initialized */ + const int mindepth = (NULL != filter_mindepth ? *filter_mindepth : internal_trace_mindepth); + const int maxnsyms = (NULL != filter_maxnsyms ? *filter_maxnsyms : internal_trace_maxnsyms); + const void *stacktrace[LIBXSMM_TRACE_MAXDEPTH]; + const int n = libxsmm_backtrace(stacktrace, LIBXSMM_TRACE_MAXDEPTH, 0); + int symbol = 0; + if (0 < n) { + const int filter = (NULL != filter_threadid ? *filter_threadid : internal_trace_threadid); + int abs_tid = 0; +# if defined(_WIN32) || defined(__CYGWIN__) || defined(LIBXSMM_TRACE_DLINFO) + static LIBXSMM_TLS struct { +# if defined(_WIN32) || defined(__CYGWIN__) + char buffer[sizeof(SYMBOL_INFO)+LIBXSMM_TRACE_SYMBOLSIZE]; +# else + char buffer[LIBXSMM_TRACE_SYMBOLSIZE]; +# endif + int tid; + } info; + if (0 != info.tid) { + abs_tid = LIBXSMM_ABS(info.tid); + } + else { + const int tid = LIBXSMM_ATOMIC_ADD_FETCH(&internal_trace_initialized, 0 < init ? 1 : -1, LIBXSMM_ATOMIC_RELAXED); + abs_tid = LIBXSMM_ABS(tid) - 1; + /* use sign bit to flag enabled fallback for symbol resolution */ + info.tid = -abs_tid; + } + LIBXSMM_ASSERT(0 < abs_tid); + if (0 > filter || filter == abs_tid) { + int next = symbol + 1; +# if defined(_WIN32) || defined(__CYGWIN__) + const HANDLE process = GetCurrentProcess(); + PSYMBOL_INFO value = (PSYMBOL_INFO)info.buffer; + value->SizeOfStruct = sizeof(SYMBOL_INFO); + value->MaxNameLen = LIBXSMM_TRACE_SYMBOLSIZE - 1; + value->NameLen = 0; +# endif + if (NULL != filter_symbol) { + struct { size_t d; int s; } approx = { (size_t)LIBXSMM_UNLIMITED, 0 }; + while (next < n && (filter_symbol == stacktrace[symbol] || +# if defined(_WIN32) || defined(__CYGWIN__) + (FALSE != SymFromAddr(process, (DWORD64)stacktrace[symbol], NULL, value) && 0 < value->NameLen))) + { + if (filter_symbol == stacktrace[symbol] || NULL != strstr(value->Name, (const char*)filter_symbol)) { +# else + (NULL != internal_trace_get_symbolname(stacktrace[symbol], info.buffer, 0, 0)))) + { + if (filter_symbol == stacktrace[symbol] || NULL != strstr(info.buffer, (const char*)filter_symbol)) { +# endif + symbol = next++; /* determine the symbol after the match which is checked below */ + break; + } + { const size_t d = LIBXSMM_DELTA((const char*)filter_symbol, (const char*)stacktrace[symbol]); + if (d < approx.d) { + approx.s = symbol + 1; + approx.d = d; + } + } + symbol = next++; + } + symbol = LIBXSMM_MAX((next != n ? symbol : approx.s/*not found*/) + mindepth/*shift*/, 0); + } + /* apply filters based on absolute symbol position */ + if ((NULL != filter_symbol || LIBXSMM_MAX(mindepth, 0) <= symbol) && (0 >= maxnsyms || symbol < maxnsyms)) { + if (symbol != next && symbol < n && filter_symbol != stacktrace[symbol] && +# if defined(_WIN32) || defined(__CYGWIN__) + FALSE != SymFromAddr(process, (DWORD64)stacktrace[symbol], NULL, value) && 0 < value->NameLen) +# else + NULL != internal_trace_get_symbolname(stacktrace[symbol], info.buffer, 0, 0)) +# endif + { + /* disable fallback allowing unresolved symbol names */ + info.tid = abs_tid; /* make unsigned */ +# if defined(_WIN32) || defined(__CYGWIN__) + fname = value->Name; +# else + fname = info.buffer; +# endif + } + if (NULL == fname && 0 > info.tid) { /* fallback allowing unresolved symbol names */ +# if defined(__MINGW32__) + sprintf(info.buffer, "%p", stacktrace[symbol]); +# else + sprintf(info.buffer, "0x%" PRIxPTR, (uintptr_t)stacktrace[symbol]); +# endif + fname = info.buffer; + } + } + } +# else +# if (0 == LIBXSMM_SYNC) + static char raw_c; + char */*const*/ raw_value = &raw_c; /* const: avoid warning (below / constant control-flow) */ +# else + char *const raw_value = (char*)LIBXSMM_TLS_GETVALUE(internal_trace_key); +# endif + const off_t fdoff = sizeof(int) * 2; + int* ivalue = NULL, fd = -1; + char* value = NULL; + if (NULL != raw_value) { + ivalue = (int*)raw_value; + abs_tid = (0 <= ivalue[1] ? ivalue[1] : -ivalue[1]); + if (0 > filter || filter == abs_tid) { + fd = ivalue[0]; + if (0 <= fd && fdoff == lseek(fd, fdoff, SEEK_SET)) { + value = raw_value + fdoff; + } +# if !defined(NDEBUG) /* library code is expected to be mute */ + else { + fprintf(stderr, "LIBXSMM ERROR: failed to get buffer\n"); + } +# endif + } + } + else { + char filename[] = "/tmp/.libxsmm_map." LIBXSMM_MKTEMP_PATTERN; + /* coverity[secure_temp] */ + fd = mkstemp(filename); + if (0 <= fd) { + if (0 == unlink(filename) && 0 == posix_fallocate(fd, 0, LIBXSMM_TRACE_SYMBOLSIZE)) { + char *const buffer = (char*)mmap(NULL, LIBXSMM_TRACE_SYMBOLSIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (MAP_FAILED != buffer) { + int check = -1; + ivalue = (int*)buffer; + ivalue[0] = fd; /* valid file descriptor for internal_delete */ + if ( +# if (0 != LIBXSMM_SYNC) + 0 == LIBXSMM_TLS_SETVALUE(internal_trace_key, buffer) && +# endif + (sizeof(int) * 1) == read(fd, &check, sizeof(int)) && + fdoff == lseek(fd, sizeof(int), SEEK_CUR) && + check == fd) + { + const int tid = LIBXSMM_ATOMIC_ADD_FETCH(&internal_trace_initialized, 0 < init ? 1 : -1, LIBXSMM_ATOMIC_RELAXED); + abs_tid = LIBXSMM_ABS(tid) - 1; + LIBXSMM_ASSERT(0 < abs_tid); +# if (0 != LIBXSMM_SYNC) + LIBXSMM_ASSERT(abs_tid < LIBXSMM_NTHREADS_MAX); + internal_trace_symbols[abs_tid] = buffer; +# endif + /* use sign bit to flag enabled fallback for symbol resolution */ + ivalue[1] = -abs_tid; + if (0 > filter || (abs_tid - 1) == filter) { + value = buffer + fdoff; + } + } + else { +# if !defined(NDEBUG) /* library code is expected to be mute */ + fprintf(stderr, "LIBXSMM ERROR: failed to setup buffer\n"); +# endif + internal_delete(buffer); + } + } +# if !defined(NDEBUG) + else { + const int error = errno; + fprintf(stderr, "LIBXSMM ERROR: %s (mmap allocation error #%i)\n", + strerror(error), error); + } +# endif + } +# if !defined(NDEBUG) /* library code is expected to be mute */ + else { + fprintf(stderr, "LIBXSMM ERROR: failed to setup file descriptor (%i)\n", fd); + } +# endif + } + } + if (NULL != value) { + int next = symbol + 1; + if (NULL != filter_symbol) { + struct { size_t d; int s; } approx = { (size_t)LIBXSMM_UNLIMITED, 0 }; + while (next < n && (filter_symbol == stacktrace[symbol] || + NULL != internal_trace_get_symbolname(stacktrace[symbol], value, fd, fdoff))) + { + if (filter_symbol == stacktrace[symbol] || NULL != strstr(value, (const char*)filter_symbol)) { + symbol = next++; /* determine the symbol after the match which is checked below */ + break; + } + { const size_t d = LIBXSMM_DELTA((const char*)filter_symbol, (const char*)stacktrace[symbol]); + if (d < approx.d) { + approx.s = symbol + 1; + approx.d = d; + } + } + symbol = next++; + } + symbol = LIBXSMM_MAX((next != n ? symbol : approx.s/*not found*/) + mindepth/*shift*/, 0); + } + /* apply filters based on absolute symbol position */ + if ((NULL != filter_symbol || LIBXSMM_MAX(mindepth, 0) <= symbol) && (0 >= maxnsyms || symbol < maxnsyms)) { + if (symbol != next && symbol < n && filter_symbol != stacktrace[symbol] && + NULL != internal_trace_get_symbolname(stacktrace[symbol], value, fd, fdoff)) + { + /* disable fallback allowing unresolved symbol names */ + ivalue[1] = abs_tid; /* make unsigned */ + fname = value; + } + if (NULL == fname && 0 > ivalue[1]) { /* fallback to symbol address */ + sprintf(value, "0x%llx", (unsigned long long)stacktrace[symbol]); + fname = value; + } + } + } +# endif + if (threadid) *threadid = abs_tid - 1; + if (depth) *depth = symbol; + } + } + --cerberus; + } +#else + LIBXSMM_UNUSED(depth); + LIBXSMM_UNUSED(threadid); + LIBXSMM_UNUSED(filter_threadid); + LIBXSMM_UNUSED(filter_symbol); + LIBXSMM_UNUSED(filter_mindepth); + LIBXSMM_UNUSED(filter_maxnsyms); +#endif + return fname; +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_NO_TRACE +void libxsmm_trace(FILE* stream, const int* /*filter_threadid*/, const void* /*filter_symbol*/, const int* /*filter_mindepth*/, const int* /*filter_maxnsyms*/); + +LIBXSMM_API void libxsmm_trace(FILE* stream, const int* filter_threadid, const void* filter_symbol, const int* filter_mindepth, const int* filter_maxnsyms) +{ +#if defined(LIBXSMM_TRACE) + unsigned int depth, threadid; + const char *const name = libxsmm_trace_info(&depth, &threadid, filter_threadid, filter_symbol, filter_mindepth, filter_maxnsyms); + if (NULL != name && 0 != *name) { /* implies actual other results to be valid */ + LIBXSMM_ASSERT(NULL != stream/*otherwise fprintf handles the error*/); + if ((NULL == filter_threadid && 0 > internal_trace_threadid) || (NULL != filter_threadid && 0 > *filter_threadid)) { + fprintf(stream, "%*s%s@%u\n", (int)depth, "", name, threadid); + } + else { + fprintf(stream, "%*s%s\n", (int)depth, "", name); + } + } +#else /* suppress warning */ + LIBXSMM_UNUSED(stream); + LIBXSMM_UNUSED(filter_threadid); + LIBXSMM_UNUSED(filter_symbol); + LIBXSMM_UNUSED(filter_mindepth); + LIBXSMM_UNUSED(filter_maxnsyms); +#endif +} + + +#if defined(__TRACE) && defined(__GNUC__) && defined(LIBXSMM_BUILD) + +LIBXSMM_API LIBXSMM_ATTRIBUTE_NO_TRACE void __cyg_profile_func_enter(void* /*this_fn*/, void* /*call_site*/); +LIBXSMM_API void __cyg_profile_func_enter(void* this_fn, void* call_site) +{ +#if defined(LIBXSMM_TRACE) + if (0 > internal_trace_initialized) { + /* NULL: inherit global settings from libxsmm_trace_init */ + libxsmm_trace(stderr, NULL/*filter_threadid*/, "__cyg_profile_func_enter"/*LIBXSMM_FUNCNAME*/, NULL, NULL); + } +#endif + LIBXSMM_UNUSED(this_fn); LIBXSMM_UNUSED(call_site); +} + + +LIBXSMM_API LIBXSMM_ATTRIBUTE_NO_TRACE void __cyg_profile_func_exit(void* /*this_fn*/, void* /*call_site*/); +LIBXSMM_API void __cyg_profile_func_exit(void* this_fn, void* call_site) +{ + LIBXSMM_UNUSED(this_fn); LIBXSMM_UNUSED(call_site); /* suppress warning */ +} + +#endif /*defined(__TRACE) && defined(__GNUC__) && defined(LIBXSMM_BUILD)*/ + diff --git a/third_party/libxsmm/src/libxsmm_trace.h b/third_party/libxsmm/src/libxsmm_trace.h new file mode 100644 index 0000000000000000000000000000000000000000..3a6772b2181a615f4ad151774542197e1b4a3efc --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_trace.h @@ -0,0 +1,124 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_TRACE_H +#define LIBXSMM_TRACE_H + +#include + +#if (defined(__TRACE) || defined(LIBXSMM_BUILD) || !defined(_WIN32)) +# define LIBXSMM_TRACE +#endif +#if !defined(LIBXSMM_TRACE_CALLERID_MAXDEPTH) +# define LIBXSMM_TRACE_CALLERID_MAXDEPTH 8 +#endif +#if !defined(LIBXSMM_TRACE_CALLERID_GCCBUILTIN) && \ + ((!defined(_WIN32) || defined(__MINGW32__) || (defined(_MSC_VER) && defined(__clang__))) && \ + (!defined(__PGI) || LIBXSMM_VERSION2(19, 0) <= LIBXSMM_VERSION2(__PGIC__, __PGIC_MINOR__)) && \ + (defined(__GNUC__) || defined(__clang__))) +# define LIBXSMM_TRACE_CALLERID_GCCBUILTIN +#endif + + +/** Initializes the trace facility; NOT thread-safe. */ +LIBXSMM_API int libxsmm_trace_init( + /* Filter for thread id (-1: all). */ + int filter_threadid, + /* Specify min. depth of stack trace (0: all). */ + int filter_mindepth, + /* Specify max. depth of stack trace (-1: all). */ + int filter_maxnsyms); + +/** Finalizes the trace facility; NOT thread-safe. */ +LIBXSMM_API int libxsmm_trace_finalize(void); + +/** Receives the backtrace of up to 'size' addresses. Returns the actual number of addresses (n <= size). */ +LIBXSMM_API unsigned int libxsmm_backtrace(const void* buffer[], unsigned int size, unsigned int skip); + +#if defined(LIBXSMM_TRACE_CALLERID_GCCBUILTIN) && !defined(__INTEL_COMPILER) +# if defined(__clang__) +# pragma clang diagnostic push +# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) +# pragma GCC diagnostic push +# endif +# if defined(__clang__) +# pragma clang diagnostic ignored "-Wunknown-warning-option" +# if LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) +# pragma clang diagnostic ignored "-Wframe-address" +# endif +# elif defined(__GNUC__) /* no version-check */ +# pragma GCC diagnostic ignored "-Wpragmas" +# pragma GCC diagnostic ignored "-Wframe-address" +# endif +#endif +LIBXSMM_API_INLINE const void* libxsmm_trace_caller_id(unsigned int level) { /* must be inline */ +#if defined(LIBXSMM_TRACE_CALLERID_GCCBUILTIN) + switch (level) { +# if 0 + case 0: return __builtin_extract_return_addr(__builtin_return_address(0)); + case 1: return __builtin_extract_return_addr(__builtin_return_address(1)); + case 2: return __builtin_extract_return_addr(__builtin_return_address(2)); + case 3: return __builtin_extract_return_addr(__builtin_return_address(3)); +# else + case 0: return __builtin_frame_address(1); + case 1: return __builtin_frame_address(2); + case 2: return __builtin_frame_address(3); + case 3: return __builtin_frame_address(4); +# endif + default: +#else + { +# if defined(_WIN32) + if (0 == level) return _AddressOfReturnAddress(); + else +# endif +#endif + { const void* stacktrace[LIBXSMM_TRACE_CALLERID_MAXDEPTH]; + const unsigned int n = libxsmm_backtrace(stacktrace, LIBXSMM_TRACE_CALLERID_MAXDEPTH, 0/*skip*/); + return (level < n ? stacktrace[level] : NULL); + } + } +} +#if defined(LIBXSMM_TRACE_CALLERID_GCCBUILTIN) && !defined(__INTEL_COMPILER) +# if defined(__clang__) +# pragma clang diagnostic pop +# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__) +# pragma GCC diagnostic pop +# endif +#endif + +/** Returns the name of the function where libxsmm_trace is called from; thread-safe. */ +LIBXSMM_API const char* libxsmm_trace_info( + /* Query and output the abs. location in stacktrace (no input). */ + unsigned int* depth, + /* Query and output the thread id (no input). */ + unsigned int* threadid, + /* Filter for thread id (-1: all, NULL: libxsmm_trace_init). */ + const int* filter_threadid, + /* Lookup symbol (depth argument becomes relative to symbol position). */ + const void* filter_symbol, + /* Specify min. abs. position in stack trace (-1 or 0: all, NULL: libxsmm_trace_init). */ + const int* filter_mindepth, + /* Specify max. depth of stack trace (-1 or 0: all, NULL: libxsmm_trace_init). */ + const int* filter_maxnsyms); + +/** Prints an entry of the function where libxsmm_trace is called from (indented/hierarchical). */ +LIBXSMM_API void libxsmm_trace(FILE* stream, + /* Filter for thread id (-1: all, NULL: libxsmm_trace_init). */ + const int* filter_threadid, + /* Lookup symbol (depth argument becomes relative to symbol position). */ + const void* filter_symbol, + /* Specify min. absolute pos. in stack trace (-1 or 0: all, NULL: libxsmm_trace_init). */ + const int* filter_mindepth, + /* Specify max. depth of stack trace (-1 or 0: all, NULL: libxsmm_trace_init). */ + const int* filter_maxnsyms); + +#endif /*LIBXSMM_TRACE_H*/ + diff --git a/third_party/libxsmm/src/libxsmm_xcopy.c b/third_party/libxsmm/src/libxsmm_xcopy.c new file mode 100644 index 0000000000000000000000000000000000000000..4dfc679c328e440e5c38af13300e7f26addbdb58 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_xcopy.c @@ -0,0 +1,735 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#include "libxsmm_xcopy.h" + +#if !defined(LIBXSMM_MCOPY_JIT_TINY) && 0 +# define LIBXSMM_MCOPY_JIT_TINY +#endif + + +/* definition of corresponding variables */ +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT)) +LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_xcopy_jit); +#endif +LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_xcopy_taskscale); +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_mcopy_mbytes); +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_mzero_mbytes); +LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_tcopy_mbytes); +LIBXSMM_APIVAR_PUBLIC_DEF(float libxsmm_mcopy_nscale); +LIBXSMM_APIVAR_PUBLIC_DEF(float libxsmm_mzero_nscale); +LIBXSMM_APIVAR_PUBLIC_DEF(float libxsmm_tcopy_nscale); + + +LIBXSMM_API_INTERN void libxsmm_xcopy_init(int archid) +{ + { /* setup tile sizes according to CPUID or environment */ + if (LIBXSMM_X86_AVX512_CORE <= archid) { /* avx-512/core */ + libxsmm_mcopy_mbytes = 0; + libxsmm_mcopy_nscale = 0.f; + libxsmm_mzero_mbytes = 0; + libxsmm_mzero_nscale = 0.f; + libxsmm_tcopy_mbytes = 32768; + libxsmm_tcopy_nscale = 0.f; + } + else if (LIBXSMM_X86_AVX512_MIC <= archid && LIBXSMM_X86_AVX512_CORE > archid) { + libxsmm_mcopy_mbytes = 0; + libxsmm_mcopy_nscale = 0.f; + libxsmm_mzero_mbytes = 0; + libxsmm_mzero_nscale = 0.f; + libxsmm_tcopy_mbytes = 32768; + libxsmm_tcopy_nscale = 0.f; + } + else { /* avx2 */ + libxsmm_mcopy_mbytes = 0; + libxsmm_mcopy_nscale = 0.f; + libxsmm_mzero_mbytes = 8192; + libxsmm_mzero_nscale = 0.f; + libxsmm_tcopy_mbytes = 4096; + libxsmm_tcopy_nscale = 0.f; + } + } + { /* mcopy: load/adjust tile sizes (measured as if DP) */ + const char *const env_m = getenv("LIBXSMM_MCOPY_M"), *const env_n = getenv("LIBXSMM_MCOPY_N"); + const int m = ((NULL == env_m || 0 == *env_m) ? 0 : atoi(env_m)); + const int n = ((NULL == env_n || 0 == *env_n) ? 0 : atoi(env_n)); + if (0 < m) libxsmm_mcopy_mbytes = LIBXSMM_MAX(m, 1) * 8/*DP*/; + if (0 != libxsmm_mcopy_mbytes && 0 != libxsmm_mcopy_nscale) { + if (0 < n) libxsmm_mcopy_nscale = ((float)(n * 8/*DP*/)) / libxsmm_mcopy_mbytes; + if (1 > (libxsmm_mcopy_nscale * libxsmm_mcopy_mbytes)) { + const float stretch = 1.f / libxsmm_mcopy_mbytes; + libxsmm_mcopy_nscale = LIBXSMM_MAX(stretch, libxsmm_mcopy_nscale); + } + } + } + { /* mzero: load/adjust tile sizes (measured as if DP) */ + const char *const env_m = getenv("LIBXSMM_MZERO_M"), *const env_n = getenv("LIBXSMM_MZERO_N"); + const int m = ((NULL == env_m || 0 == *env_m) ? 0 : atoi(env_m)); + const int n = ((NULL == env_n || 0 == *env_n) ? 0 : atoi(env_n)); + if (0 < m) libxsmm_mzero_mbytes = LIBXSMM_MAX(m, 1) * 8/*DP*/; + if (0 != libxsmm_mzero_mbytes && 0 != libxsmm_mzero_nscale) { + if (0 < n) libxsmm_mzero_nscale = ((float)(n * 8/*DP*/)) / libxsmm_mzero_mbytes; + if (1 > (libxsmm_mzero_nscale * libxsmm_mzero_mbytes)) { + const float stretch = 1.f / libxsmm_mzero_mbytes; + libxsmm_mzero_nscale = LIBXSMM_MAX(stretch, libxsmm_mzero_nscale); + } + } + } + { /* tcopy: load/adjust tile sizes (measured as if DP) */ + const char *const env_m = getenv("LIBXSMM_TCOPY_M"), *const env_n = getenv("LIBXSMM_TCOPY_N"); + const int m = ((NULL == env_m || 0 == *env_m) ? 0 : atoi(env_m)); + const int n = ((NULL == env_n || 0 == *env_n) ? 0 : atoi(env_n)); + if (0 < m) libxsmm_tcopy_mbytes = LIBXSMM_MAX(m, 1) * 8/*DP*/; + if (0 != libxsmm_tcopy_mbytes && 0 != libxsmm_tcopy_nscale) { + if (0 < n) libxsmm_tcopy_nscale = ((float)(n * 8/*DP*/)) / libxsmm_tcopy_mbytes; + if (1 > (libxsmm_tcopy_nscale * libxsmm_tcopy_mbytes)) { + const float stretch = 1.f / libxsmm_tcopy_mbytes; + libxsmm_tcopy_nscale = LIBXSMM_MAX(stretch, libxsmm_tcopy_nscale); + } + } + } +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT)) && defined(LIBXSMM_PLATFORM_X86) + /* check if JIT-code generation is permitted */ + if (LIBXSMM_X86_AVX2 <= libxsmm_target_archid && LIBXSMM_X86_ALLFEAT >= libxsmm_target_archid) { + const char *const env_jit = getenv("LIBXSMM_XCOPY_JIT"); + libxsmm_xcopy_jit = ((NULL == env_jit || 0 == *env_jit) ? (LIBXSMM_XCOPY_JIT) : atoi(env_jit)); + } +#endif + { /* determines if OpenMP tasks are used (when available) */ + const char *const env_t = getenv("LIBXSMM_XCOPY_TASKS"); + libxsmm_xcopy_taskscale = ((NULL == env_t || 0 == *env_t) + ? 0/*disabled*/ : (LIBXSMM_XCOPY_TASKSCALE * atoi(env_t))); + } +} + + +LIBXSMM_API_INTERN void libxsmm_xcopy_finalize(void) +{ +} + + +LIBXSMM_API void libxsmm_matcopy_task_internal(void* out, const void* in, unsigned int typesize, + unsigned int m, unsigned int n, unsigned int ldi, unsigned int ldo, + unsigned int km, unsigned int kn, libxsmm_xcopykernel kernel, + int tid, int ntasks) +{ + const unsigned int tm = (0 == km ? m : km); + const unsigned int tn = (0 == kn ? LIBXSMM_MIN(LIBXSMM_XCOPY_TILE_MIN, n) : kn); + const int mtasks = LIBXSMM_UPDIV(m, tm); + unsigned int m0, m1, n0, n1; + + LIBXSMM_ASSERT_MSG(tid < ntasks && 0 < ntasks, "Invalid task setup"); + LIBXSMM_ASSERT_MSG(tm <= m && tn <= n, "Invalid problem size"); + LIBXSMM_ASSERT_MSG(0 < tm && 0 < tn, "Invalid tile size"); + LIBXSMM_ASSERT_MSG(typesize <= 255, "Invalid type-size"); + LIBXSMM_ASSERT(0 < mtasks); + + if (ntasks <= mtasks) { /* parallelized over M */ + const unsigned int mt = LIBXSMM_UPDIV(m, ntasks); + m0 = LIBXSMM_MIN(tid * mt, m); + m1 = LIBXSMM_MIN(m0 + mt, m); + n0 = 0; n1 = n; + } + else { /* parallelized over M and N */ + const int mntasks = ntasks / mtasks; + const int mtid = tid / mntasks, ntid = tid - mtid * mntasks; + const unsigned int nt = LIBXSMM_UP(LIBXSMM_UPDIV(n, mntasks), tn) ; + m0 = LIBXSMM_MIN(mtid * tm, m); m1 = LIBXSMM_MIN(m0 + tm, m); + n0 = LIBXSMM_MIN(ntid * nt, n); n1 = LIBXSMM_MIN(n0 + nt, n); + } + + LIBXSMM_ASSERT_MSG(m0 <= m1 && m1 <= m, "Invalid task size"); + LIBXSMM_ASSERT_MSG(n0 <= n1 && n1 <= n, "Invalid task size"); + + if (NULL != in) { /* copy-kernel */ + libxsmm_matcopy_internal(out, in, typesize, ldi, ldo, + m0, m1, n0, n1, tm, tn, kernel); + } + else { + libxsmm_matzero_internal(out, typesize, ldo, + m0, m1, n0, n1, tm, tn, kernel); + } +} + + +LIBXSMM_API void libxsmm_otrans_task_internal(void* out, const void* in, unsigned int typesize, + unsigned int m, unsigned int n, unsigned int ldi, unsigned int ldo, + unsigned int km, unsigned int kn, libxsmm_xcopykernel kernel, + int tid, int ntasks) +{ + const unsigned int tm = (0 == km ? m : km); + const unsigned int tn = (0 == kn ? LIBXSMM_MIN(LIBXSMM_XCOPY_TILE_MIN, n) : kn); + const int mtasks = LIBXSMM_UPDIV(m, tm); + unsigned int m0, m1, n0, n1; + + LIBXSMM_ASSERT_MSG(tid < ntasks && 0 < ntasks, "Invalid task setup"); + LIBXSMM_ASSERT_MSG(tm <= m && tn <= n, "Invalid problem size"); + LIBXSMM_ASSERT_MSG(0 < tm && 0 < tn, "Invalid tile size"); + LIBXSMM_ASSERT_MSG(typesize <= 255, "Invalid type-size"); + LIBXSMM_ASSERT(0 < mtasks); + + if (ntasks <= mtasks) { /* parallelized over M */ + const unsigned int mt = LIBXSMM_UPDIV(m, ntasks); + m0 = LIBXSMM_MIN(tid * mt, m); + m1 = LIBXSMM_MIN(m0 + mt, m); + n0 = 0; n1 = n; + } + else { /* parallelized over M and N */ + const int mntasks = ntasks / mtasks; + const int mtid = tid / mntasks, ntid = tid - mtid * mntasks; + const unsigned int nt = LIBXSMM_UP(LIBXSMM_UPDIV(n, mntasks), tn); + m0 = LIBXSMM_MIN(mtid * tm, m); m1 = LIBXSMM_MIN(m0 + tm, m); + n0 = LIBXSMM_MIN(ntid * nt, n); n1 = LIBXSMM_MIN(n0 + nt, n); + } + + LIBXSMM_ASSERT_MSG(m0 <= m1 && m1 <= m, "Invalid task size"); + LIBXSMM_ASSERT_MSG(n0 <= n1 && n1 <= n, "Invalid task size"); + + libxsmm_otrans_internal(out, in, typesize, ldi, ldo, m0, m1, n0, n1, tm, tn, kernel); +} + + +LIBXSMM_API_INTERN void libxsmm_matcopy_internal(void* out, const void* in, + unsigned int typesize, unsigned int ldi, unsigned int ldo, + unsigned int m0, unsigned int m1, unsigned int n0, unsigned int n1, + unsigned int tm, unsigned int tn, libxsmm_xcopykernel kernel) +{ + LIBXSMM_ASSERT(NULL != in); + LIBXSMM_XCOPY(LIBXSMM_MCOPY_KERNEL, LIBXSMM_MCOPY_CALL, kernel, + out, in, typesize, ldi, ldo, tm, tn, m0, m1, n0, n1); +} + + +LIBXSMM_API_INTERN void libxsmm_matzero_internal(void* out, unsigned int typesize, unsigned int ldo, + unsigned int m0, unsigned int m1, unsigned int n0, unsigned int n1, + unsigned int tm, unsigned int tn, libxsmm_xcopykernel kernel) +{ + /* coverity[ptr_arith] */ + LIBXSMM_XCOPY(LIBXSMM_MZERO_KERNEL, LIBXSMM_MZERO_CALL, kernel, + out, NULL, typesize, 0, ldo, tm, tn, m0, m1, n0, n1); +} + + +LIBXSMM_API_INTERN void libxsmm_otrans_internal(void* out, const void* in, + unsigned int typesize, unsigned int ldi, unsigned int ldo, + unsigned int m0, unsigned int m1, unsigned int n0, unsigned int n1, + unsigned int tm, unsigned int tn, libxsmm_xcopykernel kernel) +{ + LIBXSMM_ASSERT(NULL != in); + LIBXSMM_XCOPY(LIBXSMM_TCOPY_KERNEL, LIBXSMM_TCOPY_CALL, kernel, + out, in, typesize, ldi, ldo, tm, tn, m0, m1, n0, n1); +} + + +LIBXSMM_API void libxsmm_matcopy_task(void* out, const void* in, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo, + int tid, int ntasks) +{ + LIBXSMM_INIT + if (0 < typesize && 256 > typesize && m <= ldi && m <= ldo && out != in && + ((NULL != out && 0 < m && 0 < n) || (0 == m && 0 == n)) && + /* use (signed) integer types, but check sanity of input */ + 0 <= tid && tid < ntasks) + { + if (0 < m && 0 < n) { + unsigned int tm, tn, ts; + libxsmm_xcopykernel kernel; + kernel.ptr = NULL; + if (NULL != in) { /* mcopy */ + tm = LIBXSMM_UPDIV(libxsmm_mcopy_mbytes, typesize); + tn = (unsigned int)(libxsmm_mcopy_nscale * tm); + ts = libxsmm_mcopy_mbytes; + } + else { /* mzero */ + tm = LIBXSMM_UPDIV(libxsmm_mzero_mbytes, typesize); + tn = (unsigned int)(libxsmm_mzero_nscale * tm); + ts = libxsmm_mzero_mbytes; + } + if (0 == tm) tm = m; + if (0 == tn) tn = LIBXSMM_MIN(LIBXSMM_XCOPY_TILE_MIN, n); + if (0 != ts && ts < (tm * tn * typesize)) { + tm = LIBXSMM_MAX(ts / (tn * typesize), LIBXSMM_XCOPY_TILE_MIN); + } + if ((unsigned int)m < tm || (unsigned int)n < tn) { + if (1 == ntasks) { + tm = (unsigned int)m; tn = (unsigned int)n; + } + else { + const unsigned int tasksize = (((unsigned int)m) * (unsigned int)n) / ((unsigned int)(ntasks * libxsmm_mcopy_nscale)); + const unsigned int nn = libxsmm_isqrt_u32(tasksize); + const unsigned int mm = (unsigned int)(libxsmm_mcopy_nscale * nn); + tn = LIBXSMM_CLMP((unsigned int)n, 1, nn); + tm = LIBXSMM_CLMP((unsigned int)m, 1, mm); + } + } +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 2)) +# if !defined(LIBXSMM_MCOPY_JIT_TINY) + else +# endif + if (0 != (2 & libxsmm_xcopy_jit)) { /* JIT'ted matrix-copy permitted? */ + switch (typesize) { + case 8: kernel.function = libxsmm_dispatch_meltw_unary((libxsmm_blasint)tm, (libxsmm_blasint)tn, &ldi, &ldo, + LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_MELTW_FLAG_UNARY_NONE, + NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/); + break; + case 4: kernel.function = libxsmm_dispatch_meltw_unary((libxsmm_blasint)tm, (libxsmm_blasint)tn, &ldi, &ldo, + LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_MELTW_FLAG_UNARY_NONE, + NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/); + break; + case 2: kernel.function = libxsmm_dispatch_meltw_unary((libxsmm_blasint)tm, (libxsmm_blasint)tn, &ldi, &ldo, + LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_MELTW_FLAG_UNARY_NONE, + NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/); + break; + case 1: kernel.function = libxsmm_dispatch_meltw_unary((libxsmm_blasint)tm, (libxsmm_blasint)tn, &ldi, &ldo, + LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_MELTW_FLAG_UNARY_NONE, + NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/); + break; + } + } +#endif + libxsmm_matcopy_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, tid, ntasks); + } + } + else { + static int error_once = 0; + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + if (0 > tid || tid >= ntasks) { + fprintf(stderr, "LIBXSMM ERROR: the matrix-copy thread-id or number of tasks is incorrect!\n"); + } + else if (NULL == out) { + fprintf(stderr, "LIBXSMM ERROR: the matrix-copy input and/or output is NULL!\n"); + } + else if (out == in) { + fprintf(stderr, "LIBXSMM ERROR: output and input of the matrix-copy must be different!\n"); + } + else if (0 == typesize || 256 <= typesize) { + fprintf(stderr, "LIBXSMM ERROR: invalid type-size for matrix-copy specified!\n"); + } + else if (ldi < m || ldo < m) { + fprintf(stderr, "LIBXSMM ERROR: the leading dimension(s) of the matrix-copy is/are too small!\n"); + } + else if (0 > m || 0 > n) { + fprintf(stderr, "LIBXSMM ERROR: the matrix extent(s) of the matrix-copy is/are negative!\n"); + } + } + } +} + + +LIBXSMM_API void libxsmm_matcopy(void* out, const void* in, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo) +{ + libxsmm_matcopy_task(out, in, typesize, m, n, ldi, ldo, 0/*tid*/, 1/*ntasks*/); +} + + +LIBXSMM_API void libxsmm_otrans_task(void* out, const void* in, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo, + int tid, int ntasks) +{ + static int error_once = 0; + LIBXSMM_INIT + if (0 < typesize && 256 > typesize && m <= ldi && n <= ldo && + ((NULL != out && NULL != in && 0 < m && 0 < n) || (0 == m && 0 == n)) && + /* use (signed) integer types, but check sanity of input */ + 0 <= tid && tid < ntasks) + { + if (0 < m && 0 < n) { + if (out != in) { + unsigned int tm = LIBXSMM_UPDIV(libxsmm_tcopy_mbytes, typesize); + unsigned int tn = (unsigned int)(libxsmm_tcopy_nscale * tm); + libxsmm_xcopykernel kernel; + kernel.ptr = NULL; + if (0 == tm) tm = m; + if (0 == tn) tn = LIBXSMM_MIN(LIBXSMM_XCOPY_TILE_MIN, n); + if (0 != libxsmm_tcopy_mbytes && libxsmm_tcopy_mbytes < (tm * tn * typesize)) { + tm = LIBXSMM_MAX(libxsmm_tcopy_mbytes / (tn * typesize), LIBXSMM_XCOPY_TILE_MIN); + } + if ((unsigned int)m < tm || (unsigned int)n < tn) { + if (1 == ntasks) { +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1)) + if (0 != (1 & libxsmm_xcopy_jit)) { /* JIT'ted transpose permitted? */ + switch (typesize) { + case 8: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 4: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 2: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 1: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + } + if (NULL != kernel.ptr) { /* JIT-kernel available */ + LIBXSMM_TCOPY_CALL(kernel, typesize, in, ldi, out, ldo); + return; /* fast path */ + } + } +#endif + tm = (unsigned int)m; tn = (unsigned int)n; + } + else { + const unsigned int tasksize = (((unsigned int)m) * (unsigned int)n) / ((unsigned int)(ntasks * libxsmm_tcopy_nscale)); + const unsigned int nn = libxsmm_isqrt_u32(tasksize); + const unsigned int mm = (unsigned int)(libxsmm_tcopy_nscale * nn); + tn = LIBXSMM_CLMP((unsigned int)n, 1, nn); + tm = LIBXSMM_CLMP((unsigned int)m, 1, mm); +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1)) + if (0 != (1 & libxsmm_xcopy_jit)) { /* JIT'ted transpose permitted? */ + switch (typesize) { + case 8: kernel.function = libxsmm_dispatch_meltw_unary((libxsmm_blasint)tm, (libxsmm_blasint)tn, &ldi, &ldo, + LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 4: kernel.function = libxsmm_dispatch_meltw_unary((libxsmm_blasint)tm, (libxsmm_blasint)tn, &ldi, &ldo, + LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 2: kernel.function = libxsmm_dispatch_meltw_unary((libxsmm_blasint)tm, (libxsmm_blasint)tn, &ldi, &ldo, + LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 1: kernel.function = libxsmm_dispatch_meltw_unary((libxsmm_blasint)tm, (libxsmm_blasint)tn, &ldi, &ldo, + LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + } + } +#endif + } + } + libxsmm_otrans_task_internal(out, in, typesize, + (unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo, + tm, tn, kernel, tid, ntasks); + } + else if (ldi == ldo) { + libxsmm_itrans(out, typesize, m, n, ldi, ldo); + } + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: output and input of the transpose must be different!\n"); + } + } + } + else { + if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + if (0 > tid || tid >= ntasks) { + fprintf(stderr, "LIBXSMM ERROR: the transpose thread-id or number of tasks is incorrect!\n"); + } + else if (NULL == out || NULL == in) { + fprintf(stderr, "LIBXSMM ERROR: the transpose input and/or output is NULL!\n"); + } + else if (out == in) { + fprintf(stderr, "LIBXSMM ERROR: output and input of the transpose must be different!\n"); + } + else if (0 == typesize || 256 <= typesize) { + fprintf(stderr, "LIBXSMM ERROR: invalid type-size for matrix-transpose specified!\n"); + } + else if (ldi < m || ldo < n) { + fprintf(stderr, "LIBXSMM ERROR: the leading dimension(s) of the transpose is/are too small!\n"); + } + else if (0 > m || 0 > n) { + fprintf(stderr, "LIBXSMM ERROR: the matrix extent(s) of the transpose is/are negative!\n"); + } + } + } +} + + +LIBXSMM_API void libxsmm_otrans(void* out, const void* in, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo) +{ + libxsmm_otrans_task(out, in, typesize, m, n, ldi, ldo, 0/*tid*/, 1/*ntasks*/); +} + + +LIBXSMM_API_INTERN void libxsmm_itrans_scratch(void* /*inout*/, void* /*scratch*/, unsigned int /*typesize*/, + libxsmm_blasint /*m*/, libxsmm_blasint /*n*/, libxsmm_blasint /*ldi*/, libxsmm_blasint /*ldo*/); +LIBXSMM_API_INTERN void libxsmm_itrans_scratch(void* inout, void* scratch, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo) +{ + LIBXSMM_ASSERT(NULL != inout && 0 < typesize && m <= ldi && n <= ldo); + LIBXSMM_XCOPY_TILE(LIBXSMM_MCOPY_KERNEL, typesize, scratch, inout, ldi, m, 0, n, 0, m); + LIBXSMM_XCOPY_TILE(LIBXSMM_TCOPY_KERNEL, typesize, inout, scratch, m, ldo, 0, m, 0, n); +} + + +LIBXSMM_API_INTERN void libxsmm_itrans_scratch_jit(void* /*inout*/, void* /*scratch*/, unsigned int /*typesize*/, + libxsmm_blasint /*m*/, libxsmm_blasint /*n*/, libxsmm_blasint /*ldi*/, libxsmm_blasint /*ldo*/, libxsmm_xcopykernel /*kernel*/); +LIBXSMM_API_INTERN void libxsmm_itrans_scratch_jit(void* inout, void* scratch, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo, libxsmm_xcopykernel kernel) +{ + LIBXSMM_ASSERT(NULL != inout && 0 < typesize && m <= ldi && n <= ldo); + LIBXSMM_XCOPY_TILE(LIBXSMM_MCOPY_KERNEL, typesize, scratch, inout, ldi, m, 0, n, 0, m); + LIBXSMM_TCOPY_CALL(kernel, typesize, scratch, m, inout, ldo); +} + + +LIBXSMM_API void libxsmm_itrans_internal(char* inout, void* scratch, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo, + libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride[], + libxsmm_xcopykernel kernel, libxsmm_blasint begin, libxsmm_blasint end) +{ +#if !defined(LIBXSMM_XCOPY_JIT) || 0 == (LIBXSMM_XCOPY_JIT & 1) + LIBXSMM_UNUSED(kernel); +#endif + if (NULL != stride) { + if (0 != index_stride) { /* stride array contains indexes */ + libxsmm_blasint i; + if (NULL == scratch) { /* in-place transpose */ + LIBXSMM_ASSERT(m == n && ldi == ldo); + for (i = begin * index_stride; i < (end * index_stride); i += index_stride) { + char *const mat = &inout[(LIBXSMM_ACCESS(const libxsmm_blasint, stride, i) - index_base) * typesize]; + LIBXSMM_ITRANS(typesize, mat, ldi, m); + } + } +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1)) + else if (NULL != kernel.ptr) { /* out-of-place transpose using JIT'ted kernel */ + for (i = begin * index_stride; i < (end * index_stride); i += index_stride) { + char *const mat = &inout[(LIBXSMM_ACCESS(const libxsmm_blasint, stride, i) - index_base) * typesize]; + libxsmm_itrans_scratch_jit(mat, scratch, typesize, m, n, ldi, ldo, kernel); + } + } +#endif + else { /* out-of-place transpose */ + for (i = begin * index_stride; i < (end * index_stride); i += index_stride) { + char *const mat = &inout[(LIBXSMM_ACCESS(const libxsmm_blasint, stride, i) - index_base) * typesize]; + libxsmm_itrans_scratch(mat, scratch, typesize, m, n, ldi, ldo); + } + } + } + else { /* array of pointers to matrices (singular stride is measured in Bytes) */ + const libxsmm_blasint d = *stride - index_base * sizeof(void*); + const char *const endi = inout + (size_t)d * end; + char* i = inout + begin * (size_t)d; + if (NULL == scratch) { /* in-place transpose */ + LIBXSMM_ASSERT(m == n && ldi == ldo); + for (; i < endi; i += d) { + void *const mat = *((void**)i); +#if defined(LIBXSMM_BATCH_CHECK) + if (NULL != mat) +#endif + LIBXSMM_ITRANS(typesize, mat, ldi, m); + } + } +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1)) + else if (NULL != kernel.ptr) { /* out-of-place transpose using JIT'ted kernel */ + for (; i < endi; i += d) { + void *const mat = *((void**)i); +# if defined(LIBXSMM_BATCH_CHECK) + if (NULL != mat) +# endif + libxsmm_itrans_scratch_jit(mat, scratch, typesize, m, n, ldi, ldo, kernel); + } + } +#endif + else { /* out-of-place transpose */ + for (; i < endi; i += d) { + void *const mat = *((void**)i); +#if defined(LIBXSMM_BATCH_CHECK) + if (NULL != mat) +#endif + libxsmm_itrans_scratch(mat, scratch, typesize, m, n, ldi, ldo); + } + } + } + } + else { /* consecutive matrices */ + libxsmm_blasint i; + if (NULL == scratch) { /* in-place transpose */ + LIBXSMM_ASSERT(m == n && ldi == ldo); + for (i = begin; i < end; ++i) { + LIBXSMM_ITRANS(typesize, inout + (size_t)i * typesize, ldi, m); + } + } +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1)) + else if (NULL != kernel.ptr) { /* out-of-place transpose using JIT'ted kernel */ + for (i = begin; i < end; ++i) { + libxsmm_itrans_scratch_jit(inout + (size_t)i * typesize, scratch, typesize, m, n, ldi, ldo, kernel); + } + } +#endif + else { /* out-of-place transpose */ + for (i = begin; i < end; ++i) { + libxsmm_itrans_scratch(inout + (size_t)i * typesize, scratch, typesize, m, n, ldi, ldo); + } + } + } +} + + +LIBXSMM_API void libxsmm_itrans(void* inout, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo) +{ + static int error_once = 0; + if (NULL != inout && 0 < typesize && m <= ldi && n <= ldo) { + if (m == n && ldi == ldo && typesize <= 127) { /* in-place transpose */ + LIBXSMM_ITRANS(typesize, inout, ldi, m); + } + else { /* out-of-place transpose */ + const libxsmm_blasint scratchsize = m * n * typesize; + if (scratchsize <= LIBXSMM_ITRANS_BUFFER_MAXSIZE) { + char buffer[LIBXSMM_ITRANS_BUFFER_MAXSIZE]; + libxsmm_itrans_scratch(inout, buffer, typesize, m, n, ldi, ldo); + } + else { + void* buffer = NULL; + LIBXSMM_INIT + if (EXIT_SUCCESS == libxsmm_xmalloc(&buffer, scratchsize, 0/*auto-align*/, + LIBXSMM_MALLOC_FLAG_SCRATCH | LIBXSMM_MALLOC_FLAG_PRIVATE, + 0/*extra*/, 0/*extra_size*/)) + { + LIBXSMM_ASSERT(NULL != buffer); + libxsmm_itrans_scratch(inout, buffer, typesize, m, n, ldi, ldo); + libxsmm_xfree(buffer, 0/*no check*/); + } + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: failed to allocate buffer for in-place transpose!\n"); + } + } + } + } + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid argument(s) for in-place transpose!\n"); + } +} + + +LIBXSMM_API void libxsmm_itrans_batch(void* inout, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo, + libxsmm_blasint index_base, libxsmm_blasint index_stride, + const libxsmm_blasint stride[], libxsmm_blasint batchsize, + /*unsigned*/int tid, /*unsigned*/int ntasks) +{ + static int error_once = 0; + if (NULL != inout && 0 < typesize && m <= ldi && n <= ldo) { + const libxsmm_blasint scratchsize = m * n * typesize; + const libxsmm_blasint size = LIBXSMM_ABS(batchsize); + const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, ntasks); + const libxsmm_blasint begin = tid * tasksize, span = begin + tasksize; + const libxsmm_blasint end = LIBXSMM_MIN(span, size); + char buffer[LIBXSMM_ITRANS_BUFFER_MAXSIZE]; + char *const mat0 = (char*)inout; + void* scratch = NULL; + libxsmm_xcopykernel kernel = { NULL }; + if (m != n || ldi != ldo || 127 < typesize) { + if (scratchsize <= LIBXSMM_ITRANS_BUFFER_MAXSIZE) { + scratch = buffer; + } + else { + LIBXSMM_INIT + if (EXIT_SUCCESS != libxsmm_xmalloc(&scratch, scratchsize, 0/*auto-align*/, + LIBXSMM_MALLOC_FLAG_SCRATCH | LIBXSMM_MALLOC_FLAG_PRIVATE, + 0/*extra*/, 0/*extra_size*/) + && 0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: failed to allocate buffer for in-place transpose!\n"); + } + } +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1)) + if (0 != (1 & libxsmm_xcopy_jit) /* JIT'ted transpose permitted? */ + /* avoid outgrown transpose kernel upfront */ + && (m <= LIBXSMM_CONFIG_MAX_DIM || n <= LIBXSMM_CONFIG_MAX_DIM)) + { + switch (typesize) { + case 8: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 4: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 2: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + case 1: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo, + LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, + LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT); + break; + } + } +#endif + } + libxsmm_itrans_internal(mat0, scratch, typesize, m, n, ldi, ldo, index_base, + index_stride, stride, kernel, begin, end); + if (NULL != scratch && LIBXSMM_ITRANS_BUFFER_MAXSIZE < scratchsize) { + libxsmm_xfree(scratch, 0/*no check*/); + } + } + else if (0 != libxsmm_verbosity /* library code is expected to be mute */ + && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) + { + fprintf(stderr, "LIBXSMM ERROR: invalid argument(s) for in-place batch-transpose!\n"); + } +} + + +#if defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__)) + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_matcopy)(void* /*out*/, const void* /*in*/, const int* /*typesize*/, + const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const libxsmm_blasint* /*ldi*/, const libxsmm_blasint* /*ldo*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_matcopy)(void* out, const void* in, const int* typesize, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo) +{ + libxsmm_blasint ldx; + LIBXSMM_ASSERT(NULL != typesize && 0 < *typesize && NULL != m); + ldx = *(NULL != ldi ? ldi : m); + libxsmm_matcopy(out, in, (unsigned int)*typesize, *m, *(NULL != n ? n : m), ldx, NULL != ldo ? *ldo : ldx); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_otrans)(void* /*out*/, const void* /*in*/, const int* /*typesize*/, + const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const libxsmm_blasint* /*ldi*/, const libxsmm_blasint* /*ldo*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_otrans)(void* out, const void* in, const int* typesize, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo) +{ + libxsmm_blasint ldx; + LIBXSMM_ASSERT(NULL != typesize && 0 < *typesize && NULL != m); + ldx = *(NULL != ldi ? ldi : m); + libxsmm_otrans(out, in, (unsigned int)*typesize, *m, *(NULL != n ? n : m), ldx, NULL != ldo ? *ldo : ldx); +} + + +/* implementation provided for Fortran 77 compatibility */ +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_itrans)(void* /*inout*/, const int* /*typesize*/, + const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const libxsmm_blasint* /*ldi*/, const libxsmm_blasint* /*ldo*/); +LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_itrans)(void* inout, const int* typesize, + const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo) +{ + const libxsmm_blasint nvalue = *(NULL != n ? n : m); + LIBXSMM_ASSERT(NULL != typesize && 0 < *typesize && NULL != m); + libxsmm_itrans(inout, (unsigned int)*typesize, *m, nvalue, *(NULL != ldi ? ldi : m), NULL != ldo ? *ldo : nvalue); +} + +#endif /*defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/ + diff --git a/third_party/libxsmm/src/libxsmm_xcopy.h b/third_party/libxsmm/src/libxsmm_xcopy.h new file mode 100644 index 0000000000000000000000000000000000000000..e9ccce6a4ad6a4b9dec78a80905dfe68adf33435 --- /dev/null +++ b/third_party/libxsmm/src/libxsmm_xcopy.h @@ -0,0 +1,286 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ +#ifndef LIBXSMM_XCOPY_H +#define LIBXSMM_XCOPY_H + +#include +#include "libxsmm_main.h" + +#if !defined(LIBXSMM_XCOPY_CHECK) && !defined(NDEBUG) +# define LIBXSMM_XCOPY_CHECK +#endif +#if !defined(LIBXSMM_ITRANS_BUFFER_MAXSIZE) +# if defined(NDEBUG) +# define LIBXSMM_ITRANS_BUFFER_MAXSIZE (12 << 10/*12kB*/) +# else +# define LIBXSMM_ITRANS_BUFFER_MAXSIZE 1 +# endif +#endif +#if !defined(LIBXSMM_XCOPY_TASKSCALE) +# define LIBXSMM_XCOPY_TASKSCALE 2 +#endif +#if !defined(LIBXSMM_XCOPY_TILE_MIN) +# define LIBXSMM_XCOPY_TILE_MIN 2 +#endif +/* 0: none, 1: transpose, 2: matcopy, 3: transpose+matcopy */ +#if defined(LIBXSMM_PLATFORM_X86) +# if !defined(LIBXSMM_XCOPY_JIT) +# if (defined(_WIN32) || defined(__CYGWIN__)) +# define LIBXSMM_XCOPY_JIT 0 +# elif defined(NDEBUG) +# define LIBXSMM_XCOPY_JIT 0 +# else +# define LIBXSMM_XCOPY_JIT 3 +# endif +# endif +#else +# define LIBXSMM_XCOPY_JIT 0 +#endif + +/* kernel uses consecutive stores */ +#define LIBXSMM_MZERO_KERNEL(TYPE, TYPESIZE, OUT, IN, LDI, LDO, INDEX_I, INDEX_J, SRC, DST) \ + static /*const*/ TYPE libxsmm_mzero_kernel_src_value_ /* zero */; \ + const TYPE *const SRC = &libxsmm_mzero_kernel_src_value_; \ + TYPE *const DST = (TYPE*)(((char*)(OUT)) + (TYPESIZE) * ((size_t)(INDEX_I) * (LDO) + (INDEX_J))) +/* kernel uses consecutive stores and consecutive loads (copy) */ +#define LIBXSMM_MCOPY_KERNEL(TYPE, TYPESIZE, OUT, IN, LDI, LDO, INDEX_I, INDEX_J, SRC, DST) \ + const TYPE *const SRC = (const TYPE*)(((const char*) (IN)) + (TYPESIZE) * ((size_t)(INDEX_I) * (LDI) + (INDEX_J))); \ + TYPE *const DST = ( TYPE*)((( char*)(OUT)) + (TYPESIZE) * ((size_t)(INDEX_I) * (LDO) + (INDEX_J))) + +#define LIBXSMM_MZERO_CALL(KERNEL, TYPESIZE, SRC, LDI, DST, LDO) { \ + libxsmm_meltw_unary_param libxsmm_mzero_call_args_; \ + libxsmm_mzero_call_args_.in.primary = (void*)(SRC); \ + libxsmm_mzero_call_args_.out.primary = (DST); \ + (KERNEL).function(&libxsmm_mzero_call_args_); \ + LIBXSMM_UNUSED(LDO); \ +} +#define LIBXSMM_MCOPY_CALL(KERNEL, TYPESIZE, SRC, LDI, DST, LDO) { \ + libxsmm_meltw_unary_param libxsmm_mcopy_call_args_; \ + libxsmm_mcopy_call_args_.in.primary = (void*)(SRC); \ + libxsmm_mcopy_call_args_.out.primary = (DST); \ + (KERNEL).function(&libxsmm_mcopy_call_args_); \ + LIBXSMM_UNUSED(LDO); \ +} + +/* kernel uses consecutive stores and strided loads (transpose) */ +#define LIBXSMM_TCOPY_KERNEL(TYPE, TYPESIZE, OUT, IN, LDI, LDO, INDEX_I, INDEX_J, SRC, DST) \ + const TYPE *const SRC = (const TYPE*)(((const char*) (IN)) + (TYPESIZE) * ((size_t)(INDEX_J) * (LDI) + (INDEX_I))); \ + TYPE *const DST = ( TYPE*)((( char*)(OUT)) + (TYPESIZE) * ((size_t)(INDEX_I) * (LDO) + (INDEX_J))) + +/* call JIT-kernel (transpose) */ +#define LIBXSMM_TCOPY_CALL(KERNEL, TYPESIZE, SRC, LDI, DST, LDO) { \ + libxsmm_meltw_unary_param libxsmm_tcopy_call_args_; \ + libxsmm_tcopy_call_args_.in.primary = (void*)(SRC); \ + libxsmm_tcopy_call_args_.out.primary = (DST); \ + (KERNEL).function(&libxsmm_tcopy_call_args_); \ + LIBXSMM_UNUSED(LDO); \ +} + +#define LIBXSMM_XCOPY_LOOP(TYPE, TYPESIZE, XKERNEL, OUT, IN, LDI, LDO, M0, M1, N0, N1) { \ + libxsmm_blasint libxsmm_xcopy_loop_i_, libxsmm_xcopy_loop_j_; \ + for (libxsmm_xcopy_loop_i_ = M0; libxsmm_xcopy_loop_i_ < (libxsmm_blasint)(M1); ++libxsmm_xcopy_loop_i_) { \ + LIBXSMM_PRAGMA_NONTEMPORAL(OUT) \ + for (libxsmm_xcopy_loop_j_ = N0; libxsmm_xcopy_loop_j_ < (libxsmm_blasint)(N1); ++libxsmm_xcopy_loop_j_) { \ + XKERNEL(TYPE, TYPESIZE, OUT, IN, LDI, LDO, libxsmm_xcopy_loop_i_, libxsmm_xcopy_loop_j_, \ + libxsmm_xcopy_loop_src_, libxsmm_xcopy_loop_dst_); *libxsmm_xcopy_loop_dst_ = *libxsmm_xcopy_loop_src_; \ + } \ + } \ +} + +#define LIBXSMM_XCOPY_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, M0, M1, N0, N1) { \ + switch(TYPESIZE) { \ + case 2: { \ + LIBXSMM_XCOPY_LOOP(short, 2, XKERNEL, OUT, IN, LDI, LDO, M0, M1, N0, N1); \ + } break; \ + case 4: { \ + LIBXSMM_XCOPY_LOOP(float, 4, XKERNEL, OUT, IN, LDI, LDO, M0, M1, N0, N1); \ + } break; \ + case 8: { \ + LIBXSMM_XCOPY_LOOP(double, 8, XKERNEL, OUT, IN, LDI, LDO, M0, M1, N0, N1); \ + } break; \ + case 16: { \ + typedef struct /*libxsmm_xcopy_tile_elem_t*/ { double value[2]; } libxsmm_xcopy_tile_elem_t; \ + LIBXSMM_XCOPY_LOOP(libxsmm_xcopy_tile_elem_t, 16, XKERNEL, OUT, IN, LDI, LDO, M0, M1, N0, N1); \ + } break; \ + default: { /* generic type-size */ \ + libxsmm_blasint libxsmm_xcopy_tile_i_, libxsmm_xcopy_tile_j_; \ + for (libxsmm_xcopy_tile_i_ = M0; libxsmm_xcopy_tile_i_ < (libxsmm_blasint)(M1); ++libxsmm_xcopy_tile_i_) { \ + for (libxsmm_xcopy_tile_j_ = N0; libxsmm_xcopy_tile_j_ < (libxsmm_blasint)(N1); ++libxsmm_xcopy_tile_j_) { \ + XKERNEL(char, TYPESIZE, OUT, IN, LDI, LDO, libxsmm_xcopy_tile_i_, libxsmm_xcopy_tile_j_, \ + libxsmm_xcopy_tile_src_, libxsmm_xcopy_tile_dst_); \ + LIBXSMM_MEMCPY127_LOOP(libxsmm_xcopy_tile_dst_, libxsmm_xcopy_tile_src_, TYPESIZE, LIBXSMM_PRAGMA_NONTEMPORAL); \ + } \ + } \ + } \ + } \ +} + +#define LIBXSMM_ITRANS_LOOP(TYPE, INOUT, LD, M) { \ + libxsmm_blasint libxsmm_itrans_loop_i_, libxsmm_itrans_loop_j_; \ + LIBXSMM_ASSERT(NULL != (INOUT) && (M) <= (LD)); \ + for (libxsmm_itrans_loop_i_ = 0; libxsmm_itrans_loop_i_ < (M); ++libxsmm_itrans_loop_i_) { \ + for (libxsmm_itrans_loop_j_ = 0; libxsmm_itrans_loop_j_ < libxsmm_itrans_loop_i_; ++libxsmm_itrans_loop_j_) { \ + TYPE *const libxsmm_itrans_loop_a_ = ((TYPE*)(INOUT)) + (size_t)(LD) * libxsmm_itrans_loop_i_ + libxsmm_itrans_loop_j_; \ + TYPE *const libxsmm_itrans_loop_b_ = ((TYPE*)(INOUT)) + (size_t)(LD) * libxsmm_itrans_loop_j_ + libxsmm_itrans_loop_i_; \ + LIBXSMM_ISWAP(*libxsmm_itrans_loop_a_, *libxsmm_itrans_loop_b_); \ + } \ + } \ +} + +#define LIBXSMM_ITRANS(TYPESIZE, INOUT, LD, M) { \ + switch(TYPESIZE) { \ + case 2: { \ + LIBXSMM_ITRANS_LOOP(short, INOUT, LD, M); \ + } break; \ + case 4: { \ + LIBXSMM_ITRANS_LOOP(int, INOUT, LD, M); \ + } break; \ + case 8: { \ + LIBXSMM_ITRANS_LOOP(int64_t, INOUT, LD, M); \ + } break; \ + default: { /* generic type-size */ \ + const signed char libxsmm_itrans_c_ = (signed char)(TYPESIZE); \ + libxsmm_blasint libxsmm_itrans_i_, libxsmm_itrans_j_; \ + LIBXSMM_ASSERT(NULL != (INOUT) && (M) <= (LD)); \ + LIBXSMM_ASSERT(0 < (TYPESIZE) && (TYPESIZE) <= 127); \ + for (libxsmm_itrans_i_ = 0; libxsmm_itrans_i_ < (M); ++libxsmm_itrans_i_) { \ + for (libxsmm_itrans_j_ = 0; libxsmm_itrans_j_ < libxsmm_itrans_i_; ++libxsmm_itrans_j_) { \ + char *const libxsmm_itrans_a_ = &((char*)(INOUT))[((LD)*libxsmm_itrans_i_+libxsmm_itrans_j_)*(TYPESIZE)]; \ + char *const libxsmm_itrans_b_ = &((char*)(INOUT))[((LD)*libxsmm_itrans_j_+libxsmm_itrans_i_)*(TYPESIZE)]; \ + signed char libxsmm_itrans_k_ = 0; \ + for (; libxsmm_itrans_k_ < libxsmm_itrans_c_; ++libxsmm_itrans_k_) { \ + LIBXSMM_ISWAP( \ + libxsmm_itrans_a_[libxsmm_itrans_k_], \ + libxsmm_itrans_b_[libxsmm_itrans_k_]); \ + } \ + } \ + } \ + } \ + } \ +} + +#define LIBXSMM_MZERO_KERNEL_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, M0, M1, N0, N1) \ + LIBXSMM_XCOPY_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, N0, N1, M0, M1) +#define LIBXSMM_MCOPY_KERNEL_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, M0, M1, N0, N1) \ + LIBXSMM_XCOPY_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, N0, N1, M0, M1) +#define LIBXSMM_TCOPY_KERNEL_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, M0, M1, N0, N1) \ + LIBXSMM_XCOPY_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, M0, M1, N0, N1) + +#define LIBXSMM_XCOPY_NONJIT(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, M0, M1, N0, N1) \ + LIBXSMM_CONCATENATE(XKERNEL,_TILE)(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, M0, M1, N0, N1) + +#if 1 +# define LIBXSMM_XCOPY_PRECOND(COND) +#else +# define LIBXSMM_XCOPY_PRECOND(COND) COND +#endif + +#define LIBXSMM_XCOPY_TILES(XKERNEL, KERNEL_CALL, KERNEL, OUT, IN, TYPESIZE, LDI, LDO, TILE_M, TILE_N, M0, M1, N0, N1) { \ + libxsmm_blasint libxsmm_xcopy_i_ = M0, libxsmm_xcopy_j_ = N0; \ + LIBXSMM_ASSERT_MSG(0 < (TILE_M) && 0 < (TILE_N), "XCOPY cannot make progress"); \ + if (NULL != (KERNEL).ptr) { /* inner tiles with JIT */ \ + for (; libxsmm_xcopy_i_ < (((libxsmm_blasint)M1) - ((libxsmm_blasint)TILE_M) + 1); libxsmm_xcopy_i_ += TILE_M) { \ + for (libxsmm_xcopy_j_ = N0; libxsmm_xcopy_j_ < (((libxsmm_blasint)N1) - ((libxsmm_blasint)TILE_N) + 1); libxsmm_xcopy_j_ += TILE_N) { \ + XKERNEL(char, TYPESIZE, OUT, IN, LDI, LDO, libxsmm_xcopy_i_, libxsmm_xcopy_j_, libxsmm_xcopy_src_, libxsmm_xcopy_dst_); \ + KERNEL_CALL(KERNEL, TYPESIZE, libxsmm_xcopy_src_, LDI, libxsmm_xcopy_dst_, LDO); \ + } \ + } \ + } \ + else { /* inner tiles without JIT */ \ + for (; libxsmm_xcopy_i_ < (((libxsmm_blasint)M1) - ((libxsmm_blasint)TILE_M) + 1); libxsmm_xcopy_i_ += TILE_M) { \ + for (libxsmm_xcopy_j_ = N0; libxsmm_xcopy_j_ < (((libxsmm_blasint)N1) - ((libxsmm_blasint)TILE_N) + 1); libxsmm_xcopy_j_ += TILE_N) { \ + LIBXSMM_XCOPY_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, \ + libxsmm_xcopy_i_, libxsmm_xcopy_i_ + (TILE_M), \ + libxsmm_xcopy_j_, libxsmm_xcopy_j_ + (TILE_N)); \ + } \ + } \ + } \ + { /* remainder/border tiles */ \ + LIBXSMM_XCOPY_PRECOND(if (libxsmm_xcopy_j_ < ((libxsmm_blasint)N1))) { \ + for (libxsmm_xcopy_i_ = M0; libxsmm_xcopy_i_ < (((libxsmm_blasint)M1) - ((libxsmm_blasint)TILE_M) + 1); libxsmm_xcopy_i_ += TILE_M) { \ + LIBXSMM_XCOPY_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, \ + libxsmm_xcopy_i_, libxsmm_xcopy_i_ + (TILE_M), \ + libxsmm_xcopy_j_, N1); \ + } \ + } \ + LIBXSMM_XCOPY_PRECOND(if (libxsmm_xcopy_i_ < ((libxsmm_blasint)M1))) { \ + for (libxsmm_xcopy_j_ = N0; libxsmm_xcopy_j_ < (((libxsmm_blasint)N1) - ((libxsmm_blasint)TILE_N)); libxsmm_xcopy_j_ += TILE_N) { \ + LIBXSMM_XCOPY_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, \ + libxsmm_xcopy_i_, M1, \ + libxsmm_xcopy_j_, libxsmm_xcopy_j_ + (TILE_N)); \ + } \ + } \ + LIBXSMM_XCOPY_PRECOND(if (libxsmm_xcopy_i_ < ((libxsmm_blasint)M1) && libxsmm_xcopy_j_ < ((libxsmm_blasint)N1))) { \ + LIBXSMM_XCOPY_TILE(XKERNEL, TYPESIZE, OUT, IN, LDI, LDO, \ + libxsmm_xcopy_i_, M1, \ + libxsmm_xcopy_j_, N1); \ + } \ + } \ +} + +#define LIBXSMM_MZERO_KERNEL_TILES(XKERNEL, KERNEL_CALL, KERNEL, OUT, IN, TYPESIZE, LDI, LDO, TILE_M, TILE_N, M0, M1, N0, N1) \ + LIBXSMM_XCOPY_TILES(XKERNEL, KERNEL_CALL, KERNEL, OUT, IN, TYPESIZE, LDI, LDO, TILE_N, TILE_M, N0, N1, M0, M1) +#define LIBXSMM_MCOPY_KERNEL_TILES(XKERNEL, KERNEL_CALL, KERNEL, OUT, IN, TYPESIZE, LDI, LDO, TILE_M, TILE_N, M0, M1, N0, N1) \ + LIBXSMM_XCOPY_TILES(XKERNEL, KERNEL_CALL, KERNEL, OUT, IN, TYPESIZE, LDI, LDO, TILE_N, TILE_M, N0, N1, M0, M1) +#define LIBXSMM_TCOPY_KERNEL_TILES(XKERNEL, KERNEL_CALL, KERNEL, OUT, IN, TYPESIZE, LDI, LDO, TILE_M, TILE_N, M0, M1, N0, N1) \ + LIBXSMM_XCOPY_TILES(XKERNEL, KERNEL_CALL, KERNEL, OUT, IN, TYPESIZE, LDI, LDO, TILE_M, TILE_N, M0, M1, N0, N1) + +#define LIBXSMM_XCOPY(XKERNEL, KERNEL_CALL, KERNEL, OUT, IN, TYPESIZE, LDI, LDO, TILE_M, TILE_N, M0, M1, N0, N1) \ + LIBXSMM_CONCATENATE(XKERNEL,_TILES)(XKERNEL, KERNEL_CALL, KERNEL, OUT, IN, TYPESIZE, LDI, LDO, TILE_M, TILE_N, M0, M1, N0, N1) + +/** Initializes the transpose functionality; NOT thread-safe. */ +LIBXSMM_API_INTERN void libxsmm_xcopy_init(int archid); +/** Finalizes the transpose functionality; NOT thread-safe. */ +LIBXSMM_API_INTERN void libxsmm_xcopy_finalize(void); + +LIBXSMM_API void libxsmm_matcopy_task_internal(void* out, const void* in, unsigned int typesize, + unsigned int m, unsigned int n, unsigned int ldi, unsigned int ldo, + unsigned int km, unsigned int kn, libxsmm_xcopykernel kernel, + int tid, int ntasks); +LIBXSMM_API void libxsmm_otrans_task_internal(void* out, const void* in, unsigned int typesize, + unsigned int m, unsigned int n, unsigned int ldi, unsigned int ldo, + unsigned int km, unsigned int kn, libxsmm_xcopykernel kernel, + int tid, int ntasks); + +LIBXSMM_API_INTERN void libxsmm_matcopy_internal(void* out, const void* in, + unsigned int typesize, unsigned int ldi, unsigned int ldo, + unsigned int m0, unsigned int m1, unsigned int n0, unsigned int n1, + unsigned int tm, unsigned int tn, libxsmm_xcopykernel kernel); +LIBXSMM_API_INTERN void libxsmm_matzero_internal(void* out, + unsigned int typesize, unsigned int ldo, + unsigned int m0, unsigned int m1, unsigned int n0, unsigned int n1, + unsigned int tm, unsigned int tn, libxsmm_xcopykernel kernel); +LIBXSMM_API_INTERN void libxsmm_otrans_internal(void* out, const void* in, + unsigned int typesize, unsigned int ldi, unsigned int ldo, + unsigned int m0, unsigned int m1, unsigned int n0, unsigned int n1, + unsigned int tm, unsigned int tn, libxsmm_xcopykernel kernel); +LIBXSMM_API void libxsmm_itrans_internal(char* inout, void* scratch, unsigned int typesize, + libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo, + libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride[], + libxsmm_xcopykernel kernel, libxsmm_blasint begin, libxsmm_blasint end); + +#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT)) +/** Determines whether JIT-kernels are used or not; values see LIBXSMM_XCOPY_JIT. */ +LIBXSMM_APIVAR_PUBLIC(int libxsmm_xcopy_jit); +#endif +/** Determines if OpenMP tasks are used, and scales beyond the number of threads. */ +LIBXSMM_APIVAR_PUBLIC(int libxsmm_xcopy_taskscale); +/** M-extent of type-size in Byte. */ +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_mcopy_mbytes); +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_mzero_mbytes); +LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_tcopy_mbytes); +/** M-factor shaping the N-extent. */ +LIBXSMM_APIVAR_PUBLIC(float libxsmm_mcopy_nscale); +LIBXSMM_APIVAR_PUBLIC(float libxsmm_mzero_nscale); +LIBXSMM_APIVAR_PUBLIC(float libxsmm_tcopy_nscale); + +#endif /*LIBXSMM_XCOPY_H*/ + diff --git a/third_party/libxsmm/src/template/libxsmm_config.h b/third_party/libxsmm/src/template/libxsmm_config.h new file mode 100644 index 0000000000000000000000000000000000000000..bfb98616e75ae422fa31ff29befc860b1353e0b7 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_config.h @@ -0,0 +1,44 @@ +#ifndef LIBXSMM_CONFIG_H +#define LIBXSMM_CONFIG_H + +#if !defined(LIBXSMM_DEFAULT_CONFIG) && defined(LIBXSMM_SOURCE_H) && !defined(LIBXSMM_CONFIGURED) +# define LIBXSMM_DEFAULT_CONFIG +#endif +#if !defined(LIBXSMM_DEFAULT_CONFIG) && defined(_WIN32) +# define LIBXSMM_DEFAULT_CONFIG +#endif + +#if !defined(LIBXSMM_DEFAULT_CONFIG) && (!defined(LIBXSMM_SOURCE_H) || defined(LIBXSMM_CONFIGURED)) +# include "libxsmm_version.h" +$LIBXSMM_OFFLOAD_BUILD +$MNK_PREPROCESSOR_LIST +#else +# define LIBXSMM_CONFIG_VERSION "" +# define LIBXSMM_CONFIG_BRANCH "" +# define LIBXSMM_CONFIG_VERSION_MAJOR INT_MAX +# define LIBXSMM_CONFIG_VERSION_MINOR INT_MAX +# define LIBXSMM_CONFIG_VERSION_UPDATE INT_MAX +# define LIBXSMM_CONFIG_VERSION_PATCH INT_MAX +# define LIBXSMM_CONFIG_BUILD_DATE INT_MAX +#endif + +#define LIBXSMM_CONFIG_CACHELINE $CACHELINE +#define LIBXSMM_CONFIG_ALIGNMENT $CACHELINE +#define LIBXSMM_CONFIG_MALLOC $MALLOC +#define LIBXSMM_CONFIG_ILP64 $ILP64 +#define LIBXSMM_CONFIG_SYNC $SYNC +#define LIBXSMM_CONFIG_JIT $JIT + +#define LIBXSMM_CONFIG_PREFETCH $PREFETCH +#define LIBXSMM_CONFIG_MAX_MNK $MAX_MNK +#define LIBXSMM_CONFIG_MAX_DIM $MAX_DIM +#define LIBXSMM_CONFIG_AVG_DIM $AVG_DIM +#define LIBXSMM_CONFIG_MAX_M $MAX_M +#define LIBXSMM_CONFIG_MAX_N $MAX_N +#define LIBXSMM_CONFIG_MAX_K $MAX_K +#define LIBXSMM_CONFIG_FLAGS $FLAGS +#define LIBXSMM_CONFIG_ALPHA $ALPHA +#define LIBXSMM_CONFIG_BETA $BETA +#define LIBXSMM_CONFIG_WRAP $WRAP + +#endif diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_bf16_macros_define.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_bf16_macros_define.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..fd0c448bf635e34d415f5c3fed4c3430e5016f10 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_bf16_macros_define.tpl.c @@ -0,0 +1,95 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +#if 0 +#define USE_CLDEMOTE +#define WR_PREFETCH_OUTPUT +#endif + +#if defined(LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI) +# define LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH( A ) (__m256i)_mm512_cvtneps_pbh( A ) +# define LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( A, B ) (__m512i)_mm512_cvtne2ps_pbh( A, B ) +#else +# define LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH( A ) LIBXSMM_INTRINSICS_MM512_CVT_FP32_BF16( A ) +# define LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( A, B ) LIBXSMM_INTRINSICS_MM512_CVT2_FP32_BF16( A, B ) +#endif + +#ifdef WR_PREFETCH_OUTPUT +#define prefetchwt_chunk(ptr, nbytes) do { \ + int __i; \ + for (__i = 0; __i < nbytes; __i += 64) { \ + _mm_prefetch((char*)ptr+__i, _MM_HINT_ET0); \ + } \ +} while(0) +#endif + +#ifdef USE_CLDEMOTE +#define LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(in, out, length) do { \ + unsigned int full_chunks = length / 32; \ + unsigned int remainder = length % 32; \ + int __i = 0; \ + if (remainder == 0) { \ + for ( __i = 0; __i < length; __i+= 32) { \ + _mm512_storeu_si512((libxsmm_bfloat16*)out+__i, LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i+16), LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i))); \ + _mm_cldemote((libxsmm_bfloat16*)out+__i); \ + } \ + } else { \ + unsigned int chunk; \ + for ( chunk = 0; chunk < full_chunks; chunk++) { \ + __i = chunk * 32; \ + _mm512_storeu_si512((libxsmm_bfloat16*)out+__i, LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i+16), LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i))); \ + _mm_cldemote((libxsmm_bfloat16*)out+__i); \ + } \ + libxsmm_rne_convert_fp32_bf16((const float*)in+32*full_chunks, (element_output_type*)out+32*full_chunks, remainder); \ + _mm_cldemote((libxsmm_bfloat16*)out+32*full_chunks); \ + } \ +} while(0) +#else +#define LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(in, out, length) do { \ + unsigned int full_chunks = length / 32; \ + unsigned int remainder = length % 32; \ + int __i = 0; \ + if (remainder == 0) { \ + for ( __i = 0; __i < length; __i+= 32) { \ + _mm512_storeu_si512((libxsmm_bfloat16*)out+__i, LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i+16), LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i))); \ + } \ + } else { \ + unsigned int chunk; \ + for ( chunk = 0; chunk < full_chunks; chunk++) { \ + __i = chunk * 32; \ + _mm512_storeu_si512((libxsmm_bfloat16*)out+__i, LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i+16), LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i))); \ + } \ + libxsmm_rne_convert_fp32_bf16((const float*)in+32*full_chunks, (element_output_type*)out+32*full_chunks, remainder); \ + } \ +} while(0) +#endif + +#define LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32(in, out, length) do { \ + unsigned int full_chunks = length / 16; \ + unsigned int remainder = length % 16; \ + int __i = 0; \ + if (remainder == 0) { \ + for ( __i = 0; __i < length; __i+= 16) { \ + _mm512_storeu_ps( (float*)out+__i, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS( _mm256_loadu_si256((__m256i*)((const libxsmm_bfloat16*)in+__i)))); \ + } \ + } else { \ + unsigned int chunk; \ + for ( chunk = 0; chunk < full_chunks; chunk++) { \ + __i = chunk * 16; \ + _mm512_storeu_ps( (float*)out+__i, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS( _mm256_loadu_si256((__m256i*)((const libxsmm_bfloat16*)in+__i)))); \ + } \ + libxsmm_convert_bf16_f32((const libxsmm_bfloat16*)in+16*full_chunks, (float*)out+16*full_chunks, remainder); \ + } \ +} while(0) + +#define _mm512_loadcvt_bf16_fp32(A) LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)(A))) +#define _mm512_storecvt_fp32_bf16(A,B) _mm256_storeu_si256((__m256i*)(A),(__m256i)LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH((B))) +#define _mm512_streamstorecvt_fp32_bf16(A,B) _mm256_stream_si256((__m256i*)(A), (__m256i)LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH((B))) + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_bf16_macros_undefine.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_bf16_macros_undefine.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..f5e2a127c734ba00bcc4bcc53e870101b384113b --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_bf16_macros_undefine.tpl.c @@ -0,0 +1,28 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +#undef LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16 +#undef LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32 +#undef LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH +#undef LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH +#undef _mm512_loadcvt_bf16_fp32 +#undef _mm512_storecvt_fp32_bf16 +#undef _mm512_streamstorecvt_fp32_bf16 + +#ifdef USE_CLDEMOTE +#undef USE_CLDEMOTE +#endif + +#ifdef WR_PREFETCH_OUTPUT +#undef prefetchwt_chunk +#undef WR_PREFETCH_OUTPUT +#endif + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..5afa0f7f9c361cc9a4d70a29e4d78c83818a76fe --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c @@ -0,0 +1,177 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Rajkishore Barik, Ankush Mandal, Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +int imgifm1, img, ofm1, ifm1, oj, ij, oi, ii, kj, ki, ifm2, ofm2, ifm1ofm1; +/* computing first logical thread */ +const int ltid = tid - start_thread; + +/* number of tasks that could be run in parallel */ +const int work = handle->desc.N * handle->blocksifm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks for transpose that could be run in parallel */ +int transpose_work = handle->blocksifm * handle->blocksofm; +/* compute chunk size */ +const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; +const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + +/* offset pointer in case of physical padding */ +element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; + +/* Weight and transpose_weight tensor declaration */ +LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +LIBXSMM_VLA_DECL(6, element_filter_type, tr_wt, (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset), handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); +/* define weight pointer which has the correct format */ +element_filter_type* weight_base = 0; + +/* padding via stack allocated buffers */ +const int padded_w = handle->desc.W + (2 * handle->desc.pad_w); +const int padded_h = handle->desc.H + (2 * handle->desc.pad_h); +const int size_tls1 = padded_h * padded_w * handle->ifmblock; +element_input_type *const del_input_scratch_padding = (element_input_type*)((char*)handle->scratch + handle->bwd_packing_padding_scratch_offset) + ltid * size_tls1; +for ( ii = 0; ii < size_tls1; ++ii ) { del_input_scratch_padding[ii] = (element_input_type)0; } + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +/* transpose filters, if requested */ +if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) { + weight_base = (element_filter_type*)handle->reg_filter_tr->data; +} else { + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / handle->blocksifm; + ifm1 = ifm1ofm1 % handle->blocksifm; + for (kj=0; kj < handle->desc.R; kj++) { + for (ki=0; ki < handle->desc.S; ki++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2, ifm2, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } + } + } + } + } + weight_base = (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset); + + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); +} + +{/* open new scope for additional variable declarations (C89) */ +LIBXSMM_VLA_DECL(5, element_input_type, del_input, (element_output_type*)handle->grad_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); +LIBXSMM_VLA_DECL(3, element_input_type, del_input_padded, del_input_scratch_padding, padded_w, handle->ifmblock); +LIBXSMM_VLA_DECL(5, const element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +LIBXSMM_VLA_DECL(6, const element_filter_type, weight, weight_base, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + +for (imgifm1 = thr_begin; imgifm1 < thr_end; ++imgifm1) { + img = imgifm1 / handle->blocksifm; + ifm1 = imgifm1 % handle->blocksifm; + + /* check if we need padding, for now we do physical padding on the fly, however we can play with N parameter of the GEMM */ + /* @TODO: add variant which deals with multiple GEMMS by varying N to deal with padding */ + if ( (handle->desc.pad_h == handle->desc.pad_h_in) && (handle->desc.pad_w == handle->desc.pad_w_in) ) { + + /* reset result buffer to zero when intent is to overwrite when first block + of input channels should be convoluted */ + if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) ) { + element_input_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, del_input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); + LIBXSMM_PRAGMA_SIMD + for (ij = 0; ij < handle->ifhp*handle->ifwp*handle->ifmblock; ij++) { + temp_ptr[ij] = (element_input_type)0; + } + } + + /* run convolution */ + for (ofm1 = 0; ofm1 < handle->blocksofm; ++ofm1) { + for ( oj = 0; oj < handle->ofh; ++oj) { + ij = oj * handle->desc.u; + oi = 0; ii = 0; + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + gemm_kernel( &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij + kj, ii + ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) ); + } + } + } + } + + /* zero rim in case of physical padding.... this code is extremely stupid and crappy as it requires a complicated if... */ + if (handle->desc.pad_h_in > 0 || handle->desc.pad_w_in > 0) { + for ( ij = 0; ij < handle->ifhp; ij++ ) { + for ( ii = 0; ii < handle->ifwp; ii++ ) { + if ( (ij < handle->desc.pad_h_in) || (ij >= (handle->desc.H+handle->desc.pad_h_in)) || + (ii < handle->desc.pad_w_in) || (ii >= (handle->desc.W+handle->desc.pad_w_in)) ) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } else { + /* reset result buffer to zero when intent is to overwrite when first block + of input channels should be convoluted */ + if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) ) { + LIBXSMM_PRAGMA_SIMD + for (ij = 0; ij < size_tls1; ++ij) { + del_input_scratch_padding[ij] = (element_output_type)0; + } + } else { + for (ij = 0; ij < handle->desc.H; ij++) { + for (ii = 0; ii < handle->desc.W; ii++) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(3, del_input_padded, ij + handle->desc.pad_h, ii + handle->desc.pad_w, ifm2, padded_w, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + + /* run convolution */ + for (ofm1 = 0; ofm1 < handle->blocksofm; ++ofm1) { + for ( oj = 0; oj < handle->ofh; ++oj) { + ij = oj * handle->desc.u; + oi = 0; ii = 0; + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + gemm_kernel( &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(3, del_input_padded, ij + kj, ii + ki, 0, padded_w, handle->ifmblock) ); + } + } + } + } + + /* input padding copy back */ + for (ij = 0; ij < handle->desc.H; ij++) { + for (ii = 0; ii < handle->desc.W; ii++) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(3, del_input_padded, ij + handle->desc.pad_h, ii + handle->desc.pad_w, ifm2, padded_w, handle->ifmblock); + } + } + } + } +} /* end of imgifm1 loop */ + +} /* end of new scope for additional variable declarations (C89) */ + +libxsmm_barrier_wait(handle->barrier, ltid); diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..40f9fd0a6c36cb56b0ce96d84199293ac2dfb34a --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c @@ -0,0 +1,172 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ + +int imgifm1, img, ofm1, ifm1, oj, ij, oi, ii, kj, ki, ifm2, ofm2; +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* auxiliary lp variables */ +int ofmblock_lp = handle->ofmblock/handle->fm_lp_block; +int ifmblock_lp = handle->ifmblock/handle->fm_lp_block; +int lpb = handle->fm_lp_block; +unsigned long long n_blocks = handle->blocksofm; + +/* number of tasks that could be run in parallel */ +int task; +const int work = handle->desc.N * handle->blocksifm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks for transpose that could be run in parallel */ +int transpose_work = handle->blocksifm * handle->blocksofm * handle->desc.R * handle->desc.S; +/* compute chunk size */ +const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; +const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + +/* offset pointer in case of physical padding */ +element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; + +/* Weight and transpose_weight tensor declaration */ +LIBXSMM_VLA_DECL(7, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb); +LIBXSMM_VLA_DECL(7, element_filter_type, tr_wt, (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset), handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + +/* define weight pointer which has the correct format */ +element_filter_type* weight_base = 0; + +/* padding via stack allocated buffers */ +const int padded_w = handle->desc.W + (2 * handle->desc.pad_w); +const int padded_h = handle->desc.H + (2 * handle->desc.pad_h); +const int size_tls1 = padded_h * padded_w * handle->ifmblock; +float *const del_input_scratch_padding = (float*)((char*)handle->scratch + handle->bwd_packing_padding_scratch_offset) + ltid * size_tls1; +for ( ii = 0; ii < size_tls1; ++ii ) { del_input_scratch_padding[ii] = (float)0.0; } + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +/* transpose filters, if requested */ +if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) == 0 ) { + for (task = transpose_thr_begin; task < transpose_thr_end; ++task) { + ifm1 = task/(handle->blocksofm * handle->desc.R * handle->desc.S); + ofm1 = (task%(handle->blocksofm * handle->desc.R * handle->desc.S))/(handle->desc.R * handle->desc.S); + kj = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))/handle->desc.S; + ki = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))%handle->desc.S; + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(7, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2/lpb, ifm2, ofm2%lpb, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb) = + LIBXSMM_VLA_ACCESS(7, wt, ofm1, ifm1, kj, ki, ifm2/lpb, ofm2, ifm2%lpb, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb); + } + } + } + weight_base = (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset); + + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); +} else { + weight_base = (element_filter_type*)handle->reg_filter_tr->data; +} + +{/* open new scope for additional variable declarations (C89) */ +LIBXSMM_VLA_DECL(5, element_input_type, del_input, (element_output_type*)handle->grad_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); +LIBXSMM_VLA_DECL(3, float, del_input_padded, del_input_scratch_padding, padded_w, handle->ifmblock); +LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +LIBXSMM_VLA_DECL(7, element_filter_type, weight, weight_base, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); +/* Auxiliary fp32 accumulators */ +float *del_inp_fp32 = (float*)((char*)handle->scratch + handle->bwd_lp_input_full_scratch_offset); +LIBXSMM_VLA_DECL(5, float, del_input_fp32, del_inp_fp32, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + +for (imgifm1 = thr_begin; imgifm1 < thr_end; ++imgifm1) { + img = imgifm1 / handle->blocksifm; + ifm1 = imgifm1 % handle->blocksifm; + + /* check if we need padding, for now we do physical padding on the fly, however we can play with N parameter of the GEMM */ + /* @TODO: add variant which deals with multiple GEMMS by varying N to deal with padding */ + if ( (handle->desc.pad_h == handle->desc.pad_h_in) && (handle->desc.pad_w == handle->desc.pad_w_in) ) { + + /* reset result buffer to zero when intent is to overwrite when first block + of input channels should be convoluted */ + if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) ) { + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, del_input_fp32, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); + LIBXSMM_PRAGMA_SIMD + for (ij = 0; ij < handle->ifhp*handle->ifwp*handle->ifmblock; ij++) { + temp_ptr[ij] = (float)0.0; + } + } + + /* run convolution */ + for ( oj = 0; oj < handle->ofh; ++oj) { + ij = oj * handle->desc.u; + oi = 0; ii = 0; + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + bf16fp32_brgemm_kernel( &LIBXSMM_VLA_ACCESS(7, weight, ifm1, 0, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb), + &LIBXSMM_VLA_ACCESS(5, output, img, 0, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij + kj, ii + ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), &n_blocks ); + } + } + } + + /* Downconvert computed result to bf16 */ + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + handle->ifhp * handle->ifwp * handle->ifmblock); + + /* zero rim in case of physical padding.... this code is extremely stupid and crappy as it requires a complicated if... */ + if (handle->desc.pad_h_in > 0 || handle->desc.pad_w_in > 0) { + for ( ij = 0; ij < handle->ifhp; ij++ ) { + for ( ii = 0; ii < handle->ifwp; ii++ ) { + if ( (ij < handle->desc.pad_h_in) || (ij >= (handle->desc.H+handle->desc.pad_h_in)) || + (ii < handle->desc.pad_w_in) || (ii >= (handle->desc.W+handle->desc.pad_w_in)) ) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + + } else { + /* reset result buffer to zero when intent is to overwrite when first block + of input channels should be convoluted */ + LIBXSMM_PRAGMA_SIMD + for (ij = 0; ij < size_tls1; ++ij) { + del_input_scratch_padding[ij] = (float)0.0; + } + + + /* run convolution */ + for ( oj = 0; oj < handle->ofh; ++oj) { + ij = oj * handle->desc.u; + oi = 0; ii = 0; + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + bf16fp32_brgemm_kernel( &LIBXSMM_VLA_ACCESS(7, weight, ifm1, 0, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb), + &LIBXSMM_VLA_ACCESS(5, output, img, 0, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(3, del_input_padded, ij + kj, ii + ki, 0, padded_w, handle->ifmblock), &n_blocks ); + } + } + } + + /* input padding copy back */ + for (ij = 0; ij < handle->desc.H; ij++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(&LIBXSMM_VLA_ACCESS(3, del_input_padded, ij + handle->desc.pad_h, handle->desc.pad_w, 0, padded_w, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + handle->desc.W * handle->ifmblock); + } + } +} /* end of imgifm1 loop */ + +} /* end of new scope for additional variable declarations (C89) */ + +libxsmm_barrier_wait(handle->barrier, ltid); diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..7738322adb3fb24bcf75e127bf0f76f11b19067e --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c @@ -0,0 +1,352 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +int img, ofm1, ofm2, ifm1, ifm2, oj, oi, kj, ki, oi_use, oj_use, ii_use, ij_use, ofmb, ifmb, ojb, myIfmId, nIfmBlocks, ind, task, ifm1ofm1; +/* computing first logical thread */ +const int ltid = tid - start_thread; +int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); +int threads_per_image = handle->desc.threads / handle->desc.N; +int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); +int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); +int my_ifm_start = 0; +int my_ifm_end = handle->blocksifm; + +/* Batch reduce related variables */ +const element_filter_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +unsigned long long n_blocks; + +/* number of tasks for transpose that could be run in parallel */ +int transpose_work = handle->blocksifm * handle->blocksofm * handle->desc.R * handle->desc.S; +/* compute chunk size */ +int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; +int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; +/* offset output pointer in case of physical padding */ +const int IFW = (handle->pack_input_bwd == 1) ? handle->ofw : handle->ifwp; +const int IFH = (handle->pack_input_bwd == 1) ? handle->ofh : handle->ifhp; +element_input_type *input_ptr = (handle->pack_input_bwd == 1) ? (element_input_type*)((char*)handle->scratch + handle->bwd_packing_padding_scratch_offset) : (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock; +LIBXSMM_VLA_DECL(5, element_input_type, del_input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock); +element_output_type *const out = (element_output_type*)handle->grad_output->data; +LIBXSMM_VLA_DECL(5, const element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + +/* Weight and transpose_weight tensor declaration */ +LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +LIBXSMM_VLA_DECL(6, element_filter_type, tr_wt, (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset), handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); +/* define weight pointer which has the correct format */ +element_filter_type* weight_base = ((handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) ? (element_filter_type*)handle->reg_filter_tr->data : (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset); +LIBXSMM_VLA_DECL(6, const element_filter_type, weight, weight_base, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +/* transpose filters, if requested */ +if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) == 0 ) { + /* Special case of 64x64 transpose with JITed transpose */ + if (handle->ifmblock == 64 && handle->ofmblock == 64) { + libxsmm_meltwfunction_unary tr_kernel = handle->tr_kernel; + libxsmm_meltw_unary_param trans_param; + for (task = transpose_thr_begin; task < transpose_thr_end; ++task) { + ifm1 = task/(handle->blocksofm * handle->desc.R * handle->desc.S); + ofm1 = (task%(handle->blocksofm * handle->desc.R * handle->desc.S))/(handle->desc.R * handle->desc.S); + kj = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))/handle->desc.S; + ki = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))%handle->desc.S; + trans_param.in.primary = &LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + trans_param.out.primary = &LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + tr_kernel( &trans_param ); + trans_param.in.primary = &LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, 16, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + trans_param.out.primary = &LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 16, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + tr_kernel( &trans_param ); + trans_param.in.primary = &LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, 32, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + trans_param.out.primary = &LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 32, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + tr_kernel( &trans_param ); + trans_param.in.primary = &LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, 48, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + trans_param.out.primary = &LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 48, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + tr_kernel( &trans_param ); + } + } else { + /* number of tasks for transpose that could be run in parallel */ + transpose_work = handle->blocksifm * handle->blocksofm; + /* compute chunk size */ + transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; + transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / handle->blocksifm; + ifm1 = ifm1ofm1 % handle->blocksifm; + for (kj=0; kj < handle->desc.R; kj++) { + for (ki=0; ki < handle->desc.S; ki++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2, ifm2, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } + } + } + } + } + } + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( imgpt <= 1 ) { + my_img_start = LIBXSMM_MIN(ltid / threads_per_image, handle->desc.N); + my_img_end = LIBXSMM_MIN(my_img_start + 1, handle->desc.N); + myIfmId = ltid % threads_per_image; + nIfmBlocks = LIBXSMM_UPDIV(handle->blocksifm, threads_per_image); + my_ifm_start = LIBXSMM_MIN(myIfmId * nIfmBlocks, handle->blocksifm); + my_ifm_end = LIBXSMM_MIN((myIfmId+1) * nIfmBlocks, handle->blocksifm); +} + +if ( handle->use_ifm_parallelization == 1 ) { + int spread_out = 0; + if ( handle->desc.N % 8 == 0) { + spread_out = 8; + } else if ( handle->desc.N % 4 == 0) { + spread_out = 4; + } else if (handle->desc.N % 3 == 0) { + spread_out = 3; + } else if (handle->desc.N % 2 == 0) { + spread_out = 2; + } else { + spread_out = 1; + } + if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { + int tile_id = ltid / spread_out; + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads) * spread_out; + my_img_start = LIBXSMM_MIN(tile_id * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN((tile_id+1) * imgpt, handle->desc.N); + my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + } +} + +if (handle->loop_order == 0) { /* (loop_order == N_Kb_Cb_Hb_k_c_h_w) {*/ + if ( handle->avoid_fmas_in_rim == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_input_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, del_input, img, ifm1, oj, 0, 0, handle->blocksifm, IFH, IFW, handle->ifmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (element_input_type)0; + } + temp_ptr += handle->ifmblock; + } + } + } + + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + /* Prepare batch-reduce kernel arguments */ + ij_use = oj; + ii_use = oi; + oj_use = oj - (1-handle->desc.pad_h_out); + oi_use = oi - (1-handle->desc.pad_w_out); + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } else if (oi == handle->ofw-handle->bwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } else { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_input_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, del_input, img, ifm1, oj, 0, 0, handle->blocksifm, IFH, IFW, handle->ifmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (element_input_type)0; + } + temp_ptr += handle->ifmblock; + } + } + } + + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + ij_use = (handle->spread_input_bwd == 1) ? oj * handle->desc.u : oj; + ii_use = (handle->spread_input_bwd == 1) ? oi * handle->desc.v : oi; + oi_use = oi; + oj_use = oj; + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } + } + } + } + } + } + } + } + } +} + +if (handle->loop_order == 1) { /* (loop_order == N_Kb_Cb_Hb_k_c_h_w) { */ + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0 && oj == 0 && oi == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_input_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, del_input, img, ifm1, oj, 0, 0, handle->blocksifm, IFH, IFW, handle->ifmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (element_input_type)0; + } + temp_ptr += handle->ifmblock; + } + } + } + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + /* Prepare batch-reduce kernel arguments */ + ij_use = (handle->spread_input_bwd == 1) ? oj * handle->desc.u : oj; + ii_use = (handle->spread_input_bwd == 1) ? oi * handle->desc.v : oi; + oi_use = oi; + oj_use = oj; + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } + } + } + } + } + } + } + } +} + +if (handle->pack_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, oj/handle->desc.u, oi/handle->desc.v, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock); + } + } + } + } + } + } +} else if (handle->spread_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..efd2d68e78acd0f8f85dbcde6060006451612e58 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c @@ -0,0 +1,407 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +int img, ofm1, ofm2, ifm1, ifm2, oj, ojj, oi, kj, ki, oi_use, oj_use, ii_use, ij_use, ofmb, ifmb, ojb, myIfmId, nIfmBlocks, ind, task; +int last_ki, last_kj, next_kj; +/* computing first logical thread */ +const int ltid = tid - start_thread; +int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); +int threads_per_image = handle->desc.threads / handle->desc.N; +int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); +int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); +int my_ifm_start = 0; +int my_ifm_end = handle->blocksifm; +int ofmblock_lp = handle->ofmblock/handle->fm_lp_block; +int ifmblock_lp = handle->ifmblock/handle->fm_lp_block; +int lpb = handle->fm_lp_block; + +/* Batch reduce related variables */ +const element_filter_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +unsigned long long n_blocks; + +/* number of tasks for transpose that could be run in parallel */ +int transpose_work = handle->blocksifm * handle->blocksofm * handle->desc.R * handle->desc.S; +/* compute chunk size */ +int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; +int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; +/* offset output pointer in case of physical padding */ +const int IFW = (handle->pack_input_bwd == 1) ? handle->ofw : handle->ifwp; +const int IFH = (handle->pack_input_bwd == 1) ? handle->ofh : handle->ifhp; +const int ifwp_scratch = (handle->spread_input_bwd == 1) ? handle->desc.v * handle->bwd_ofw_rb : handle->bwd_ofw_rb; + +/* Auxiliary fp32 accumulators */ +float *del_inp_ptr; +float *del_inp_fp32 = (float*)((char*)handle->scratch + handle->bwd_lp_input_full_scratch_offset) + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock; +LIBXSMM_VLA_DECL(5, float, del_input_fp32, del_inp_fp32, handle->blocksifm, IFH, IFW, handle->ifmblock); + +element_input_type *input_ptr = (handle->pack_input_bwd == 1) ? (element_input_type*)((char*)handle->scratch + handle->bwd_packing_padding_scratch_offset) : (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock; +LIBXSMM_VLA_DECL(5, element_input_type, del_input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock); +element_output_type *const out = (element_output_type*)handle->grad_output->data; +LIBXSMM_VLA_DECL(5, const element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + +/* Weight and transpose_weight tensor declaration */ +LIBXSMM_VLA_DECL(7, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb); +LIBXSMM_VLA_DECL(7, element_filter_type, tr_wt, (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset), handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + +/* define weight pointer which has the correct format */ +element_filter_type* weight_base = ((handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) ? (element_filter_type*)handle->reg_filter_tr->data : (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset); +LIBXSMM_VLA_DECL(7, const element_filter_type, weight, weight_base, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +/* transpose filters, if requested */ +if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) == 0 ) { + for (task = transpose_thr_begin; task < transpose_thr_end; ++task) { + ifm1 = task/(handle->blocksofm * handle->desc.R * handle->desc.S); + ofm1 = (task%(handle->blocksofm * handle->desc.R * handle->desc.S))/(handle->desc.R * handle->desc.S); + kj = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))/handle->desc.S; + ki = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))%handle->desc.S; + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(7, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2/lpb, ifm2, ofm2%lpb, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb) = + LIBXSMM_VLA_ACCESS(7, wt, ofm1, ifm1, kj, ki, ifm2/lpb, ofm2, ifm2%lpb, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb); + } + } + } + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( imgpt <= 1 ) { + my_img_start = LIBXSMM_MIN(ltid / threads_per_image, handle->desc.N); + my_img_end = LIBXSMM_MIN(my_img_start + 1, handle->desc.N); + myIfmId = ltid % threads_per_image; + nIfmBlocks = LIBXSMM_UPDIV(handle->blocksifm, threads_per_image); + my_ifm_start = LIBXSMM_MIN(myIfmId * nIfmBlocks, handle->blocksifm); + my_ifm_end = LIBXSMM_MIN((myIfmId+1) * nIfmBlocks, handle->blocksifm); +} + +if ( handle->use_ifm_parallelization == 1 ) { + int spread_out = 0; + if ( handle->desc.N % 8 == 0) { + spread_out = 8; + } else if ( handle->desc.N % 4 == 0) { + spread_out = 4; + } else if (handle->desc.N % 3 == 0) { + spread_out = 3; + } else if (handle->desc.N % 2 == 0) { + spread_out = 2; + } else { + spread_out = 1; + } + if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { + int tile_id = ltid / spread_out; + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads) * spread_out; + my_img_start = LIBXSMM_MIN(tile_id * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN((tile_id+1) * imgpt, handle->desc.N); + my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + } +} + +if (handle->loop_order == 0) { /* (loop_order == N_Kb_Cb_Hb_k_c_h_w) {*/ + if ( handle->avoid_fmas_in_rim == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float *temp_ptr = (float*)&LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, oj, 0, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (float)0; + } + temp_ptr += handle->ifmblock; + } + } + } + + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + /* Prepare batch-reduce kernel arguments */ + ij_use = oj; + ii_use = oi; + oj_use = oj - (1-handle->desc.pad_h_out); + oi_use = oi - (1-handle->desc.pad_w_out); + last_kj = handle->desc.R-1; + last_ki = handle->desc.S-1; + next_kj = kj+1; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + n_blocks = ind; + if (handle->avoid_acc_load_bwd == 1) { + br_gemm_kernel2_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } else { + del_inp_ptr = &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (ofm2 == handle->blocksofm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + handle->bwd_ofw_rb * handle->ifmblock); + } + } + } + } else if (oi == handle->ofw-handle->bwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + n_blocks = ind; + if (handle->avoid_acc_load_bwd == 1) { + br_gemm_kernel2_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } else { + del_inp_ptr = &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (ofm2 == handle->blocksofm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + handle->bwd_ofw_rb * handle->ifmblock); + } + } + } + } else { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + n_blocks = ind; + if (handle->avoid_acc_load_bwd == 1) { + br_gemm_kernel_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } else { + del_inp_ptr = &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (ofm2 == handle->blocksofm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + handle->bwd_ofw_rb * handle->ifmblock); + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float *temp_ptr = (float*)&LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, oj, 0, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (float)0; + } + temp_ptr += handle->ifmblock; + } + } + } + + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + ij_use = (handle->spread_input_bwd == 1) ? oj * handle->desc.u : oj; + ii_use = (handle->spread_input_bwd == 1) ? oi * handle->desc.v : oi; + oi_use = oi; + oj_use = oj; + ind = 0; + kj = 0; + ki = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + } + } + n_blocks = ind; + if (handle->avoid_acc_load_bwd == 1) { + br_gemm_kernel_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } else { + del_inp_ptr = &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (ofm2 == handle->blocksofm && kj == handle->desc.R && ki == handle->desc.S) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + ifwp_scratch * handle->ifmblock); + } + } + } + } + } + } + } + } + } + } + } + } +} + +if (handle->loop_order == 1) { /* (loop_order == N_Kb_Cb_Hb_k_c_h_w) { */ + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0 && oj == 0 && oi == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float *temp_ptr = (float*)&LIBXSMM_VLA_ACCESS( 5, del_input_fp32, img, ifm1, oj, 0, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (float)0; + } + temp_ptr += handle->ifmblock; + } + } + } + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + /* Prepare batch-reduce kernel arguments */ + ij_use = (handle->spread_input_bwd == 1) ? oj * handle->desc.u : oj; + ii_use = (handle->spread_input_bwd == 1) ? oi * handle->desc.v : oi; + oi_use = oi; + oj_use = oj; + ind = 0; + kj = 0; + ki = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + } + } + n_blocks = ind; + if (handle->avoid_acc_load_bwd == 1) { + br_gemm_kernel_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), &n_blocks); + } else { + del_inp_ptr = &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (ofm2 == handle->blocksofm && kj == handle->desc.R && ki == handle->desc.S) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + ifwp_scratch * handle->ifmblock); + } + } + } + } + } + } + } + } + } + } + } +} + +if (handle->pack_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, oj/handle->desc.u, oi/handle->desc.v, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock); + } + } + } + } + } + } +} else if (handle->spread_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..36b7d4fde9f6b57b619726541879c544e4fd5a9c --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c @@ -0,0 +1,530 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +int img, ofm1, ofm2, ifm1, ifm2, oj, ojj, oi, kj, ki, /*oi_use, oj_use, ii_use, ij_use, ofmb,*/ ifmb, ojb, myIfmId, nIfmBlocks, /*ind,*/ task; +/*int last_ki, last_kj, next_kj;*/ +/* computing first logical thread */ +const int ltid = tid - start_thread; +int imgpt = (handle->desc.N + handle->desc.threads - 1)/handle->desc.threads; +int threads_per_image = handle->desc.threads / handle->desc.N; +int my_img_start = LIBXSMM_MIN( ltid * imgpt, handle->desc.N); +int my_img_end = LIBXSMM_MIN( (ltid+1) * imgpt, handle->desc.N); +int my_ifm_start = 0; +int my_ifm_end = handle->blocksifm; +int ofmblock_lp = handle->ofmblock/handle->fm_lp_block; +int ifmblock_lp = handle->ifmblock/handle->fm_lp_block; +int lpb = handle->fm_lp_block; + +/* Batch reduce related variables */ +#if 0 +const element_filter_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +#endif +unsigned long long n_blocks; + +/* number of tasks for transpose that could be run in parallel */ +int transpose_work = handle->blocksifm * handle->blocksofm * handle->desc.R * handle->desc.S; +/* compute chunk size */ +int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; +int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; +/* offset output pointer in case of physical padding */ +const int IFW = (handle->pack_input_bwd == 1) ? handle->ofw : handle->ifwp; +const int IFH = (handle->pack_input_bwd == 1) ? handle->ofh : handle->ifhp; + +/* Auxiliary fp32 accumulators */ +float *out_ptr; +/*float *del_inp_fp32 = (float*)handle->scratch6 + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock;*/ +float *del_inp_scratch = (float*)((char*)handle->scratch + handle->bwd_lp_input_full_scratch_offset) + ltid * handle->bwd_gemm_pixels * handle->ifmblock; +/*LIBXSMM_VLA_DECL(5, float, del_input_fp32, del_inp_fp32, handle->blocksifm, IFH, IFW, handle->ifmblock);*/ +int scratch_ifwp = (handle->bwd_gemm_pixels == (handle->bwd_ofw_rb * handle->bwd_ofh_rb)) ? handle->bwd_ofw_rb : handle->ifwp; +LIBXSMM_VLA_DECL(3, float, scratch_fp32, del_inp_scratch, scratch_ifwp, handle->ifmblock); + +element_input_type *input_ptr = (handle->pack_input_bwd == 1) ? (element_input_type*)((char*)handle->scratch + handle->bwd_packing_padding_scratch_offset) : (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock; +LIBXSMM_VLA_DECL(5, element_input_type, del_input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock); +element_output_type *const out = (element_output_type*)handle->grad_output->data; +LIBXSMM_VLA_DECL(5, const element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + +/* Weight and transpose_weight tensor declaration */ +LIBXSMM_VLA_DECL(7, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb); +LIBXSMM_VLA_DECL(7, element_filter_type, tr_wt, (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset), handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + +/* define weight pointer which has the correct format */ +element_filter_type* weight_base = ((handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) ? (element_filter_type*)handle->reg_filter_tr->data : (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset); +LIBXSMM_VLA_DECL(7, const element_filter_type, weight, weight_base, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +/* Execute the tileconfig kernel */ +tile_config_kernel(NULL, NULL, NULL); + +/* transpose filters, if requested */ +if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) == 0 ) { + if ((handle->ifmblock % 16 == 0) && (handle->ofmblock % 16 == 0)) { + for (task = transpose_thr_begin; task < transpose_thr_end; ++task) { + ifm1 = task/(handle->blocksofm * handle->desc.R * handle->desc.S); + ofm1 = (task%(handle->blocksofm * handle->desc.R * handle->desc.S))/(handle->desc.R * handle->desc.S); + kj = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))/handle->desc.S; + ki = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))%handle->desc.S; + bf16_vnni_transpose_kernel( &LIBXSMM_VLA_ACCESS(7, wt, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb), + &LIBXSMM_VLA_ACCESS(7, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb), + handle->ifmblock, handle->ofmblock, handle->ifmblock, handle->ofmblock); + } + } else { + for (task = transpose_thr_begin; task < transpose_thr_end; ++task) { + ifm1 = task/(handle->blocksofm * handle->desc.R * handle->desc.S); + ofm1 = (task%(handle->blocksofm * handle->desc.R * handle->desc.S))/(handle->desc.R * handle->desc.S); + kj = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))/handle->desc.S; + ki = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))%handle->desc.S; + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(7, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2/lpb, ifm2, ofm2%lpb, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb) = + LIBXSMM_VLA_ACCESS(7, wt, ofm1, ifm1, kj, ki, ifm2/lpb, ofm2, ifm2%lpb, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, lpb); + } + } + } + } +} +/* wait for transpose to finish */ + +libxsmm_barrier_wait(handle->barrier, ltid); + +if ( imgpt <= 1 ) { + my_img_start = LIBXSMM_MIN( ltid / threads_per_image, handle->desc.N); + my_img_end = LIBXSMM_MIN( my_img_start + 1, handle->desc.N); + myIfmId = ltid % threads_per_image; + nIfmBlocks = (handle->blocksifm + threads_per_image - 1) / threads_per_image; + my_ifm_start = LIBXSMM_MIN(myIfmId * nIfmBlocks, handle->blocksifm); + my_ifm_end = LIBXSMM_MIN((myIfmId+1) * nIfmBlocks, handle->blocksifm); +} + +if ( handle->use_ifm_parallelization == 1 ) { + int spread_out = 0; + if ( handle->desc.N % 8 == 0) { + spread_out = 8; + } else if ( handle->desc.N % 4 == 0) { + spread_out = 4; + } else if (handle->desc.N % 3 == 0) { + spread_out = 3; + } else if (handle->desc.N % 2 == 0) { + spread_out = 2; + } else { + spread_out = 1; + } + if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { + int tile_id = ltid / spread_out; + int ifmpt = (handle->blocksifm+spread_out-1)/spread_out; + int ifm_id = ltid % spread_out; + imgpt = ((handle->desc.N + handle->desc.threads - 1)/handle->desc.threads) * spread_out; + my_img_start = LIBXSMM_MIN( tile_id * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN( (tile_id+1) * imgpt, handle->desc.N); + my_ifm_start = LIBXSMM_MIN( ifm_id * ifmpt, handle->blocksifm); + my_ifm_end = LIBXSMM_MIN( (ifm_id+1) * ifmpt, handle->blocksifm); + } +} + +n_blocks = (unsigned long long)handle->blocksofm_blocking * handle->desc.R * handle->desc.S; +out_ptr = (float*) &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ifwp, handle->ifmblock); + +#if 1 +if (handle->desc.R == 1 && handle->desc.S == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + /* Batch-reduce GEMM call */ + br_gemm_kernel_strd( &LIBXSMM_VLA_ACCESS(7, weight, ifm1, 0, 0, 0, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb), + &LIBXSMM_VLA_ACCESS(5, output, img, 0, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), out_ptr, &n_blocks); + /* Downconvert accumulated tiles to BF16 */ + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, scratch_ifwp, handle->ifmblock), &LIBXSMM_VLA_ACCESS( 5, del_input, img, ifm1, oj+ojj, oi, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), handle->bwd_ofw_rb * handle->ifmblock); + } + } + } + } + } + } + } +} +else { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + /* Batch-reduce GEMM call */ + br_gemm_kernel_offs( &LIBXSMM_VLA_ACCESS(7, weight, ifm1, 0, 0, 0, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb), + &LIBXSMM_VLA_ACCESS(5, output, img, 0, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), out_ptr, &n_blocks, handle->A_offsets_bwd, handle->B_offsets_bwd); + /* Downconvert accumulated tiles to BF16 */ + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, scratch_ifwp, handle->ifmblock), &LIBXSMM_VLA_ACCESS( 5, del_input, img, ifm1, oj+ojj, oi, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), handle->bwd_ofw_rb * handle->ifmblock); + } + } + } + } + } + } + } +} + +if (handle->pack_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, oj/handle->desc.u, oi/handle->desc.v, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock); + } + } + } + } + } + } +} else if (handle->spread_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } +} +#else +if (handle->loop_order == 0) { /* (loop_order == N_Kb_Cb_Hb_k_c_h_w) {*/ + if ( handle->avoid_fmas_in_rim == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float *temp_ptr = (float*)&LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, oj, 0, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (float)0; + } + temp_ptr += handle->ifmblock; + } + } + } + + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + /* Prepare batch-reduce kernel arguments */ + ij_use = oj; + ii_use = oi; + oj_use = oj - (1-handle->desc.pad_h_out); + oi_use = oi - (1-handle->desc.pad_w_out); + last_kj = handle->desc.R-1; + last_ki = handle->desc.S-1; + next_kj = kj+1; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + n_blocks = ind; + del_inp_ptr = (handle->avoid_acc_load_bwd == 1) ? &LIBXSMM_VLA_ACCESS(3, scratch_fp32, 0, 0, 0, ifwp_scratch, handle->ifmblock) + : &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (handle->avoid_acc_load_bwd == 1) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, ifwp_scratch, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + (handle->bwd_ofw_rb-1) * handle->ifmblock); + } + } else if (ofm2 == handle->blocksofm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + handle->bwd_ofw_rb * handle->ifmblock); + } + } + } else if (oi == handle->ofw-handle->bwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + n_blocks = ind; + del_inp_ptr = (handle->avoid_acc_load_bwd == 1) ? &LIBXSMM_VLA_ACCESS(3, scratch_fp32, 0, 0, 0, ifwp_scratch, handle->ifmblock) + : &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (handle->avoid_acc_load_bwd == 1) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, ifwp_scratch, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + (handle->bwd_ofw_rb-1) * handle->ifmblock); + } + } else if (ofm2 == handle->blocksofm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + handle->bwd_ofw_rb * handle->ifmblock); + } + } + } else { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + n_blocks = ind; + del_inp_ptr = (handle->avoid_acc_load_bwd == 1) ? &LIBXSMM_VLA_ACCESS(3, scratch_fp32, 0, 0, 0, ifwp_scratch, handle->ifmblock) + : &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (handle->avoid_acc_load_bwd == 1) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, ifwp_scratch, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + handle->bwd_ofw_rb * handle->ifmblock); + } + } else if (ofm2 == handle->blocksofm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + handle->bwd_ofw_rb * handle->ifmblock); + } + } + } + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float *temp_ptr = (float*)&LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, oj, 0, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (float)0; + } + temp_ptr += handle->ifmblock; + } + } + } + + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + ij_use = (handle->spread_input_bwd == 1) ? oj * handle->desc.u : oj; + ii_use = (handle->spread_input_bwd == 1) ? oi * handle->desc.v : oi; + oi_use = oi; + oj_use = oj; + ind = 0; + kj = 0; + ki = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + } + } + n_blocks = ind; + del_inp_ptr = (handle->avoid_acc_load_bwd == 1) ? &LIBXSMM_VLA_ACCESS(3, scratch_fp32, 0, 0, 0, ifwp_scratch, handle->ifmblock) + : &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (handle->avoid_acc_load_bwd == 1) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, ifwp_scratch, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + ifwp_scratch * handle->ifmblock); + } + } else if (ofm2 == handle->blocksofm && kj == handle->desc.R && ki == handle->desc.S) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + ifwp_scratch * handle->ifmblock); + } + } + } + } + } + } + } + } + } + } + } +} + +if (handle->loop_order == 1) { /* (loop_order == N_Kb_Cb_Hb_k_c_h_w) { */ + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0 && oj == 0 && oi == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float *temp_ptr = (float*)&LIBXSMM_VLA_ACCESS( 5, del_input_fp32, img, ifm1, oj, 0, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (float)0; + } + temp_ptr += handle->ifmblock; + } + } + } + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + /* Prepare batch-reduce kernel arguments */ + ij_use = (handle->spread_input_bwd == 1) ? oj * handle->desc.u : oj; + ii_use = (handle->spread_input_bwd == 1) ? oi * handle->desc.v : oi; + oi_use = oi; + oj_use = oj; + ind = 0; + kj = 0; + ki = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ifm1, ofm2, kj, ki, 0, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, ofmblock_lp, handle->ifmblock, lpb); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, ofm2, oj_use + kj, oi_use + ki, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + ind++; + } + } + } + n_blocks = ind; + del_inp_ptr = (handle->avoid_acc_load_bwd == 1) ? &LIBXSMM_VLA_ACCESS(3, scratch_fp32, 0, 0, 0, ifwp_scratch, handle->ifmblock) + : &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + br_gemm_kernel(A_ptrs, B_ptrs, del_inp_ptr, &n_blocks); + if (handle->avoid_acc_load_bwd == 1) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, ifwp_scratch, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + ifwp_scratch * handle->ifmblock); + } + } else if (ofm2 == handle->blocksofm && kj == handle->desc.R && ki == handle->desc.S) { + for (ojj = 0; ojj < handle->bwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, del_input_fp32, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, ij_use+ojj, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + ifwp_scratch * handle->ifmblock); + } + } + } + } + } + } + } + } + } + } +} + +if (handle->pack_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1, oj/handle->desc.u, oi/handle->desc.v, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) ; + } + } + } + } + } + } +} else if (handle->spread_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->ifmblock, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, ifm1, oj, oi, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } +} +#endif + +handle->tilerelease_kernel(NULL, NULL, NULL); +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..22a3beebacf4b7e7f61491cd5ac908a616dae23c --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c @@ -0,0 +1,191 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Rajkishore Barik, Ankush Mandal, Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +int imgifm1, img, ofm1, ifm1, oj, ij, oi, ii, kj, ki, ifm2, ofm2, ifm1ofm1; +/* computing first logical thread */ +const int ltid = tid - start_thread; + +/* number of tasks that could be run in parallel */ +const int work = handle->desc.N * handle->blocksifm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks for transpose that could be run in parallel */ +int transpose_work = handle->blocksifm * handle->blocksofm; +/* compute chunk size */ +const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; +const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + +/* offset pointer in case of physical padding */ +element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->blocksofm * handle->ofmblock; + +/* Weight and transpose_weight tensor declaration */ +#if defined(LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM) +LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#if defined(LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK) +LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif +LIBXSMM_VLA_DECL(6, element_filter_type, tr_wt, (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset), handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); +/* define weight pointer which has the correct format */ +element_filter_type* weight_base = 0; + +/* padding via stack allocated buffers */ +const int padded_w = handle->desc.W + (2 * handle->desc.pad_w); +const int padded_h = handle->desc.H + (2 * handle->desc.pad_h); +const int size_tls1 = padded_h * padded_w * handle->ifmblock; +element_input_type *const del_input_scratch_padding = (element_input_type*)((char*)handle->scratch + handle->bwd_packing_padding_scratch_offset) + ltid * size_tls1; +for ( ii = 0; ii < size_tls1; ++ii ) { del_input_scratch_padding[ii] = (element_input_type)0; } + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +/* transpose filters, if requested */ +if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) { + weight_base = (element_filter_type*)handle->reg_filter_tr->data; +} else { + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / handle->blocksifm; + ifm1 = ifm1ofm1 % handle->blocksifm; + for (kj=0; kj < handle->desc.R; kj++) { + for (ki=0; ki < handle->desc.S; ki++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { +#if defined(LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM) + LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2, ifm2, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#if defined(LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK) + LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2, ifm2, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(6, wt, kj, ki, ifm1, ifm2, ofm1, ofm2, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + } + } + } + } + } + weight_base = (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset); + + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); +} + +{/* open new scope for additional variable declarations (C89) */ +LIBXSMM_VLA_DECL(5, element_input_type, del_input, (element_output_type*)handle->grad_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); +LIBXSMM_VLA_DECL(3, element_input_type, del_input_padded, del_input_scratch_padding, padded_w, handle->ifmblock); +LIBXSMM_VLA_DECL(5, const element_output_type, output, out, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); +LIBXSMM_VLA_DECL(6, const element_filter_type, weight, weight_base, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + +for (imgifm1 = thr_begin; imgifm1 < thr_end; ++imgifm1) { + img = imgifm1 / handle->blocksifm; + ifm1 = imgifm1 % handle->blocksifm; + + /* check if we need padding, for now we do physical padding on the fly, however we can play with N parameter of the GEMM */ + /* @TODO: add variant which deals with multiple GEMMS by varying N to deal with padding */ + if ( (handle->desc.pad_h == handle->desc.pad_h_in) && (handle->desc.pad_w == handle->desc.pad_w_in) ) { + + /* reset result buffer to zero when intent is to overwrite when first block + of input channels should be convoluted */ + if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) ) { + element_input_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, del_input, img, 0, 0, ifm1, 0, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock)); + /*LIBXSMM_PRAGMA_SIMD*/ + for (ij = 0; ij < handle->ifhp*handle->ifwp; ij++) { + for (ii = 0; ii < handle->ifmblock; ii++) { + temp_ptr[ii] = (element_input_type)0; + } + temp_ptr += handle->blocksifm * handle->ifmblock; + } + } + + /* run convolution */ + for (ofm1 = 0; ofm1 < handle->blocksofm; ++ofm1) { + for ( oj = 0; oj < handle->ofh; ++oj) { + ij = oj * handle->desc.u; + oi = 0; ii = 0; + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + gemm_kernel( &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, oj, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, del_input, img, ij + kj, ii + ki, ifm1, 0, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock) ); + } + } + } + } + + /* zero rim in case of physical padding.... this code is extremely stupid and crappy as it requires a complicated if... */ + if (handle->desc.pad_h_in > 0 || handle->desc.pad_w_in > 0) { + for ( ij = 0; ij < handle->ifhp; ij++ ) { + for ( ii = 0; ii < handle->ifwp; ii++ ) { + if ( (ij < handle->desc.pad_h_in) || (ij >= (handle->desc.H+handle->desc.pad_h_in)) || + (ii < handle->desc.pad_w_in) || (ii >= (handle->desc.W+handle->desc.pad_w_in)) ) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(5, del_input, img, ij, ii, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } else { + /* reset result buffer to zero when intent is to overwrite when first block + of input channels should be convoluted */ + if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) ) { + LIBXSMM_PRAGMA_SIMD + for (ij = 0; ij < size_tls1; ++ij) { + del_input_scratch_padding[ij] = (element_output_type)0; + } + } else { + for (ij = 0; ij < handle->desc.H; ij++) { + for (ii = 0; ii < handle->desc.W; ii++) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(3, del_input_padded, ij + handle->desc.pad_h, ii + handle->desc.pad_w, ifm2, padded_w, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(5, del_input, img, ij, ii, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + } + } + } + } + + /* run convolution */ + for (ofm1 = 0; ofm1 < handle->blocksofm; ++ofm1) { + for ( oj = 0; oj < handle->ofh; ++oj) { + ij = oj * handle->desc.u; + oi = 0; ii = 0; + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + gemm_kernel( &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, oj, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(3, del_input_padded, ij + kj, ii + ki, 0, padded_w, handle->ifmblock) ); + } + } + } + } + + /* input padding copy back */ + for (ij = 0; ij < handle->desc.H; ij++) { + for (ii = 0; ii < handle->desc.W; ii++) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input, img, ij, ii, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(3, del_input_padded, ij + handle->desc.pad_h, ii + handle->desc.pad_w, ifm2, padded_w, handle->ifmblock); + } + } + } + } +} /* end of imgifm1 loop */ + +} /* end of new scope for additional variable declarations (C89) */ + +libxsmm_barrier_wait(handle->barrier, ltid); diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..d2cfb8e6cf5fb2bb775634f8975b5efdd1a39b37 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c @@ -0,0 +1,364 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +int img, ofm1, ofm2, ifm1, ifm2, oj, oi, kj, ki, oi_use, oj_use, ii_use, ij_use, ofmb, ifmb, ojb, myIfmId, nIfmBlocks, ind, /*task,*/ ifm1ofm1; +/* computing first logical thread */ +const int ltid = tid - start_thread; +int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); +int threads_per_image = handle->desc.threads / handle->desc.N; +int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); +int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); +int my_ifm_start = 0; +int my_ifm_end = handle->blocksifm; + +/* Batch reduce related variables */ +const element_filter_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +unsigned long long n_blocks; + +/* number of tasks for transpose that could be run in parallel */ +int transpose_work = handle->blocksifm * handle->blocksofm * handle->desc.R * handle->desc.S; +/* compute chunk size */ +int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; +int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; +/* offset output pointer in case of physical padding */ +const int IFW = (handle->pack_input_bwd == 1) ? handle->ofw : handle->ifwp; +const int IFH = (handle->pack_input_bwd == 1) ? handle->ofh : handle->ifhp; +element_input_type *input_ptr = (handle->pack_input_bwd == 1) ? (element_input_type*)((char*)handle->scratch + handle->bwd_packing_padding_scratch_offset) : (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->blocksifm * handle->ifmblock; +LIBXSMM_VLA_DECL(5, element_input_type, del_input, input_ptr, IFH, IFW, handle->blocksifm, handle->ifmblock); +element_output_type *const out = (element_output_type*)handle->grad_output->data; +LIBXSMM_VLA_DECL(5, const element_output_type, output, out, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + +/* Weight and transpose_weight tensor declaration */ +#if defined(LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM) +LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#if defined(LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK) +LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif +LIBXSMM_VLA_DECL(6, element_filter_type, tr_wt, (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset), handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); +/* define weight pointer which has the correct format */ +element_filter_type* weight_base = ((handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) ? (element_filter_type*)handle->reg_filter_tr->data : (element_filter_type*)((char*)handle->scratch + handle->bwd_filter_trans_scratch_offset); +LIBXSMM_VLA_DECL(6, const element_filter_type, weight, weight_base, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +/* transpose filters, if requested */ +if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) == 0 ) { + /* Special case of 64x64 transpose with JITed transpose */ +#if 0 + if (handle->ifmblock == 64 && handle->ofmblock == 64) { + libxsmm_xtransfunction tr_kernel = handle->tr_kernel; + const unsigned int ld_in = 64; + const unsigned int ld_out = 64; + for (task = transpose_thr_begin; task < transpose_thr_end; ++task) { + ifm1 = task/(handle->blocksofm * handle->desc.R * handle->desc.S); + ofm1 = (task%(handle->blocksofm * handle->desc.R * handle->desc.S))/(handle->desc.R * handle->desc.S); + kj = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))/handle->desc.S; + ki = ((task%(handle->blocksofm * handle->desc.R * handle->desc.S))%(handle->desc.R * handle->desc.S))%handle->desc.S; + tr_kernel(&LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &ld_in, + &LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock), &ld_out); + tr_kernel(&LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, 16, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &ld_in, + &LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 16, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock), &ld_out); + tr_kernel(&LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, 32, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &ld_in, + &LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 32, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock), &ld_out); + tr_kernel(&LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, 48, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &ld_in, + &LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj, handle->desc.S-1-ki, 0, 48, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock), &ld_out); + } + } else { +#endif + /* number of tasks for transpose that could be run in parallel */ + transpose_work = handle->blocksifm * handle->blocksofm; + /* compute chunk size */ + transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; + transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / handle->blocksifm; + ifm1 = ifm1ofm1 % handle->blocksifm; + for (kj=0; kj < handle->desc.R; kj++) { + for (ki=0; ki < handle->desc.S; ki++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { +#if defined(LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM) + LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2, ifm2, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(6, wt, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#if defined(LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK) + LIBXSMM_VLA_ACCESS(6, tr_wt, ifm1, ofm1, handle->desc.R-1-kj , handle->desc.S-1-ki, ofm2, ifm2, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(6, wt, kj, ki, ifm1, ifm2, ofm1, ofm2, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + } + } + } + } + } +#if 0 + } +#endif + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( imgpt <= 1 ) { + my_img_start = LIBXSMM_MIN(ltid / threads_per_image, handle->desc.N); + my_img_end = LIBXSMM_MIN(my_img_start + 1, handle->desc.N); + myIfmId = ltid % threads_per_image; + nIfmBlocks = LIBXSMM_UPDIV(handle->blocksifm, threads_per_image); + my_ifm_start = LIBXSMM_MIN(myIfmId * nIfmBlocks, handle->blocksifm); + my_ifm_end = LIBXSMM_MIN((myIfmId+1) * nIfmBlocks, handle->blocksifm); +} + +if ( handle->use_ifm_parallelization == 1 ) { + int spread_out = 0; + if ( handle->desc.N % 8 == 0) { + spread_out = 8; + } else if ( handle->desc.N % 4 == 0) { + spread_out = 4; + } else if (handle->desc.N % 3 == 0) { + spread_out = 3; + } else if (handle->desc.N % 2 == 0) { + spread_out = 2; + } else { + spread_out = 1; + } + if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { + int tile_id = ltid / spread_out; + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads) * spread_out; + my_img_start = LIBXSMM_MIN(tile_id * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN((tile_id+1) * imgpt, handle->desc.N); + my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + } +} + +if (handle->loop_order == 0) { /* (loop_order == N_Kb_Cb_Hb_k_c_h_w) {*/ + if ( handle->avoid_fmas_in_rim == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_input_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, del_input, img, oj, 0, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (element_input_type)0; + } + temp_ptr += handle->blocksifm * handle->ifmblock; + } + } + } + + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + /* Prepare batch-reduce kernel arguments */ + ij_use = oj; + ii_use = oi; + oj_use = oj - (1-handle->desc.pad_h_out); + oi_use = oi - (1-handle->desc.pad_w_out); + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, oj_use + kj, oi_use + ki + 1, ofm2, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ij_use, ii_use + 1, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock), &n_blocks); + } else if (oi == handle->ofw-handle->bwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, oj_use + kj, oi_use + ki, ofm2, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ij_use, ii_use, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock), &n_blocks); + } else { + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, oj_use + kj, oi_use + ki, ofm2, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ij_use, ii_use, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock), &n_blocks); + } + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_input_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, del_input, img, oj, 0, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (element_input_type)0; + } + temp_ptr += handle->blocksifm * handle->ifmblock; + } + } + } + + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + ij_use = (handle->spread_input_bwd == 1) ? oj * handle->desc.u : oj; + ii_use = (handle->spread_input_bwd == 1) ? oi * handle->desc.v : oi; + oi_use = oi; + oj_use = oj; + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, oj_use + kj, oi_use + ki, ofm2, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ij_use, ii_use, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock), &n_blocks); + } + } + } + } + } + } + } + } + } +} + +if (handle->loop_order == 1) { /* (loop_order == N_Kb_Cb_Hb_k_c_h_w) { */ + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += handle->block_bwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_bwd_oj) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_bwd_oj,handle->ofh); oj += handle->bwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->bwd_ofw_rb) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_bwd_ifm, my_ifm_end); ifm1++ ) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_bwd_ofm) { + if ( (ofmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load_bwd == 0 && ojb == 0 && oj == 0 && oi == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_input_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, del_input, img, oj, 0, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + temp_ptr[ifm2] = (element_input_type)0; + } + temp_ptr += handle->blocksifm * handle->ifmblock; + } + } + } + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_bwd_ofm, handle->blocksofm); ofm1 += handle->blocksofm_blocking) { + /* Prepare batch-reduce kernel arguments */ + ij_use = (handle->spread_input_bwd == 1) ? oj * handle->desc.u : oj; + ii_use = (handle->spread_input_bwd == 1) ? oi * handle->desc.v : oi; + oi_use = oi; + oj_use = oj; + ind = 0; + for (ofm2 = ofm1; ofm2 < ofm1 + handle->blocksofm_blocking; ofm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ifm1, ofm2, kj, ki, 0, 0, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img, oj_use + kj, oi_use + ki, ofm2, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, del_input, img, ij_use, ii_use, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock), &n_blocks); + } + } + } + } + } + } + } + } +} + +if (handle->pack_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->blocksifm * handle->ifmblock, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, oj, oi, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock) = (element_input_type)0; + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, oj, oi, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, del_input, img, oj/handle->desc.u, oi/handle->desc.v, ifm1, ifm2, IFH, IFW, handle->blocksifm,handle->ifmblock); + } + } + } + } + } + } +} else if (handle->spread_input_bwd == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, del_input_full, (element_input_type*)handle->grad_input->data + ((size_t)handle->desc.pad_h_in * handle->ifwp + handle->desc.pad_w_in) * handle->blocksifm * handle->ifmblock, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ifhp; oj++) { + for (oi = 0; oi < handle->ifwp; oi++) { + if (oi % handle->desc.v != 0 || oj % handle->desc.u != 0) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, del_input_full, img, oj, oi, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..c116cc204697dfd72404232b6c8ac12a9781e4cf --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c @@ -0,0 +1,519 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ + +int img, ofm1, ofm2 = 0, ifm1, ifm2 = 0, oj, oi, kj, ki, oi_use, oj_use, ii_use, ij_use, ofmb, ifmb, ojb, myOfmId, nOfmBlocks, ind, ofm11, ki1, kj1, ojj, oii, ii, ij, spread_out = 1; +/* computing first logical thread */ +const int ltid = tid - start_thread; +int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); +int threads_per_image = handle->desc.threads / handle->desc.N; +int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); +int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); +int my_ofm_start = 0; +int my_ofm_end = handle->blocksofm; + +/* Batch reduce related variables */ +const element_filter_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +unsigned long long n_blocks; + +/* offset output pointer in case of physical output padding */ +element_output_type* out = (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +element_input_type *input_ptr = ( (handle->pack_input == 1) || (handle->fwd_padding_copy == 1) ) ? (element_input_type*)((char*)handle->scratch + handle->fwd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; +const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp ); +const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp ); +LIBXSMM_VLA_DECL(5, element_input_type, input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock); +LIBXSMM_VLA_DECL(6, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( imgpt <= 1 ) { + my_img_start = LIBXSMM_MIN(ltid / threads_per_image, handle->desc.N); + my_img_end = LIBXSMM_MIN(my_img_start + 1, handle->desc.N); + myOfmId = ltid % threads_per_image; + nOfmBlocks = LIBXSMM_UPDIV(handle->blocksofm, threads_per_image); + my_ofm_start = LIBXSMM_MIN(myOfmId * nOfmBlocks, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((myOfmId+1) * nOfmBlocks, handle->blocksofm); +} + +if ( handle->use_ofm_parallelization == 1 ) { + if ( handle->desc.N % 8 == 0) { + spread_out = 8; + } else if ( handle->desc.N % 4 == 0) { + spread_out = 4; + } else if (handle->desc.N % 2 == 0) { + spread_out = 2; + } else if (handle->desc.N % 3 == 0) { + spread_out = 3; + } else { + spread_out = 1; + } + if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { + int tile_id = ltid / spread_out; + int ofmpt = LIBXSMM_UPDIV(handle->blocksofm, spread_out); + int ofm_id = ltid % spread_out; + imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads) * spread_out; + my_img_start = LIBXSMM_MIN(tile_id * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN((tile_id+1) * imgpt, handle->desc.N); + my_ofm_start = LIBXSMM_MIN(ofm_id * ofmpt, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((ofm_id+1) * ofmpt, handle->blocksofm); + } +} + +/* remove stride from input */ +if (handle->pack_input == 1) { + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij_use, ii_use, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 || handle->desc.N % handle->desc.threads != 0) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +/* physical pad input */ +if (handle->fwd_padding_copy == 1) { + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + /* copy the inner part */ + for (ij = 0; ij < handle->ifhp+(2*handle->desc.pad_h); ij++) { + for (ii = 0; ii < handle->ifwp+(2*handle->desc.pad_w); ii++) { + if ( (ij >= handle->desc.pad_h) && (ii >= handle->desc.pad_w) && (ij < handle->ifhp+handle->desc.pad_h) && (ii < handle->ifwp+handle->desc.pad_w) ) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij-handle->desc.pad_h, ii-handle->desc.pad_w, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 || handle->desc.N % handle->desc.threads != 0 ) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +if (handle->use_fallback_fwd_loops == 1) { + /* number of tasks that could be run in parallel */ + const int work = handle->desc.N * handle->blocksofm * handle->ofh; + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int imgofm1ofh; + + if ( handle->avoid_fmas_in_rim == 1) { + for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { + img = imgofm1ofh / (handle->blocksofm*handle->ofh); +#if 1 + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; +#else + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->blocksofm; + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->blocksofm; +#endif + + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->ofmblock; + } + } + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); + ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); + } + oi_use = oi; + oj_use = oj; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel_b_addr(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel_b_addr(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel_a_addr(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } + } + } + } + } + } + } + } else { + for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { + img = imgofm1ofh / (handle->blocksofm*handle->ofh); +#if 1 + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; +#else + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->blocksofm; + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->blocksofm; +#endif + + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->ofmblock; + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; +#if 1 + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel_a_addr(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); +#else + LIBXSMM_UNUSED( ifm2 ); + LIBXSMM_UNUSED( kj ); + LIBXSMM_UNUSED( ki ); + n_blocks = handle->blocksifm_blocking * handle->desc.R * handle->desc.S; + if (handle->desc.R == 1 && handle->desc.S == 1) { + br_gemm_kernel_strd( &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm1, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks ); + } else { + br_gemm_kernel_offs( &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm1, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks, handle->A_offsets, handle->B_offsets ); + } +#endif + } + } + } + } + } + +} else { + if (handle->loop_order == 0) { + if ( handle->avoid_fmas_in_rim == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { + ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->ofmblock; + } + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (kj1 = 0; kj1 < handle->desc.R; kj1++) { + for (ki1 = 0; ki1 < handle->desc.S; ki1++) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); + ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); + } + oi_use = oi; + oj_use = oj; + + ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; + kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel_b_addr(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel_b_addr(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel_a_addr(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { + ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->ofmblock; + } + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; +#if 1 + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj1 = 0; kj1 < handle->desc.R; kj1++) { + for (ki1 = 0; ki1 < handle->desc.S; ki1++) { + ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; + kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel_a_addr(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); +#else + LIBXSMM_UNUSED( ifm2 ); + LIBXSMM_UNUSED( kj ); + LIBXSMM_UNUSED( ki ); + n_blocks = handle->blocksifm_blocking * handle->desc.R * handle->desc.S; + if (handle->desc.R == 1 && handle->desc.S == 1) { + br_gemm_kernel_strd( &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm1, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks ); + } else { + br_gemm_kernel_offs( &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm1, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks, handle->A_offsets, handle->B_offsets ); + } +#endif + } + } + } + } + } + } + } + } + } + } + + if (handle->loop_order == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm1++ ) { + if (((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && oj == 0 && oi == 0) { + /* set output feature map to zero */ + for (ojj = 0; ojj < handle->ofh; ++ojj) { + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, ojj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oii = 0; oii < handle->ofw; ++oii) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->ofmblock; + } + } + } + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; +#if 1 + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel_a_addr(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); +#else + LIBXSMM_UNUSED( ifm2 ); + LIBXSMM_UNUSED( kj ); + LIBXSMM_UNUSED( ki ); + n_blocks = handle->blocksifm_blocking * handle->desc.R * handle->desc.S; + if (handle->desc.R == 1 && handle->desc.S == 1) { + br_gemm_kernel_strd( &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm1, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks ); + } else { + br_gemm_kernel_offs( &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm1, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks, handle->A_offsets, handle->B_offsets ); + } +#endif + } + } + } + } + } + } + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..170e3afd82c108d7331fd5c477cbc6d4d9b80f51 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c @@ -0,0 +1,609 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ + +int img, ofm1, ofm2 = 0, ifm1, ifm2 = 0, oj, oi, kj, ki, oi_use, oj_use, ii_use, ij_use, ofmb, ifmb, ojb, myOfmId, nOfmBlocks, ind, ofm11, ki1, kj1, ojj, oii, spread_out = 1, ij = 0, ii = 0; +int last_ki, last_kj, next_kj; +/* computing first logical thread */ +const int ltid = tid - start_thread; +int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); +int threads_per_image = handle->desc.threads / handle->desc.N; +int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); +int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); +int my_ofm_start = 0; +int my_ofm_end = handle->blocksofm; +int ifmblock_lp = handle->ifmblock/handle->fm_lp_block; +/* Batch reduce related variables */ +const element_filter_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +unsigned long long n_blocks; +/* JITed eltwise function */ +libxsmm_meltwfunction_unary cvt_kernel = handle->fwd_cvtfp32bf16_kernel; +libxsmm_meltw_unary_param cvt_params; + +/* offset output pointer in case of physical output padding */ +element_output_type* out = (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +float* out_fp32 = (float*)((char*)handle->scratch + handle->fwd_lp_output_full_scratch_offset) + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +float* out_scratch = (float*)((char*)handle->scratch + handle->fwd_lp_output_block_scratch_offset) + ((size_t) ltid * handle->fwd_ofw_rb * handle->fwd_ofh_rb * handle->ofmblock); +float* out_ptr; +LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +LIBXSMM_VLA_DECL(5, float, output_fp32, out_fp32, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +LIBXSMM_VLA_DECL(3, float, scratch_fp32, out_scratch, handle->fwd_ofw_rb, handle->ofmblock); +element_input_type *input_ptr = ((handle->pack_input == 1) || (handle->fwd_padding_copy == 1)) ?(element_input_type*)((char*)handle->scratch + handle->fwd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; +const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp ); +const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp ); +LIBXSMM_VLA_DECL(5, element_input_type, input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock); +LIBXSMM_VLA_DECL(7, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + +libxsmm_barrier_init(handle->barrier, ltid); + +if ( imgpt <= 1 ) { + my_img_start = LIBXSMM_MIN(ltid / threads_per_image, handle->desc.N); + my_img_end = LIBXSMM_MIN(my_img_start + 1, handle->desc.N); + myOfmId = ltid % threads_per_image; + nOfmBlocks = LIBXSMM_UPDIV(handle->blocksofm, threads_per_image); + my_ofm_start = LIBXSMM_MIN(myOfmId * nOfmBlocks, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((myOfmId+1) * nOfmBlocks, handle->blocksofm); +} + +if ( handle->use_ofm_parallelization == 1 ) { + if ( handle->desc.N % 8 == 0) { + spread_out = 8; + } else if ( handle->desc.N % 4 == 0) { + spread_out = 4; + } else if (handle->desc.N % 2 == 0) { + spread_out = 2; + } else if (handle->desc.N % 3 == 0) { + spread_out = 3; + } else { + spread_out = 1; + } + if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { + int tile_id = ltid / spread_out; + int ofmpt = LIBXSMM_UPDIV(handle->blocksofm, spread_out); + int ofm_id = ltid % spread_out; + imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads) * spread_out; + my_img_start = LIBXSMM_MIN(tile_id * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN((tile_id+1) * imgpt, handle->desc.N); + my_ofm_start = LIBXSMM_MIN(ofm_id * ofmpt, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((ofm_id+1) * ofmpt, handle->blocksofm); + } +} + +if (handle->pack_input == 1) { + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij_use, ii_use, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 ) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +/* physical pad input */ +if (handle->fwd_padding_copy == 1) { + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + /* copy the inner part */ + for (ij = 0; ij < handle->ifhp+(2*handle->desc.pad_h); ij++) { + for (ii = 0; ii < handle->ifwp+(2*handle->desc.pad_w); ii++) { + if ( (ij >= handle->desc.pad_h) && (ii >= handle->desc.pad_w) && (ij < handle->ifhp+handle->desc.pad_h) && (ii < handle->ifwp+handle->desc.pad_w) ) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij-handle->desc.pad_h, ii-handle->desc.pad_w, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 || handle->desc.N % handle->desc.threads != 0 ) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +if (handle->use_fallback_fwd_loops == 1) { + /* number of tasks that could be run in parallel */ + const int work = handle->desc.N * handle->blocksofm * handle->ofh; + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int imgofm1ofh; + + if ( handle->avoid_fmas_in_rim == 1) { + for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { + img = imgofm1ofh / (handle->blocksofm*handle->ofh); + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); + ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); + } + oi_use = oi; + oj_use = oj; + last_kj = handle->desc.R-1; + last_ki = handle->desc.S-1; + next_kj = kj+1; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + + if (handle->avoid_acc_load == 1) { + br_gemm_kernel2_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + out_ptr = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + + cvt_params.in.primary = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_params.out.primary = &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_kernel(&cvt_params); + } + } + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + + if (handle->avoid_acc_load == 1) { + br_gemm_kernel2_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + out_ptr = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + cvt_params.in.primary = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_params.out.primary = &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_kernel(&cvt_params); + } + } + } else { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + if (handle->avoid_acc_load == 1) { + br_gemm_kernel_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + out_ptr = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + cvt_params.in.primary = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_params.out.primary = &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_kernel(&cvt_params); + } + } + } + } + } + } + } + } + } + } else { + for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { + img = imgofm1ofh / (handle->blocksofm*handle->ofh); + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; + + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; + ind = 0; + kj = 0; + ki = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + + if (handle->avoid_acc_load == 1) { + br_gemm_kernel_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, handle->fwd_ofw_rb, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (ifm2 == handle->blocksifm && kj == handle->desc.R && ki == handle->desc.S) { + cvt_params.in.primary = &LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_params.out.primary = &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_kernel(&cvt_params); + } + } + } + } + } + } + } +} else { + if (handle->loop_order == 0) { + if ( handle->avoid_fmas_in_rim == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { + ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (kj1 = 0; kj1 < handle->desc.R; kj1++) { + for (ki1 = 0; ki1 < handle->desc.S; ki1++) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); + ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); + } + oi_use = oi; + oj_use = oj; + + ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; + kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; + last_ki = (handle->shuffle_filter_accesses == 1) ? (handle->desc.S-1+ltid)%handle->desc.S : handle->desc.S-1; + last_kj = (handle->shuffle_filter_accesses == 1) ? (handle->desc.R-1+ltid)%handle->desc.R : handle->desc.R-1; + next_kj = (handle->shuffle_filter_accesses == 1) ? (kj1+1+ltid)%handle->desc.R : kj1+1; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + + if (handle->avoid_acc_load == 1) { + br_gemm_kernel2_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + out_ptr = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + cvt_params.in.primary = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_params.out.primary = &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_kernel(&cvt_params); + } + } + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + if (handle->avoid_acc_load == 1) { + br_gemm_kernel2_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + out_ptr = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + cvt_params.in.primary = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_params.out.primary = &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_kernel(&cvt_params); + } + } + } else { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + if (handle->avoid_acc_load == 1) { + br_gemm_kernel_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + out_ptr = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + cvt_params.in.primary = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_params.out.primary = &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_kernel(&cvt_params); + } + } + } + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { + ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; + ind = 0; + kj1 = 0; + ki1 = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj1 = 0; kj1 < handle->desc.R; kj1++) { + for (ki1 = 0; ki1 < handle->desc.S; ki1++) { + ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; + kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + + if (handle->avoid_acc_load == 1) { + br_gemm_kernel_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + out_ptr = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (kj1 == handle->desc.R && ki1 == handle->desc.S && ifm2 == handle->blocksifm) { + cvt_params.in.primary = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_params.out.primary = &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_kernel(&cvt_params); + } + } + } + } + } + } + } + } + } + } + } + } + + if (handle->loop_order == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm1++ ) { + if (((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && oj == 0 && oi == 0) { + /* set output feature map to zero */ + for (ojj = 0; ojj < handle->ofh; ++ojj) { + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, ojj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oii = 0; oii < handle->ofw; ++oii) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + } + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; + ind = 0; + kj = 0; + ki = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + + if (handle->avoid_acc_load == 1) { + br_gemm_kernel_bf16bf16(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + out_ptr = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (kj == handle->desc.R && ki == handle->desc.S && ifm2 == handle->blocksifm) { + cvt_params.in.primary = &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_params.out.primary = &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + cvt_kernel(&cvt_params); + } + } + } + } + } + } + } + } + } + } + } + +#if 0 + /* In case we used intermediate fp32 buffer, now downconvert the result to the actual bf16 output */ + if (handle->avoid_acc_load == 0) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofm1 = my_ofm_start; ofm1 < my_ofm_end; ofm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->ofw * handle->ofmblock); + } + } + } + } +#endif + +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..3db6596c4b82cb1d41651c3791152f629adb2209 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c @@ -0,0 +1,732 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +int img, ofm1, ifm1, ifm2, /*ofm2, ifm1, ifm2,*/ oj, oi, ij, ii, /*kj, ki, oi_use, oj_use, */ii_use, ij_use, ofmb,/* ifmb,*/ ojb, myOfmId, nOfmBlocks, /*ind, ofm11, ki1, kj1,*/ ojj, /*oii,*/ spread_out = 1; +/*int last_ki, last_kj, next_kj;*/ +/* computing first logical thread */ +const int ltid = tid - start_thread; +int imgpt = (handle->desc.N + handle->desc.threads - 1)/handle->desc.threads; +int threads_per_image = handle->desc.threads / handle->desc.N; +int my_img_start = LIBXSMM_MIN( ltid * imgpt, handle->desc.N); +int my_img_end = LIBXSMM_MIN( (ltid+1) * imgpt, handle->desc.N); +int my_ofm_start = 0; +int my_ofm_end = handle->blocksofm; +int ifmblock_lp = handle->ifmblock/handle->fm_lp_block; +/* Batch reduce related variables */ +/*const element_filter_type *A_ptrs[1024];*/ +/*const element_input_type *B_ptrs[1024];*/ +unsigned long long n_blocks; + +/* offset output pointer in case of physical output padding */ +element_output_type* out = (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +/*float* out_fp32 = (float*)handle->scratch6 + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock;*/ +float* out_ptr; +LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +/*LIBXSMM_VLA_DECL(5, float, output_fp32, out_fp32, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);*/ +int scratch_ofwp = (handle->fwd_gemm_pixels == (handle->fwd_ofw_rb * handle->fwd_ofh_rb)) ? handle->fwd_ofw_rb : ((handle->fwd_padding_copy == 1) ? handle->ofwp + 2 * handle->desc.pad_w : handle->ofwp); +/*float scratch_stack_fp32[8*16*16];*/ +float *out_scratch = (float*)((char*)handle->scratch + handle->fwd_lp_output_full_scratch_offset) + ltid * handle->fwd_gemm_pixels * handle->ofmblock; +LIBXSMM_VLA_DECL(3, float, scratch_fp32, out_scratch, scratch_ofwp, handle->ofmblock); +element_input_type *input_ptr = ((handle->pack_input == 1) || (handle->fwd_padding_copy == 1)) ?(element_input_type*)((char*)handle->scratch + handle->fwd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; +const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp ); +const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp ); +LIBXSMM_VLA_DECL(5, element_input_type, input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock); +LIBXSMM_VLA_DECL(7, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + +if ( imgpt <= 1 ) { + my_img_start = LIBXSMM_MIN( ltid / threads_per_image, handle->desc.N); + my_img_end = LIBXSMM_MIN( my_img_start + 1, handle->desc.N); + myOfmId = ltid % threads_per_image; + nOfmBlocks = (handle->blocksofm + threads_per_image - 1) / threads_per_image; + my_ofm_start = LIBXSMM_MIN(myOfmId * nOfmBlocks, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((myOfmId+1) * nOfmBlocks, handle->blocksofm); +} + +if ( handle->use_ofm_parallelization == 1 ) { + if ( handle->desc.N % 8 == 0) { + spread_out = 8; + } else if ( handle->desc.N % 4 == 0) { + spread_out = 4; + } else if (handle->desc.N % 2 == 0) { + spread_out = 2; + } else if (handle->desc.N % 3 == 0) { + spread_out = 3; + } else { + spread_out = 1; + } + if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { + int tile_id = ltid / spread_out; + int ofmpt = (handle->blocksofm+spread_out-1)/spread_out; + int ofm_id = ltid % spread_out; + imgpt = ((handle->desc.N + handle->desc.threads - 1)/handle->desc.threads) * spread_out; + my_img_start = LIBXSMM_MIN( tile_id * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN( (tile_id+1) * imgpt, handle->desc.N); + my_ofm_start = LIBXSMM_MIN( ofm_id * ofmpt, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN( (ofm_id+1) * ofmpt, handle->blocksofm); + } +} + +n_blocks = (unsigned long long)handle->blocksifm_blocking * handle->desc.R * handle->desc.S; +out_ptr = (float*) &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock); + +libxsmm_barrier_init(handle->barrier, ltid); + +if (handle->pack_input == 1) { + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij_use, ii_use, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 ) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +/* physical pad input */ +if (handle->fwd_padding_copy == 1) { + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + /* copy the inner part */ + for (ij = 0; ij < handle->ifhp+(2*handle->desc.pad_h); ij++) { + for (ii = 0; ii < handle->ifwp+(2*handle->desc.pad_w); ii++) { + if ( (ij >= handle->desc.pad_h) && (ii >= handle->desc.pad_w) && (ij < handle->ifhp+handle->desc.pad_h) && (ii < handle->ifwp+handle->desc.pad_w) ) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij-handle->desc.pad_h, ii-handle->desc.pad_w, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 || handle->desc.N % handle->desc.threads != 0 ) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +/* Execute the tileconfig kernel */ +tile_config_kernel(NULL, NULL, NULL); + +#if 1 +if (handle->desc.R == 1 && handle->desc.S == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm1++ ) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u; + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v; + /* Batch-reduce GEMM call */ + br_gemm_kernel_strd( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, 0, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, 0, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } + } + } + } + } + } +} +/* @TODO this needs a reasonable fix */ +else if ( handle->fwd_ofw_rb*handle->fwd_ofh_rb == handle->fwd_gemm_pixels ) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm1++ ) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Batch-reduce GEMM call */ + br_gemm_kernel_offs_a( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, 0, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, 0, oj*handle->desc.u, oi*handle->desc.v, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &n_blocks, handle->A_offsets, handle->B_offsets); + } + } + } + } + } + } +} else { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm1++ ) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Batch-reduce GEMM call */ + br_gemm_kernel_offs_b( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, 0, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, 0, oj, oi, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), out_ptr, &n_blocks, handle->A_offsets, handle->B_offsets); + /* Downconvert accumulated tiles to BF16 */ + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, scratch_ofwp, handle->ofmblock), &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj+ojj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), handle->fwd_ofw_rb * handle->ofmblock); + } + } + } + } + } + } + } +} +#else +if (handle->pack_input == 1) { + int ifmpt = (handle->blocksifm+spread_out-1)/spread_out; + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN( ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN( (ifm_id+1) * ifmpt, handle->blocksifm); + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij_use, ii_use, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 ) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +if (handle->use_fallback_fwd_loops == 1) { + /* number of tasks that could be run in parallel */ + const int work = handle->desc.N * handle->blocksofm * handle->ofh; + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int imgofm1ofh; + + if ( handle->avoid_fmas_in_rim == 1) { + for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { + img = imgofm1ofh / (handle->blocksofm*handle->ofh); + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); + ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); + } + oi_use = oi; + oj_use = oj; + last_kj = handle->desc.R-1; + last_ki = handle->desc.S-1; + next_kj = kj+1; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (handle->avoid_acc_load == 1) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, handle->fwd_ofw_rb, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use+1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + (handle->fwd_ofw_rb-1) * handle->ofmblock); + } + } else if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (handle->avoid_acc_load == 1) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, handle->fwd_ofw_rb, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + (handle->fwd_ofw_rb-1) * handle->ofmblock); + } + } else if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } + } else { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (handle->avoid_acc_load == 1) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, handle->fwd_ofw_rb, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } else if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } + } + } + } + } + } + } + } + } else { + for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { + img = imgofm1ofh / (handle->blocksofm*handle->ofh); + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; + + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; + ind = 0; + kj = 0; + ki = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (handle->avoid_acc_load == 1) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, handle->fwd_ofw_rb, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } else if (ifm2 == handle->blocksifm && kj == handle->desc.R && ki == handle->desc.S) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } + } + } + } + } + } +} else { + if (handle->loop_order == 0) { + if ( handle->avoid_fmas_in_rim == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { + ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (kj1 = 0; kj1 < handle->desc.R; kj1++) { + for (ki1 = 0; ki1 < handle->desc.S; ki1++) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); + ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); + } + oi_use = oi; + oj_use = oj; + + ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; + kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; + last_ki = (handle->shuffle_filter_accesses == 1) ? (handle->desc.S-1+ltid)%handle->desc.S : handle->desc.S-1; + last_kj = (handle->shuffle_filter_accesses == 1) ? (handle->desc.R-1+ltid)%handle->desc.R : handle->desc.R-1; + next_kj = (handle->shuffle_filter_accesses == 1) ? (kj1+1+ltid)%handle->desc.R : kj1+1; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (handle->avoid_acc_load == 1) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, handle->fwd_ofw_rb, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use+1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + (handle->fwd_ofw_rb-1) * handle->ofmblock); + } + } else if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel2(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (handle->avoid_acc_load == 1) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, handle->fwd_ofw_rb, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + (handle->fwd_ofw_rb-1) * handle->ofmblock); + } + } else if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } + } else { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (handle->avoid_acc_load == 1) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, handle->fwd_ofw_rb, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } else if (ifm2 == handle->blocksifm && + ((kj == last_kj && ki == last_ki) || + (next_kj == 0 && next_kj == last_kj && oj == 0) || + (next_kj == handle->desc.R-1 && next_kj == last_kj && oj == handle->ofh-1))) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } + } + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { + ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; + ind = 0; + kj1 = 0; + ki1 = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj1 = 0; kj1 < handle->desc.R; kj1++) { + for (ki1 = 0; ki1 < handle->desc.S; ki1++) { + ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; + kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + if (handle->avoid_acc_load == 1) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, handle->fwd_ofw_rb, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } else if (kj1 == handle->desc.R && ki1 == handle->desc.S && ifm2 == handle->blocksifm) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } + } + } + } + } + } + } + } + } + } + } + + if (handle->loop_order == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm1++ ) { + if (((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && oj == 0 && oi == 0) { + /* set output feature map to zero */ + for (ojj = 0; ojj < handle->ofh; ++ojj) { + float* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, ojj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oii = 0; oii < handle->ofw; ++oii) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (float)0; + } + temp_ptr += handle->ofmblock; + } + } + } + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; + ind = 0; + kj = 0; + ki = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm2, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + out_ptr = (handle->avoid_acc_load == 1) ? &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, 0, 0, 0, scratch_ofwp, handle->ofmblock) : &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + br_gemm_kernel(A_ptrs, B_ptrs, out_ptr, &n_blocks); + + if (handle->avoid_acc_load == 1) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 3, scratch_fp32, ojj, 0, 0, handle->fwd_ofw_rb, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } else if (kj == handle->desc.R && ki == handle->desc.S && ifm2 == handle->blocksifm) { + for (ojj = 0; ojj < handle->fwd_ofh_rb; ojj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS(5, output_fp32, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj_use+ojj, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->fwd_ofw_rb * handle->ofmblock); + } + } + } + } + } + } + } + } + } + } + } + +#if 0 + /* In case we used intermediate fp32 buffer, now downconvert the result to the actual bf16 output */ + if (handle->avoid_acc_load == 0) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofm1 = my_ofm_start; ofm1 < my_ofm_end; ofm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16( &LIBXSMM_VLA_ACCESS( 5, output_fp32, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + handle->ofw * handle->ofmblock); + } + } + } + } +#endif + +} +#endif + +handle->tilerelease_kernel(NULL, NULL, NULL); +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i32.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i32.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..7b654f25b335ed27cfa431f6ac7fb5c9e1639c27 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i32.tpl.c @@ -0,0 +1,170 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ + +int img, ofm1, ofm2, ifm1, ifm2, oj, oi, kj, ki, ii_use, ij_use, oii, spread_out = 1; +/* computing first logical thread */ +const int ltid = tid - start_thread; + +/* number of tasks that could be run in parallel */ +const int w_tasks = handle->ofw/handle->fwd_ofw_rb; +const int work = handle->desc.N * handle->blocksofm * handle->ofh * w_tasks; +const int work_KHW = handle->blocksofm * handle->ofh * w_tasks; +const int work_HW = handle->ofh * w_tasks; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; +int imgofm1ofhofw; +int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); +int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); +int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); +int ifmblock_lp = handle->ifmblock/handle->fm_lp_block; +/* Batch reduce related variables */ +unsigned long long n_blocks; + +/* offset output pointer in case of physical output padding */ +element_output_type* out = (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +element_input_type *input_ptr = (handle->pack_input == 1) ?(element_input_type*)((char*)handle->scratch + handle->fwd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; +const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp; +const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp; +LIBXSMM_VLA_DECL(5, element_input_type, input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock); +LIBXSMM_VLA_DECL(7, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + +libxsmm_barrier_init(handle->barrier, ltid); + +if (handle->pack_input == 1) { + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij_use, ii_use, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 ) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +if (handle->avoid_fmas_in_rim == 1) { + n_blocks = handle->blocksifm_blocking; + for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) { + img = imgofm1ofhofw / work_KHW; + ofm1 = (imgofm1ofhofw % work_KHW)/work_HW; + oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks; + oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb; + ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u - (1-handle->desc.pad_h_in); + ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v - (1-handle->desc.pad_w_in); + if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oii = 0; oii < handle->fwd_ofw_rb; ++oii) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->ofmblock; + } + } + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1 += handle->blocksifm_blocking) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + br_gemm_kernel_strided2( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use+kj, ii_use+ki+1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi+1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + br_gemm_kernel_strided2( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use+kj, ii_use+ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } else { + br_gemm_kernel_strided( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use+kj, ii_use+ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } + } + } + } + } +} else { + /* Strided based BRGEMM */ + n_blocks = (unsigned long long)handle->blocksifm_blocking * handle->desc.R * handle->desc.S; + if (handle->desc.R == 1 && handle->desc.S == 1) { + for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) { + img = imgofm1ofhofw / work_KHW; + ofm1 = (imgofm1ofhofw % work_KHW)/work_HW; + oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks; + oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb; + ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u; + ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v; + if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oii = 0; oii < handle->fwd_ofw_rb; ++oii) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->ofmblock; + } + } + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1 += handle->blocksifm_blocking) { + br_gemm_kernel_strided( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); + } + } + } else { /* Offset based BRGEMM */ + for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) { + img = imgofm1ofhofw / work_KHW; + ofm1 = (imgofm1ofhofw % work_KHW)/work_HW; + oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks; + oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb; + ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u; + ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v; + if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); + for (oii = 0; oii < handle->fwd_ofw_rb; ++oii) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->ofmblock; + } + } + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1 += handle->blocksifm_blocking) { + br_gemm_kernel_offset( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks, handle->A_offsets, handle->B_offsets); + } + } + } +} +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i8.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i8.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..961355476fa9adbb839068cb9e5a0aed552c4ef7 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i8.tpl.c @@ -0,0 +1,61 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ +const int ifmblock_lp = handle->ifmblock/handle->fm_lp_block; +int imgofm1ofhofw, img, ofm1, oj, oi, ii, ij; +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int w_tasks = handle->ofw/handle->fwd_ofw_rb; +const int work = handle->desc.N * handle->blocksofm * handle->ofh * w_tasks; +const int work_KHW = handle->blocksofm * handle->ofh * w_tasks; +const int work_HW = handle->ofh * w_tasks; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; +/* Batch reduce related variables */ +unsigned long long n_blocks = (unsigned long long)handle->blocksifm_blocking * handle->desc.R * handle->desc.S; +/* Calculate scaling factor here for output... */ +float _scf = libxsmm_sexp2_i8i(-(handle->reg_filter->scf + handle->reg_input->scf - handle->reg_output->scf)); +/* offset output pointer in case of physical output padding */ +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); +LIBXSMM_VLA_DECL(7, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); + +libxsmm_barrier_init(handle->barrier, ltid); +if (handle->desc.R == 1 && handle->desc.S == 1) { /* Strided based BRGEMM */ + for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) { + img = imgofm1ofhofw / work_KHW; + ofm1 = (imgofm1ofhofw % work_KHW)/work_HW; + oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks; + oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb; + ij = oj * handle->desc.u; + ii = oi * handle->desc.v; + br_gemm_kernel_strided( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, 0, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, 0, ij, ii, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks, &_scf); + } +} else { /* Offset based BRGEMM */ + for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) { + img = imgofm1ofhofw / work_KHW; + ofm1 = (imgofm1ofhofw % work_KHW)/work_HW; + oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks; + oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb; + ij = oj * handle->desc.u; + ii = oi * handle->desc.v; + br_gemm_kernel_offset( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, 0, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), + &LIBXSMM_VLA_ACCESS(5, input, img, 0, ij, ii, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(5 , output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks, handle->A_offsets, handle->B_offsets, &_scf); + } +} +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..04232958ee358dce596252981bd72c9719117eed --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c @@ -0,0 +1,522 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) +******************************************************************************/ + +int img, ofm1, ofm2 = 0, ifm1, ifm2 = 0, oj, oi, kj, ki, oi_use, oj_use, ii_use, ij_use, ofmb, ifmb, ojb, myOfmId, nOfmBlocks, ind, ofm11, ki1, kj1, ojj, oii, ii, ij, spread_out = 1; +/* computing first logical thread */ +const int ltid = tid - start_thread; +int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); +int threads_per_image = handle->desc.threads / handle->desc.N; +int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); +int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); +int my_ofm_start = 0; +int my_ofm_end = handle->blocksofm; + +/* Batch reduce related variables */ +const element_filter_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +unsigned long long n_blocks; + +/* offset output pointer in case of physical output padding */ +element_output_type* out = (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->blocksofm * handle->ofmblock; +LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); +element_input_type *input_ptr = ( (handle->pack_input == 1) || (handle->fwd_padding_copy == 1) ) ?(element_input_type*)((char*)handle->scratch + handle->fwd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; +const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp ); +const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp ); +LIBXSMM_VLA_DECL(5, element_input_type, input, input_ptr, IFH, IFW, handle->blocksifm, handle->ifmblock); +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM +LIBXSMM_VLA_DECL(6, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK +LIBXSMM_VLA_DECL(6, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( imgpt <= 1 ) { + my_img_start = LIBXSMM_MIN(ltid / threads_per_image, handle->desc.N); + my_img_end = LIBXSMM_MIN(my_img_start + 1, handle->desc.N); + myOfmId = ltid % threads_per_image; + nOfmBlocks = LIBXSMM_UPDIV(handle->blocksofm, threads_per_image); + my_ofm_start = LIBXSMM_MIN(myOfmId * nOfmBlocks, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((myOfmId+1) * nOfmBlocks, handle->blocksofm); +} + +if ( handle->use_ofm_parallelization == 1 ) { + if ( handle->desc.N % 8 == 0) { + spread_out = 8; + } else if ( handle->desc.N % 4 == 0) { + spread_out = 4; + } else if (handle->desc.N % 2 == 0) { + spread_out = 2; + } else if (handle->desc.N % 3 == 0) { + spread_out = 3; + } else { + spread_out = 1; + } + if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { + int tile_id = ltid / spread_out; + int ofmpt = LIBXSMM_UPDIV(handle->blocksofm, spread_out); + int ofm_id = ltid % spread_out; + imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads) * spread_out; + my_img_start = LIBXSMM_MIN(tile_id * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN((tile_id+1) * imgpt, handle->desc.N); + my_ofm_start = LIBXSMM_MIN(ofm_id * ofmpt, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((ofm_id+1) * ofmpt, handle->blocksofm); + } +} + +/* remove stride from input */ +if (handle->pack_input == 1) { + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + /* @TODO think about packed format */ + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, oj, oi, ifm1, ifm2, IFH, IFW, handle->blocksifm, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ij_use, ii_use, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 ) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +/* physical pad input */ +if (handle->fwd_padding_copy == 1) { + int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); + int ifm_id = ltid % spread_out; + int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); + int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + /* copy the inner part */ + for (ij = 0; ij < handle->ifhp+(2*handle->desc.pad_h); ij++) { + for (ii = 0; ii < handle->ifwp+(2*handle->desc.pad_w); ii++) { + if ( (ij >= handle->desc.pad_h) && (ii >= handle->desc.pad_w) && (ij < handle->ifhp+handle->desc.pad_h) && (ii < handle->ifwp+handle->desc.pad_w) ) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, ifm2, IFH, IFW, handle->blocksifm, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(5, input_src, img, ij-handle->desc.pad_h, ii-handle->desc.pad_w, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, ifm2, IFH, IFW, handle->blocksifm, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } + if ( handle->use_ofm_parallelization == 1 ) { + libxsmm_barrier_wait(handle->barrier, ltid); + } +} + +if (handle->use_fallback_fwd_loops == 1) { + /* number of tasks that could be run in parallel */ + const int work = handle->desc.N * handle->blocksofm * handle->ofh; + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int imgofm1ofh; + + if ( handle->avoid_fmas_in_rim == 1) { + for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { + img = imgofm1ofh / (handle->blocksofm*handle->ofh); +#if 1 + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; +#else + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->blocksofm; + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->blocksofm; +#endif + + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, oj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->blocksofm*handle->ofmblock; + } + } + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); + ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); + } + oi_use = oi; + oj_use = oj; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki + 1, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use + 1, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); + } else { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); + } + } + } + } + } + } + } + } else { + for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { + img = imgofm1ofh / (handle->blocksofm*handle->ofh); +#if 1 + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; +#else + oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->blocksofm; + ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->blocksofm; +#endif + + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { + /* set output feature map to zero */ + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, oj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->blocksofm*handle->ofmblock; + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); + } + } + } + } + } + +} else { + if (handle->loop_order == 0) { + if ( handle->avoid_fmas_in_rim == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { + ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, oj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->blocksofm*handle->ofmblock; + } + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (kj1 = 0; kj1 < handle->desc.R; kj1++) { + for (ki1 = 0; ki1 < handle->desc.S; ki1++) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); + ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); + } + oi_use = oi; + oj_use = oj; + + ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; + kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; + + if (kj == 0 && oj == 0) { + /* Do no FLOPS */ + } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { + /* Do no FLOPS */ + } else if ( oi == 0 && ki == 0 ) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki + 1, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use + 1, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); + } else { + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); + } + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { + ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; + if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { + /* set output feature map to zero */ + for (oj = 0; oj < handle->ofh; ++oj) { + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, oj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); + for (oi = 0; oi < handle->ofw; ++oi) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->blocksofm * handle->ofmblock; + } + } + } + + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj1 = 0; kj1 < handle->desc.R; kj1++) { + for (ki1 = 0; ki1 < handle->desc.S; ki1++) { + ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; + kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); + } + } + } + } + } + } + } + } + } + } + + if (handle->loop_order == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm1++ ) { + if (((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && oj == 0 && oi == 0) { + /* set output feature map to zero */ + for (ojj = 0; ojj < handle->ofh; ++ojj) { + element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ojj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); + for (oii = 0; oii < handle->ofw; ++oii) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { + temp_ptr[ofm2] = (element_output_type)0; + } + temp_ptr += handle->blocksofm * handle->ofmblock; + } + } + } + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { + /* Prepare batch-reduce kernel arguments */ + if (handle->pack_input == 1) { + ij_use = oj; + ii_use = oi; + } else { + ij_use = oj * handle->desc.u; + ii_use = oi * handle->desc.v; + } + oi_use = oi; + oj_use = oj; + ind = 0; + for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { + for (kj = 0; kj < handle->desc.R; kj++) { + for (ki = 0; ki < handle->desc.S; ki++) { +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); + ind++; + } + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); + } + } + } + } + } + } + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..356d4138f68eff2290e9c143c58ad9f134f460ca --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c @@ -0,0 +1,577 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +int img, my_img_start, my_img_end, ofmb, ifmb, ojb, ofm1, ifm1, ifm2 = 0, ofm2 = 0, oj, oi, ii, ij, kj, ki, ind, j_br, img_br, img_block_size = 1, my_ofm_start, my_ofm_end, my_ifm_start, my_ifm_end, block_ofm, block_ifm; +/* computing first logical thread */ +const int ltid = tid - start_thread; +libxsmm_blasint LDA = handle->ofmblock; +libxsmm_blasint LDB = (handle->upd_pack_input == 1) ? handle->ifmblock : handle->desc.v * handle->ifmblock; +libxsmm_blasint LDC = handle->ofmblock; +int l_flags = LIBXSMM_GEMM_FLAGS('N', 'T'); +element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +LIBXSMM_VLA_DECL(5, const element_output_type, output, (const element_output_type*)out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +const int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : handle->ifwp; +const int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : handle->ifhp; +element_input_type *input_ptr_to_use = (handle->upd_padding_copy == 1) ? (element_input_type*) ((char*)handle->scratch + handle->upd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type*) input_ptr_to_use, handle->blocksifm, IFHP, IFWP, handle->ifmblock); +LIBXSMM_VLA_DECL(6, element_filter_type, weight_global, (element_filter_type*)handle->grad_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +element_filter_type *weight_ptr = (handle->weight_copies == 1) ? (element_filter_type*)handle->grad_filter->data : (element_filter_type*) ((char*)handle->scratch + handle->upd_filter_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; +LIBXSMM_VLA_DECL(6, element_filter_type, weight_private, (element_filter_type*)weight_ptr, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +int prefetch_mode = (handle->desc.u == 2 || (handle->desc.R == 3 && handle->ofw == 7) ) ? libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE) : libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BL1); + +/* Batch reduce related variables */ +const element_output_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +unsigned long long n_blocks; + +int brgemm_pf_oob = 0; +const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); +if ( 0 == env_brgemm_pf_oob ) { +} else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); +} +if (brgemm_pf_oob > 0) { + prefetch_mode = prefetch_mode | libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); +} + +libxsmm_barrier_init(handle->barrier, ltid); + +/* physical pad input */ +if (handle->upd_padding_copy == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); + + my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + /* copy the inner part */ + for (ij = 0; ij < handle->ifhp+(2*handle->desc.pad_h); ij++) { + for (ii = 0; ii < handle->ifwp+(2*handle->desc.pad_w); ii++) { + if ( (ij >= handle->desc.pad_h) && (ii >= handle->desc.pad_w) && (ij < handle->ifhp+handle->desc.pad_h) && (ii < handle->ifwp+handle->desc.pad_w) ) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFHP, IFWP, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij-handle->desc.pad_h, ii-handle->desc.pad_w, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFHP, IFWP, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } + libxsmm_barrier_wait(handle->barrier, ltid); +} + + +if (handle->upd_use_batchreduce == 0 && handle->upd_linearized_tasklist == 0) { + /* Parallelize over minibatch */ + const int img_work = handle->desc.N; + const int img_chunksize = (img_work % handle->desc.threads == 0) ? (img_work / handle->desc.threads) : (img_work / handle->desc.threads) + 1; + const float beta = ((img_chunksize == 1) && (handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw)) ? 0.f : 1.f; + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb * handle->upd_ofh_rb, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + + my_img_start = (ltid * img_chunksize < img_work) ? (ltid * img_chunksize) : img_work; + my_img_end = ((ltid + 1) * img_chunksize < img_work) ? ((ltid + 1) * img_chunksize) : img_work; + + if (!((img_chunksize == 1) && (handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw))) { + memset(weight_ptr, 0, handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S * sizeof(element_filter_type)); + } + + if (handle->upd_loop_order == 0) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->upd_ofh_rb) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->upd_ofh_rb,handle->ofh); oj+= handle->upd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->upd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, 0, handle->blocksifm, IFHP, IFWP, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(6, weight_private, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) ); + } + } + } + } + } + } + } + } + } + } + } + if (handle->upd_loop_order == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->upd_ofh_rb) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->upd_ofh_rb,handle->ofh); oj+= handle->upd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->upd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, 0, handle->blocksifm, IFHP, IFWP, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(6, weight_private, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) ); + } + } + } + } + } + } + } + } + } + } + } +} else { + if (handle->upd_linearized_tasklist == 1) { + /* Amount of work when using linearized view of tasks */ + const int work = handle->desc.R * handle->desc.S * handle->blocksofm * handle->blocksifm; + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : (work / handle->desc.threads) + 1; + const int work_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int work_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int work_item; + int Cb = handle->blocksifm; +#if 0 + int Kb = handle->blocksofm; +#endif + int R = handle->desc.R; + int S = handle->desc.S; + + if (handle->upd_avoid_rim_fmas == 0) { + const int IFH = (handle->upd_pack_input == 1) ? handle->ifhp/handle->desc.u : IFHP; + const int IFW = (handle->upd_pack_input == 1) ? handle->ifwp/handle->desc.v : IFWP; + element_input_type *input_ptr_base = (handle->upd_pack_input == 1) ? (element_input_type*) ((char*)handle->scratch + handle->upd_packing_padding_scratch_offset) : (element_input_type*)input_ptr_to_use; + LIBXSMM_VLA_DECL(5, element_input_type, input_use, (element_input_type*)input_ptr_base, handle->blocksifm, IFH, IFW, handle->ifmblock); + const float beta = ((handle->desc.N == 1) && (handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw)) ? 0.f : 1.f; + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb * handle->upd_ofh_rb, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + + /* If requested, pack input to avoid strided accesses */ + if (handle->upd_pack_input == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, IFHP, IFWP, handle->ifmblock); + const int img_chunk = (handle->desc.N % handle->desc.threads == 0) ? handle->desc.N/handle->desc.threads : (handle->desc.N/handle->desc.threads) + 1; + const int img_copy_start = LIBXSMM_MIN(ltid*img_chunk, handle->desc.N); + const int img_copy_end = LIBXSMM_MIN((ltid+1)*img_chunk, handle->desc.N); + + for (img = img_copy_start; img < img_copy_end; img++) { + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij = oj * handle->desc.u; + ii = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input_use, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFHP, IFWP, handle->ifmblock); + } + } + } + } + } + libxsmm_barrier_wait(handle->barrier, ltid); + } + + /* Initialize weights to zero */ + if (!((handle->desc.N == 1) && (handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw))) { + for (work_item = work_begin; work_item < work_end; work_item++) { + ofm1 = work_item/(Cb*R*S); + ifm1 = (work_item%(Cb*R*S))/(R*S); + kj = ((work_item%(Cb*R*S))%(R*S))/S; + ki = ((work_item%(Cb*R*S))%(R*S))%S; + + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) = (element_filter_type)0; + } + } + } + } + + for (img = 0; img < handle->desc.N; img++) { + for (work_item = work_begin; work_item < work_end; work_item++) { + ofm1 = work_item/(Cb*R*S); + ifm1 = (work_item%(Cb*R*S))/(R*S); + kj = ((work_item%(Cb*R*S))%(R*S))/S; + ki = ((work_item%(Cb*R*S))%(R*S))%S; + oi = 0; + ii = ki; + for (oj = 0; oj < handle->ofh; oj += handle->upd_ofh_rb) { + ij = oj * handle->desc.u + kj; + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input_use, img, ifm1, ij, ii, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) ); + } + } + } + } else { + const float beta = ((handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw)) ? 0.f : 1.f; + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb-1, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + + for (work_item = work_begin; work_item < work_end; work_item++) { + ofm1 = work_item/(Cb*R*S); + ifm1 = (work_item%(Cb*R*S))/(R*S); + kj = ((work_item%(Cb*R*S))%(R*S))/S; + ki = ((work_item%(Cb*R*S))%(R*S))%S; + oi = 0; + oj = 0; + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; + img = 0; + img_block_size = handle->desc.N; + + if (kj == 0) { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 1; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, ofm1, oj + j_br, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ifm1, ij + j_br * handle->desc.u, ii, 0, handle->blocksifm, IFHP, IFWP, handle->ifmblock); + ind++; + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); + } else if (ki == 0) { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, ofm1, oj + j_br, oi + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ifm1, ij + j_br * handle->desc.u, ii + 1, 0, handle->blocksifm, IFHP, IFWP, handle->ifmblock); + ind++; + } + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, ofm1, oj + j_br, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ifm1, ij + j_br * handle->desc.u, ii, 0, handle->blocksifm, IFHP, IFWP, handle->ifmblock); + ind++; + } + } + n_blocks = ind; + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); + } else { + if (kj == handle->desc.R-1) { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb-1; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, ofm1, oj + j_br, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ifm1, ij + j_br * handle->desc.u, ii, 0, handle->blocksifm, IFHP, IFWP, handle->ifmblock); + ind++; + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); + } else { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, ofm1, oj + j_br, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ifm1, ij + j_br * handle->desc.u, ii, 0, handle->blocksifm, IFHP, IFWP, handle->ifmblock); + ind++; + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); + } + } + } + } + } else { + /* Here we are using batch-reduce kernel and hybrid minibatch/FM parallelization */ + /* FIXME: Hardcoed logic for N=27 */ + int group_size = (handle->desc.threads == 27 && handle->desc.N == 27 && handle->ofw == 14 && handle->desc.R == 1 && handle->desc.u == 1 && ltid >= 24) ? 3 : LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies); + int tile_id = ltid / LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies); + int tiles = handle->weight_copies; + int img_per_tile = LIBXSMM_UPDIV(handle->desc.N, tiles); + int my_in_tile_id = ltid % group_size; + int ifms_per_thread = LIBXSMM_UPDIV(handle->blocksifm, group_size); + int ofms_per_thread = LIBXSMM_UPDIV(handle->blocksofm, group_size); + int my_R_start = 0; + int my_R_end = handle->desc.R; + const float beta = ((handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw)) ? 0.f : 1.f; + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + const float beta_flat = 0.0; + gemm_br_function br_gemm_kernel_flat = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb, &LDA, &LDB, &LDC, NULL, &beta_flat, &l_flags, &prefetch_mode); + element_filter_type *weight_ptr_group = (handle->weight_copies > 1) ? (element_filter_type*) ((char*)handle->scratch + handle->upd_filter_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data; + LIBXSMM_VLA_DECL(6, element_filter_type, weight_private_group, (element_filter_type*)weight_ptr_group, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + my_img_start = LIBXSMM_MIN(tile_id * img_per_tile, handle->desc.N); + my_img_end = LIBXSMM_MIN((tile_id+1) * img_per_tile, handle->desc.N); + my_ifm_start = LIBXSMM_MIN(my_in_tile_id * ifms_per_thread, handle->blocksifm ); + my_ifm_end = LIBXSMM_MIN((my_in_tile_id+1) * ifms_per_thread, handle->blocksifm ); + my_ofm_start = 0; + my_ofm_end = handle->blocksofm; + /* FIXME: Hardcoed logic for N=27 */ + if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.C == 256 && handle->desc.K == 1024 && handle->ofh == 14 && handle->desc.u == 1) { + my_ofm_start = LIBXSMM_MIN(my_in_tile_id * ofms_per_thread, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((my_in_tile_id+1) * ofms_per_thread, handle->blocksofm); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + } + if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.R == 3 && handle->desc.S == 3 && handle->ofh == 14) { + int r_per_tile = LIBXSMM_UPDIV(handle->desc.R, group_size); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + my_ofm_start = 0; + my_ofm_end = handle->blocksofm; + my_R_start = LIBXSMM_MIN(my_in_tile_id * r_per_tile, handle->desc.R); + my_R_end = LIBXSMM_MIN((my_in_tile_id+1) * r_per_tile, handle->desc.R); + } + if (handle->desc.threads == 92 && handle->desc.N == 92 && handle->desc.C == 512 && handle->desc.K == 512 && handle->ofh == 7 && handle->desc.u == 1 && handle->desc.R == 3) { + my_ofm_start = LIBXSMM_MIN(my_in_tile_id * ofms_per_thread, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((my_in_tile_id+1) * ofms_per_thread, handle->blocksofm); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + } + block_ofm = my_ofm_end-my_ofm_start+1; + block_ifm = my_ifm_end-my_ifm_start+1; + img_block_size = my_img_end - my_img_start; + + if (handle->desc.N != handle->desc.threads) { + /* Use "flat" parallelism + reduction */ + const int work = handle->desc.R * handle->desc.S * handle->blocksofm * handle->blocksifm * handle->desc.N; + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : (work / handle->desc.threads) + 1; + const int work_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int work_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int work_item; + int Cb = handle->blocksifm; + int Kb = handle->blocksofm; + int R = handle->desc.R; + int S = handle->desc.S; + const int IFH = (handle->upd_pack_input == 1) ? handle->ifhp/handle->desc.u : IFHP; + const int IFW = (handle->upd_pack_input == 1) ? handle->ifwp/handle->desc.v : IFWP; + element_input_type *input_ptr_base = (handle->upd_pack_input == 1) ? (element_input_type*) ((char*)handle->scratch + handle->upd_packing_padding_scratch_offset) : (element_input_type*)input_ptr_to_use; + LIBXSMM_VLA_DECL(5, element_input_type, input_use, (element_input_type*)input_ptr_base, handle->blocksifm, IFH, IFW, handle->ifmblock); + + /* If requested, pack input to avoid strided accesses */ + if (handle->upd_pack_input == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + const int img_chunk = (handle->desc.N % handle->desc.threads == 0) ? handle->desc.N/handle->desc.threads : (handle->desc.N/handle->desc.threads) + 1; + const int img_copy_start = LIBXSMM_MIN(ltid*img_chunk, handle->desc.N); + const int img_copy_end = LIBXSMM_MIN((ltid+1)*img_chunk, handle->desc.N); + + for (img = img_copy_start; img < img_copy_end; img++) { + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij = oj * handle->desc.u; + ii = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input_use, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + libxsmm_barrier_wait(handle->barrier, ltid); + } + + /* Initialize weights to zero */ + if (handle->upd_ofw_rb != handle->ofw) { + for (work_item = work_begin; work_item < work_end; work_item++) { + img = work_item/(Cb*Kb*R*S); + ofm1 = (work_item%(Cb*Kb*R*S))/(Cb*R*S); + ifm1 = ((work_item%(Cb*Kb*R*S))%(Cb*R*S))/(R*S); + kj = (((work_item%(Cb*Kb*R*S))%(Cb*R*S))%(R*S))/S; + ki = (((work_item%(Cb*Kb*R*S))%(Cb*R*S))%(R*S))%S; + { + element_filter_type *weight_ptr_current = (handle->weight_copies > 1) ? (element_filter_type*) ((char*)handle->scratch + handle->upd_filter_scratch_offset) + img * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data; + LIBXSMM_VLA_DECL(6, element_filter_type, weight_current, (element_filter_type*)weight_ptr_current, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(6, weight_current, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) = (element_filter_type)0; + } + } + } + } + } + + for (work_item = work_begin; work_item < work_end; work_item++) { + img = work_item/(Cb*Kb*R*S); + ofm1 = (work_item%(Cb*Kb*R*S))/(Cb*R*S); + ifm1 = ((work_item%(Cb*Kb*R*S))%(Cb*R*S))/(R*S); + kj = (((work_item%(Cb*Kb*R*S))%(Cb*R*S))%(R*S))/S; + ki = (((work_item%(Cb*Kb*R*S))%(Cb*R*S))%(R*S))%S; + ii = 0 + ki; + ij = 0 + kj; + oj = 0; + oi = 0; + { + element_filter_type *weight_ptr_current = (handle->weight_copies > 1) ? (element_filter_type*) ((char*)handle->scratch + handle->upd_filter_scratch_offset) + img * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data; + LIBXSMM_VLA_DECL(6, element_filter_type, weight_current, (element_filter_type*)weight_ptr_current, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + ind = 0; + for (j_br = 0; j_br < handle->ofh; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img , ofm1, oj + j_br, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input_use, img, ifm1, ij + j_br * handle->desc.u, ii, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); + ind++; + } + n_blocks = ind; + br_gemm_kernel_flat(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_current, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); + } + } + } else { + /* May need to initialized private weights to zero */ + if (!((handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw))) { + for (ofm1 = my_ofm_start; ofm1 < my_ofm_end; ofm1++ ) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++ ) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(6, weight_private_group, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) = (element_filter_type)0; + } + } + } + } + } + } + } + + if (handle->upd_loop_order == 0) { + for (img = my_img_start; img < my_img_end; img += img_block_size) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += block_ofm) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += block_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->upd_ofh_rb) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+block_ofm, my_ofm_end); ofm1++ ) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+block_ifm, my_ifm_end); ifm1++) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->upd_ofh_rb,handle->ofh); oj+= handle->upd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->upd_ofw_rb) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, ofm1, oj + j_br, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ifm1, ij + j_br * handle->desc.u, ii, 0, handle->blocksifm, IFHP, IFWP, handle->ifmblock); + ind++; + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_private_group, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img += img_block_size) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += block_ifm) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += block_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->upd_ofh_rb) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+block_ifm, my_ifm_end); ifm1++) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+block_ofm, my_ofm_end); ofm1++ ) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->upd_ofh_rb,handle->ofh); oj+= handle->upd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->upd_ofw_rb) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, ofm1, oj + j_br, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ifm1, ij + j_br * handle->desc.u, ii, 0, handle->blocksifm, IFHP, IFWP, handle->ifmblock); + ind++; + } + } + n_blocks = ind; + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_private_group, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); + } + } + } + } + } + } + } + } + } + } + } + } + } +} + +if (handle->weight_copies > 1) { + /* reduce work-related variables */ + const int fm_blocking = (handle->ofmblock % 16 == 0) ? 16 : handle->ofmblock; + const int reduce_work = handle->blocksofm * handle->blocksifm * handle->desc.R * handle->desc.S * (handle->ofmblock/fm_blocking) * handle->ifmblock; + const int reduce_chunksize = (reduce_work % handle->desc.threads == 0) ? (reduce_work / handle->desc.threads) : (reduce_work / handle->desc.threads) + 1; + const int reduce_thr_begin = (ltid * reduce_chunksize < reduce_work) ? (ltid * reduce_chunksize) : reduce_work; + const int reduce_thr_end = ((ltid + 1) * reduce_chunksize < reduce_work) ? ((ltid + 1) * reduce_chunksize) : reduce_work; + + /* Perform reduction here */ + libxsmm_barrier_wait(handle->barrier, ltid); + + for ( ij = reduce_thr_begin; ij < reduce_thr_end; ij++ ) { + element_filter_type *weight_ptr_glb = (element_filter_type*) handle->grad_filter->data; +#if 1 + float weight_sum[64]; + int wtcnt = 0; + assert( handle->ofmblock <= 64 ); + + LIBXSMM_PRAGMA_SIMD + for ( wtcnt = 0; wtcnt < fm_blocking; ++wtcnt ) { + weight_sum[wtcnt] = 0.0f; + } + + for ( ii = 0; ii < handle->weight_copies; ii++ ) { + element_filter_type *weight_ptr_src = (element_filter_type*) ((char*)handle->scratch + handle->upd_filter_scratch_offset) + ii * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S + ij * fm_blocking; + LIBXSMM_PRAGMA_SIMD + for ( wtcnt = 0; wtcnt < fm_blocking; ++wtcnt ) { + weight_sum[wtcnt] += weight_ptr_src[wtcnt]; + } + } + + LIBXSMM_PRAGMA_SIMD + for ( wtcnt = 0; wtcnt < fm_blocking; ++wtcnt ) { + weight_ptr_glb[(ij*fm_blocking) + wtcnt] = weight_sum[wtcnt]; + } +#else + __m512 weight_sum = _mm512_setzero_ps(); + for ( ii = 0; ii < handle->weight_copies; ii++ ) { + element_filter_type *weight_ptr_src = (element_filter_type*)handle->scratch7 + ii * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S + ij * 16; + weight_sum = _mm512_add_ps(weight_sum, LIBXSMM_INTRINSICS_MM512_LOAD_PS(weight_ptr_src)); + } + _mm512_storeu_ps(&weight_ptr_glb[ij*16], weight_sum); +#endif + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..8ef6e8e2c608e9c374a5049c2d533be5cee70923 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c @@ -0,0 +1,723 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ + +#define TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1) do {\ + __m512i zero_reg = _mm512_setzero_si512();\ + src_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);\ + tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);\ + for (pixel_pair = 0; pixel_pair < n_full_pixel_pairs; pixel_pair++) {\ + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ + pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\ + pixel_1 = _mm512_loadu_si512(((element_output_type*)src_out+handle->ofmblock+ofm2));\ + ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\ + ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\ + _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\ + _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\ + }\ + src_out += 2* handle->ofmblock;\ + tr_out += 2*handle->ofmblock;\ + }\ + if (half_pixel_pair == 1) {\ + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ + pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\ + pixel_1 = _mm512_setzero_si512();\ + ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\ + ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\ + _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\ + _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\ + }\ + }\ + for (oi = ((handle->compute_pixels+1)/2)*2; oi < handle->output_pixels; oi+=2) {\ + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ + tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, oi/2, ofm2, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);\ + _mm512_storeu_si512((element_output_type*)tr_out, zero_reg);\ + _mm512_storeu_si512((element_output_type*)tr_out+32, zero_reg);\ + }\ + }\ +} while(0) + +#define TRANS_OUTPUT_W_TO_VNNI_FORMAT(img, ofm1, oj, H) do {\ + int h, w_pixel_pair, w_full_pixel_pairs = handle->ofwp/2;\ + for (h=0; hblocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);\ + tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, 0, h, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2);\ + for (w_pixel_pair = 0; w_pixel_pair < w_full_pixel_pairs; w_pixel_pair++) {\ + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ + pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\ + pixel_1 = _mm512_loadu_si512(((element_output_type*)src_out+handle->ofmblock+ofm2));\ + ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\ + ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\ + _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\ + _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\ + }\ + src_out += 2* handle->ofmblock;\ + tr_out += 2*handle->ofmblock;\ + }\ + }\ +} while(0) + +int img, my_img_start, my_img_end, ofmb, ifmb, ofm1, ifm1, ifm2, ofm2, oj, oi, ii, ij, kj, ki, j_br, img_br, i, j, img_block_size = 1, my_ofm_start, my_ofm_end, my_ifm_start, my_ifm_end, block_ofm, block_ifm, pix; +/* computing first logical thread */ +const int ltid = tid - start_thread; + +const int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : handle->ifwp; +const int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : handle->ifhp; +const int OFWP = (handle->upd_padding_copy == 1) ? handle->ofwp + 2*handle->desc.pad_w : handle->ofwp; +const int OFHP = (handle->upd_padding_copy == 1) ? handle->ofhp + 2*handle->desc.pad_h : handle->ofhp; + +element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +LIBXSMM_VLA_DECL(5, const element_output_type, output, (const element_output_type*)out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +LIBXSMM_VLA_DECL(5, const element_input_type, input, (const element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + +element_filter_type *weight_ptr = (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; + +element_filter_type *filter_dst_ptr = (handle->weight_copies > 1) ? (element_filter_type*)weight_ptr : (element_filter_type*)handle->grad_filter->data; +LIBXSMM_VLA_DECL(7, element_filter_type, weight_dst, (element_filter_type*)filter_dst_ptr, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2); + +/* This intermediate tensor is used when pixels are NOT fully accumulated */ +float *weight_ptr_f32 = (float*) ((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; + +LIBXSMM_VLA_DECL(6, float, weight_private_f32, (float*)weight_ptr_f32, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +/* Accumulation scratch is used when pixels are ully accumulated */ +element_filter_type *filter_scratch = (element_filter_type*)((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + ltid * handle->ofmblock * handle->ifmblock * 2; + +LIBXSMM_VLA_DECL(2, float, filter_tmp, (float*)filter_scratch, handle->ofmblock); + +element_input_type *scratch_tr_input = (element_input_type*)((char*)handle->scratch + handle->upd_lp_input_full_scratch_offset); +element_input_type *zero_ptr_in; +element_output_type *zero_ptr_out; +LIBXSMM_VLA_DECL(4, element_input_type, tr_input, (element_input_type*) scratch_tr_input, handle->blocksifm, handle->ifmblock, handle->input_pixels); +LIBXSMM_VLA_DECL(5, element_input_type, tr_input_2, (element_input_type*) scratch_tr_input, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended); + +element_output_type *scratch_tr_output = (element_input_type*)((char*)handle->scratch + handle->upd_lp_output_full_scratch_offset); +LIBXSMM_VLA_DECL(5, element_output_type, tr_output, (element_output_type*) scratch_tr_output, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); +LIBXSMM_VLA_DECL(6, element_output_type, tr_output_2, (element_output_type*) scratch_tr_output, handle->blocksofm, OFHP, handle->ofwp_extended/2, handle->ofmblock, 2); +#if 0 +element_output_type *out_ptr = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +element_output_type *zero_ptr_out; +#endif + +/* transpose, copy and reduce work-related variables */ +const int reduce_work = (handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S)/16; +const int reduce_chunksize = (reduce_work % handle->desc.threads == 0) ? (reduce_work / handle->desc.threads) : (reduce_work / handle->desc.threads) + 1; +const int reduce_thr_begin = (ltid * reduce_chunksize < reduce_work) ? (ltid * reduce_chunksize) : reduce_work; +const int reduce_thr_end = ((ltid + 1) * reduce_chunksize < reduce_work) ? ((ltid + 1) * reduce_chunksize) : reduce_work; + +const float beta = (handle->use_intermediate_f32_wt_tensor ? 1.f : 0.f); +float *dst_ptr; +gemm_br_function br_gemm_kernel = 0; + +/* These are used for the vnni reformatting of the f32 output */ +__m512i c01 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); +const __m512i perm_index = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + +/* Related to the output transpose */ +int n_full_pixel_pairs = handle->compute_pixels/2, half_pixel_pair = handle->compute_pixels%2, pixel_pair; +element_output_type *tr_out, *src_out; +const __m512i selector = LIBXSMM_INTRINSICS_MM512_SET_EPI16(32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0); +const __m512i offsets_lo = LIBXSMM_INTRINSICS_MM512_SET_EPI16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); +const __m512i offsets_hi = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 31, 30, 30, 29, 29, 28, 28, 27, 27, 26, 26, 25, 25, 24, 24, 23, 23, 22, 22, 21, 21, 20, 20, 19, 19, 18, 18, 17, 17, 16, 16); +const __m512i idx_lo = _mm512_or_epi32(selector, offsets_lo); +const __m512i idx_hi = _mm512_or_epi32(selector, offsets_hi); +__m512i pixel_0, pixel_1, ofms_lo, ofms_hi; + +/* Batch reduce related variables */ +const element_output_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +unsigned long long n_blocks; + +libxsmm_blasint LDA = handle->ofmblock; +libxsmm_blasint LDB = handle->input_pixels; +libxsmm_blasint LDC = handle->ofmblock; +int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); +int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + +const int img_work = handle->desc.N; +const int img_chunksize = (img_work % handle->desc.threads == 0) ? (img_work / handle->desc.threads) : (img_work / handle->desc.threads) + 1; +my_img_start = (ltid * img_chunksize < img_work) ? (ltid * img_chunksize) : img_work; +my_img_end = ((ltid + 1) * img_chunksize < img_work) ? ((ltid + 1) * img_chunksize) : img_work; + +libxsmm_barrier_init(handle->barrier, ltid); + +if (handle->upd_linearized_pixels == 1) { + /* First transpose input and output */ + if (handle->use_hybrid_imgofm_parallelization == 1) { + if (handle->upd_pack_input_upfront == 0) { + for (img = my_img_start; img < my_img_end; img++) { + if (handle->upd_padding_copy == 1) { + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels); + memset(zero_ptr_in, 0, handle->ifmblock * handle->input_pixels * sizeof(element_input_type)); + for (ij = 0; ij < handle->ifhp; ij++) { + for (ii = 0; ii < handle->ifwp; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, (ij + handle->desc.pad_h) * IFWP + (ii + handle->desc.pad_w), handle->blocksifm, handle->ifmblock, handle->input_pixels) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } else { + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels), + handle->ifmblock, handle->ifhp*handle->ifwp, handle->ifmblock, handle->input_pixels ); +#if 0 + for (ij = 0; ij < handle->ifhp; ij++) { + for (ii = 0; ii < handle->ifwp; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, ij * handle->ifwp + ii, handle->blocksifm, handle->ifmblock, handle->input_pixels) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } +#endif + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { +#if 0 + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels); + memset(zero_ptr_in, 0, handle->desc.C * handle->input_pixels * sizeof(element_input_type)); +#endif + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (ij = 0; ij < handle->ifhp/handle->desc.u; ij++) { + transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, ij * (handle->ifwp/handle->desc.v), handle->blocksifm, handle->ifmblock, handle->input_pixels), + handle->ifmblock, handle->ifwp/handle->desc.v, 2*handle->ifmblock, handle->input_pixels ); +#if 0 + for (ii = 0; ii < handle->ifwp/handle->desc.v; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, ij * (handle->ifwp/handle->desc.v) + ii, handle->blocksifm, handle->ifmblock, handle->input_pixels) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, ii*handle->desc.v, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } +#endif + } + } + } + } + + if (handle->upd_padding_copy == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); + memset(zero_ptr_out, 0, handle->ofmblock * handle->output_pixels * sizeof(element_output_type)); + for (oj = 0; oj < handle->ofhp; oj++) { + for (oi = 0; oi < handle->ofwp; oi++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, (oj*OFWP+oi)/2, ofm2, (oj*OFWP+oi)%2, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2) = + LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1); + } + } + } + } +#if 0 + for (img = my_img_start; img < my_img_end; img++) { + zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, 0, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); + memset(zero_ptr_out, 0, handle->desc.K * handle->output_pixels * sizeof(element_output_type)); + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + for (oi = 0; oi < handle->n_used_pixels; oi++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, oi/2, ofm2, oi%2, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2) = + *((element_output_type*)out_ptr + img * handle->blocksofm * handle->ofwp * handle->ofhp * handle->ofmblock + ofm1 * handle->ofwp * handle->ofhp * handle->ofmblock + oi * handle->ofmblock + ofm2); + } + } + } + } +#endif +} else { + if (handle->upd_trans_w_only == 0) { + if (handle->on_the_fly_input_packing == 0) { + for (img = my_img_start; img < my_img_end; img++) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended); + memset(zero_ptr_in, 0, handle->desc.C * handle->ifhp * handle->ifwp_extended * sizeof(element_input_type)); + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (ij = 0; ij < handle->ifhp; ij++) { + for (ii = 0; ii < handle->ifwp; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, tr_input_2, img, ifm1, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + } + for (img = my_img_start; img < my_img_end; img++) { + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { +#if 0 + TRANS_OUTPUT_W_TO_VNNI_FORMAT(img, ofm1, 0, handle->ofh); +#else + for (oj = 0; oj < handle->ofh; oj++) { +#if 0 + zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, 0, 0, 0, handle->blocksofm, OFHP, handle->ofwp_extended/2, handle->ofmblock, 2); + memset(zero_ptr_out, 0, handle->ofmblock * handle->ofwp_extended * sizeof(element_output_type)); +#endif + for (oi = 0; oi < handle->ofw; oi++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, oi/2, ofm2, oi%2, handle->blocksofm, OFHP, handle->ofwp_extended/2, handle->ofmblock, 2) = + LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + } + } + } + if (handle->ofw % 2 == 1) { + for (oj = 0; oj < handle->ofh; oj++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, handle->ofw/2, ofm2, handle->ofw%2, handle->blocksofm, OFHP, handle->ofwp_extended/2, handle->ofmblock, 2) = (element_output_type)0; + } + } + } +#endif + } + } + } +} + +/* Make sure we initialize intermediate weights to zero */ +if (handle->use_intermediate_f32_wt_tensor == 1 && handle->use_hybrid_imgofm_parallelization == 0) { + memset(weight_ptr_f32, 0, handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S * sizeof(float)); +} + +if (handle->upd_linearized_pixels == 0) { + if (handle->upd_trans_w_only == 1) { + LDA = handle->ofmblock; + LDB = IFHP*handle->ifwp_extended; + LDC = handle->ofmblock; + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + n_blocks = handle->batchreduce_h_pixels; + br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { + for (oj = 0; oj < handle->ofh; oj += handle->batchreduce_h_pixels){ + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { + /* Transpose output block */ + TRANS_OUTPUT_W_TO_VNNI_FORMAT(img, ofm1, oj, handle->batchreduce_h_pixels); + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { + /* Transpose input block */ + for (j=0; j < handle->batchreduce_h_pixels; j++) { + transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj+j, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + (element_input_type*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, j, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), + handle->ifmblock, handle->ifwp_extended, handle->ifmblock, handle->ifhp*handle->ifwp_extended ); + } + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + + /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ + if (handle->use_intermediate_f32_wt_tensor == 1) { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } else { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); + } + + for (j_br = 0; j_br < handle->batchreduce_h_pixels; j_br++) { + A_ptrs[j_br] = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, 0, j_br, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2); + B_ptrs[j_br] = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, j_br, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended); + } + + br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks); + + /* Convert fully accumulated buffer to bf16 weight buffer in case of full accumulation has happened */ + if ((oj + handle->batchreduce_h_pixels >= handle->ofh) && (img == my_img_end - 1)) { + LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); + for (ij = 0; ij < handle->ifmblock; ij+=2) { + for (ii = 0; ii < handle->ofmblock; ii+=16) { + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), + LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)) ); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + } + } + } + } + } + } + } + } + } else { + int fast_trans = (handle->ofw == 112 && handle->desc.v == 2 && handle->ifmblock == 4 && handle->batchreduce_h_pixels == 1) ? 1 : 0; + const __m512i skipper = LIBXSMM_INTRINSICS_MM512_SET_EPI16(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 19, 11, 3, 26, 18, 10, 2, 25, 17, 9, 1, 24, 16, 8, 0); + __m512i p0, p1, p2, p3; + __m256i _p0, _p1, _p2, _p3; + __m256i r0 = LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256(); + __m256i r1 = LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256(); + __m256i r2 = LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256(); + __m256i r3 = LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256(); + LDA = handle->ofmblock; + LDB = IFHP*handle->ifwp_extended; + LDC = handle->ofmblock; + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + n_blocks = handle->batchreduce_h_pixels; + /* Handle case when ofw is odd number... */ + if (handle->ofw % 2 == 1) { + br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw+1, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + } else { + br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + } + + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { + for (oj = 0; oj < handle->ofh; oj += handle->batchreduce_h_pixels){ + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + + /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ + if (handle->use_intermediate_f32_wt_tensor == 1) { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } else { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); + } + + /* Copy the input in such a way that we ignore "w-pixels" based on ki value */ + if (handle->on_the_fly_input_packing == 1) { + if ((fast_trans == 1) && (handle->upd_padding_copy == 0)) { + for (ii = 0; ii < handle->ofw*2; ii+=32) { + p0 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); + p0 = _mm512_permutexvar_epi16(skipper, p0); + _p0 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p0, 0); + p1 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+8+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); + p1 = _mm512_permutexvar_epi16(skipper, p1); + _p1 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p1, 0); + p2 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+16+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); + p2 = _mm512_permutexvar_epi16(skipper, p2); + _p2 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p2, 0); + p3 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+24+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); + p3 = _mm512_permutexvar_epi16(skipper, p3); + _p3 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p3, 0); + + r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p0, 0), 0); + r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p1, 0), 1); + r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p2, 0), 2); + r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p3, 0), 3); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r0); + + r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p0, 1), 0); + r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p1, 1), 1); + r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p2, 1), 2); + r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p3, 1), 3); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 1, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r1); + + r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p0, 2), 0); + r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p1, 2), 1); + r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p2, 2), 2); + r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p3, 2), 3); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 2, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r2); + + r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p0, 3), 0); + r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p1, 3), 1); + r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p2, 3), 2); + r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p3, 3), 3); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 3, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r3); + + } + } else { + if (handle->upd_padding_copy == 1) { + for (ij = 0; ij < handle->batchreduce_h_pixels; ij++) { + for (ii = 0; ii < handle->ofw; ii++) { + int j_pixel = (oj+ij)*handle->desc.u+kj; + int i_pixel = ii*handle->desc.v+ki; + if ( (j_pixel >= handle->desc.pad_h) && (i_pixel >= handle->desc.pad_w) && (j_pixel < handle->ifhp+handle->desc.pad_h) && (i_pixel < handle->ifwp+handle->desc.pad_w) ) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, (oj+ij)*handle->desc.u+kj-handle->desc.pad_h, ii*handle->desc.v+ki-handle->desc.pad_w, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } else { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended) = (element_input_type)0; + } + } + } + } + } else { + for (ij = 0; ij < handle->batchreduce_h_pixels; ij++) { + for (ii = 0; ii < handle->ofw; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, (oj+ij)*handle->desc.u+kj, ii*handle->desc.v+ki, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + } + + for (j_br = 0; j_br < handle->batchreduce_h_pixels; j_br++) { + A_ptrs[j_br] = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj+j_br, 0, 0, 0, handle->blocksofm, OFHP, handle->ofwp_extended/2, handle->ofmblock, 2); + B_ptrs[j_br] = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, j_br, 0, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended); + } + + br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks); + + /* Convert fully accumulated buffer to bf16 weight buffer in case of full accumulation has happened */ + if ((oj + handle->batchreduce_h_pixels >= handle->ofh) && (img == my_img_end - 1)) { + LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); + for (ij = 0; ij < handle->ifmblock; ij+=2) { + for (ii = 0; ii < handle->ofmblock; ii+=16) { + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), + LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock))); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + } + } + } + } + } + } + } + } + } +} else { + LDA = handle->ofmblock; + LDB = handle->input_pixels; + LDC = handle->ofmblock; + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); + + if (handle->use_hybrid_imgofm_parallelization == 1) { + /* Here we are using batch-reduce kernel and hybrid minibatch/FM parallelization */ + /* FIXME: Hardcoed logic for N=27 */ + int group_size = (handle->desc.threads == 27 && handle->desc.N == 27 && handle->ofw == 14 && handle->desc.R == 1 && handle->desc.u == 1 && ltid >= 24) ? 3 : LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies); + int tile_id = ltid / LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies); + int tiles = handle->weight_copies; + int img_per_tile = LIBXSMM_UPDIV(handle->desc.N, tiles); + int my_in_tile_id = ltid % group_size; + int ifms_per_thread = LIBXSMM_UPDIV(handle->blocksifm, group_size); + int ofms_per_thread = LIBXSMM_UPDIV(handle->blocksofm, group_size); + int my_R_start = 0; + int my_R_end = handle->desc.R; + element_filter_type *weight_ptr_group = (handle->weight_copies > 1) ? (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data; + LIBXSMM_VLA_DECL(7, element_filter_type, weight_private_group, (element_filter_type*)weight_ptr_group, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2); + /* This intermediate tensor is used when pixels are NOT fully accumulated */ + float *weight_tile_ptr_f32 = (float*)((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; + LIBXSMM_VLA_DECL(6, float, weight_private_tile_f32, (float*)weight_tile_ptr_f32, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + + my_img_start = LIBXSMM_MIN(tile_id * img_per_tile, handle->desc.N); + my_img_end = LIBXSMM_MIN((tile_id+1) * img_per_tile, handle->desc.N); + my_ifm_start = LIBXSMM_MIN(my_in_tile_id * ifms_per_thread, handle->blocksifm ); + my_ifm_end = LIBXSMM_MIN((my_in_tile_id+1) * ifms_per_thread, handle->blocksifm ); + my_ofm_start = 0; + my_ofm_end = handle->blocksofm; + /* FIXME: Hardcoed logic for N=27 */ + if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.C == 256 && handle->desc.K == 1024 && handle->ofh == 14 && handle->desc.u == 1) { + my_ofm_start = LIBXSMM_MIN(my_in_tile_id * ofms_per_thread, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((my_in_tile_id+1) * ofms_per_thread, handle->blocksofm); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + } + if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.R == 3 && handle->desc.S == 3 && handle->ofh == 14) { + int r_per_tile = LIBXSMM_UPDIV(handle->desc.R, group_size); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + my_ofm_start = 0; + my_ofm_end = handle->blocksofm; + my_R_start = LIBXSMM_MIN(my_in_tile_id * r_per_tile, handle->desc.R); + my_R_end = LIBXSMM_MIN((my_in_tile_id+1) * r_per_tile, handle->desc.R); + } + block_ofm = my_ofm_end-my_ofm_start+1; + block_ifm = my_ifm_end-my_ifm_start+1; + img_block_size = my_img_end - my_img_start; + + br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + n_blocks = img_block_size; + + /* Make sure we initialize intermediate weights to zero */ + if (handle->use_intermediate_f32_wt_tensor == 1) { + for (ofm1 = my_ofm_start; ofm1 < my_ofm_end; ofm1++ ) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + memset((float*)&LIBXSMM_VLA_ACCESS(6, weight_private_tile_f32, ofm1, ifm1, kj, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), 0, handle->ofmblock * handle->ifmblock * handle->desc.S * sizeof(float)); + } + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + for (img = my_img_start; img < my_img_end; img += img_block_size) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += block_ofm) { + for (pix = 0; pix < handle->n_used_pixels; pix += handle->pixel_blocking){ + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += block_ifm) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+block_ofm, my_ofm_end); ofm1++ ) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+block_ifm, my_ifm_end); ifm1++) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + + /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ + if (handle->use_intermediate_f32_wt_tensor == 1) { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_tile_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } else { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); + } + + for (img_br = 0; img_br < img_block_size; img_br++) { + A_ptrs[img_br] = &LIBXSMM_VLA_ACCESS(5, tr_output, img + img_br, ofm1, pix/2, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); + B_ptrs[img_br] = &LIBXSMM_VLA_ACCESS(4, tr_input, img + img_br, ifm1, 0, pix + kj * IFWP + ki, handle->blocksifm, handle->ifmblock, handle->input_pixels); + } + + br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks); + + /* Convert fully caccumulated buffer to bf16 weight buffer in case of full accumulation has happened */ + if ((pix + handle->pixel_blocking >= handle->n_used_pixels) && (img == my_img_end - img_block_size)) { + LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); + for (ij = 0; ij < handle->ifmblock; ij+=2) { + for (ii = 0; ii < handle->ofmblock; ii+=16) { + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), + LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)) ); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_private_group, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + } + } + } + } + } + } + } + } + + } else { + gemm_function gemm_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { + for (pix = 0; pix < handle->n_used_pixels; pix += handle->pixel_blocking){ + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { + /* Transpose output block */ + if (pix == 0 && ifmb == 0) { + if (handle->upd_padding_copy == 1) { + zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); + memset(zero_ptr_out, 0, handle->ofmblock * handle->output_pixels * sizeof(element_output_type)); + for (oj = 0; oj < handle->ofhp; oj++) { + for (oi = 0; oi < handle->ofwp; oi++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, (oj*OFWP+oi)/2, ofm2, (oj*OFWP+oi)%2, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2) = + LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + } + } + } + } else { + TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1); + } + } + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { + /* Transpose input block */ + if (pix == 0 && ofmb == 0 && ofm1 == 0) { + if (handle->upd_padding_copy == 1) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels); + memset(zero_ptr_in, 0, handle->ifmblock * handle->input_pixels * sizeof(element_input_type)); + for (ij = 0; ij < handle->ifhp; ij++) { + for (ii = 0; ii < handle->ifwp; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, (ij + handle->desc.pad_h) * IFWP + (ii + handle->desc.pad_w), handle->blocksifm, handle->ifmblock, handle->input_pixels) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } else { + if (handle->upd_pack_input_upfront == 0) { + transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels), + handle->ifmblock, handle->ifhp*handle->ifwp, handle->ifmblock, handle->input_pixels ); + } else { + for (ij = 0; ij < handle->ifhp/handle->desc.u; ij++) { + transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, ij * (handle->ifwp/handle->desc.v), handle->blocksifm, handle->ifmblock, handle->input_pixels), + handle->ifmblock, handle->ifwp/handle->desc.v, 2*handle->ifmblock, handle->input_pixels ); + } + } + } + } + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + + /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ + if (handle->use_intermediate_f32_wt_tensor == 1) { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } else { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); + } + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, pix/2, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2), + &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, pix + kj * IFWP + ki, handle->blocksifm, handle->ifmblock, handle->input_pixels), + dst_ptr); + + /* Convert fully accumulated buffer to bf16 weight buffer in case of full accumulation has happened */ + if ((pix + handle->pixel_blocking >= handle->n_used_pixels) && (img == my_img_end - 1)) { + LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); + for (ij = 0; ij < handle->ifmblock; ij+=2) { + for (ii = 0; ii < handle->ofmblock; ii+=16) { + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), + LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)) ); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + } + } + } + } + } + } + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + +if (handle->weight_copies > 1) { + int active_copies = handle->weight_copies; + const int filter_size = handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K; + LIBXSMM_VLA_DECL(2, element_filter_type, weight_copies_buffer, (element_filter_type*) ((char*)handle->scratch + handle->upd_filter_scratch_offset), filter_size); + element_filter_type *weight_global_ptr = (element_filter_type*) handle->grad_filter->data; + + /* In this case calculate how many weight copies have been indeed computed */ + if (handle->desc.N != handle->desc.threads) { + active_copies = 1; + while (active_copies * img_chunksize < handle->desc.N) { + active_copies++; + } + } + + for ( j = reduce_thr_begin; j < reduce_thr_end; j++) { + __m512 weight_sum = _mm512_setzero_ps(); + for ( i = 0; i < active_copies; i++ ) { + weight_sum = _mm512_add_ps(weight_sum, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, weight_copies_buffer, i, j*16, filter_size)))); + } + _mm256_storeu_si256((__m256i*)(((libxsmm_bfloat16*) weight_global_ptr) + j*16), LIBXSMM_INTRINSICS_MM512_CVT_FP32_BF16(weight_sum)); + } + libxsmm_barrier_wait(handle->barrier, ltid); +} + +#undef TRANS_OUTPUT_W_TO_VNNI_FORMAT +#undef TRANS_OUTPUT_TO_VNNI_FORMAT diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..eb7f7d973197fe267f2e50d789271e8146f85c84 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c @@ -0,0 +1,783 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#define TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1) do {\ + __m512i zero_reg = _mm512_setzero_si512();\ + src_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);\ + tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);\ + for (pixel_pair = 0; pixel_pair < n_full_pixel_pairs; pixel_pair++) {\ + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ + pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\ + pixel_1 = _mm512_loadu_si512(((element_output_type*)src_out+handle->ofmblock+ofm2));\ + ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\ + ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\ + _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\ + _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\ + }\ + src_out += 2* handle->ofmblock;\ + tr_out += 2*handle->ofmblock;\ + }\ + if (half_pixel_pair == 1) {\ + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ + pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\ + pixel_1 = _mm512_setzero_si512();\ + ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\ + ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\ + _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\ + _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\ + }\ + tr_out += 2*handle->ofmblock;\ + } \ + for (oi = (n_full_pixel_pairs+half_pixel_pair)*2; oi < handle->output_pixels; oi+=2) {\ + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ + _mm512_storeu_si512((element_output_type*)tr_out+ofm2*2, zero_reg);\ + _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, zero_reg);\ + } \ + tr_out += 2*handle->ofmblock;\ + }\ +}while(0) + +#define TRANS_INPUT(img, ifm1) do {\ + transpose_input_pixels_bf16((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock),(element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels), handle->ifmblock, handle->ifhp*handle->ifwp, handle->ifmblock, handle->input_pixels);\ + if (handle->input_pixels - handle->ifhp*handle->ifwp > 0) {\ + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) {\ + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, handle->ifhp * handle->ifwp, handle->blocksifm, handle->ifmblock, handle->input_pixels);\ + memset(zero_ptr_in, 0, (handle->input_pixels - handle->ifhp * handle->ifwp)*sizeof(element_input_type));\ + }\ + }\ +} while(0) + +int img, my_img_start, my_img_end, ofmb, ifmb, ofm1, ifm1, ifm2, ofm2, oj, oi, ii, ij, kj, ki, /*j_br, img_br,*/ i, j, img_block_size = 1, my_ofm_start, my_ofm_end, my_ifm_start, my_ifm_end, block_ofm, block_ifm, pix; +/* computing first logical thread */ +const int ltid = tid - start_thread; + +const int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : handle->ifwp; +const int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : handle->ifhp; +const int OFWP = (handle->upd_padding_copy == 1) ? handle->ofwp + 2*handle->desc.pad_w : handle->ofwp; +const int OFHP = (handle->upd_padding_copy == 1) ? handle->ofhp + 2*handle->desc.pad_h : handle->ofhp; + +element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +LIBXSMM_VLA_DECL(5, const element_output_type, output, (const element_output_type*)out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); +LIBXSMM_VLA_DECL(5, const element_input_type, input, (const element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + +element_filter_type *weight_ptr = (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; +element_filter_type *filter_dst_ptr = (handle->weight_copies > 1) ? (element_filter_type*)weight_ptr : (element_filter_type*)handle->grad_filter->data; +LIBXSMM_VLA_DECL(7, element_filter_type, weight_dst, (element_filter_type*)filter_dst_ptr, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2); + +/* This intermediate tensor is used when pixels are NOT fully accumulated */ +float *weight_ptr_f32 = (float*)((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; +LIBXSMM_VLA_DECL(6, float, weight_private_f32, (float*)weight_ptr_f32, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +/* Accumulation scratch is used when pixels are ully accumulated */ +element_filter_type *filter_scratch = (element_filter_type*)((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + ltid * handle->ofmblock * handle->ifmblock * 2; +LIBXSMM_VLA_DECL(2, float, filter_tmp, (float*)filter_scratch, handle->ofmblock); + +element_input_type *scratch_tr_input = (element_input_type*)((char*)handle->scratch + handle->upd_lp_input_full_scratch_offset); +element_input_type *zero_ptr_in; +LIBXSMM_VLA_DECL(4, element_input_type, tr_input, (element_input_type*) scratch_tr_input, handle->blocksifm, handle->ifmblock, handle->input_pixels); +LIBXSMM_VLA_DECL(5, element_input_type, tr_input_2, (element_input_type*) scratch_tr_input, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended); +LIBXSMM_VLA_DECL(3, element_input_type, tr_input_3, (element_input_type*) scratch_tr_input, handle->ifmblock, handle->input_pixels); + +element_output_type *scratch_tr_output = (element_input_type*)((char*)handle->scratch + handle->upd_lp_output_full_scratch_offset); +LIBXSMM_VLA_DECL(5, element_output_type, tr_output, (element_output_type*) scratch_tr_output, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); +LIBXSMM_VLA_DECL(6, element_output_type, tr_output_2, (element_output_type*) scratch_tr_output, handle->blocksofm, OFHP, handle->ofwp_extended/2, handle->ofmblock, 2); +LIBXSMM_VLA_DECL(4, element_output_type, tr_output_3, (element_output_type*) scratch_tr_output, handle->output_pixels/2, handle->ofmblock, 2); + +element_output_type *out_ptr = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; +element_output_type *zero_ptr_out; + +/* transpose, copy and reduce work-related variables */ +const int reduce_work = (handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S)/16 ; +const int reduce_chunksize = (reduce_work % handle->desc.threads == 0) ? (reduce_work / handle->desc.threads) : (reduce_work / handle->desc.threads) + 1; +const int reduce_thr_begin = (ltid * reduce_chunksize < reduce_work) ? (ltid * reduce_chunksize) : reduce_work; +const int reduce_thr_end = ((ltid + 1) * reduce_chunksize < reduce_work) ? ((ltid + 1) * reduce_chunksize) : reduce_work; + +#if 0 +const float beta = (handle->use_intermediate_f32_wt_tensor) ? 1.0 : 0.0; +#endif +float *dst_ptr; +#if 0 +gemm_br_function br_gemm_kernel = 0; +#endif + +/* These are used for the vnni reformatting of the f32 output */ +__m512i c01; +const __m512i perm_index = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + +/* Related to the output transpose */ +int n_full_pixel_pairs = handle->compute_pixels/2, half_pixel_pair = handle->compute_pixels%2, pixel_pair; +element_output_type *tr_out, *src_out; +const __m512i selector = LIBXSMM_INTRINSICS_MM512_SET_EPI16(32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0); +const __m512i offsets_lo = LIBXSMM_INTRINSICS_MM512_SET_EPI16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); +const __m512i offsets_hi = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 31, 30, 30, 29, 29, 28, 28, 27, 27, 26, 26, 25, 25, 24, 24, 23, 23, 22, 22, 21, 21, 20, 20, 19, 19, 18, 18, 17, 17, 16, 16); +const __m512i idx_lo = _mm512_or_epi32(selector, offsets_lo); +const __m512i idx_hi = _mm512_or_epi32(selector, offsets_hi); +__m512i pixel_0, pixel_1, ofms_lo, ofms_hi; + +/* Batch reduce related variables */ +#if 0 +const element_output_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +#endif +unsigned long long n_blocks; + +#if 0 +int LDA = handle->ofmblock; +int LDB = handle->input_pixels; +int LDC = handle->ofmblock; +int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); +int l_flags = (LIBXSMM_GEMM_FLAGS('N', 'N')) | LIBXSMM_GEMM_FLAG_EXCLUDE_TILECONFIG; +int l_tc_flags = LIBXSMM_GEMM_FLAG_ONLY_TILECONFIG; +gemm_function tile_config_kernel = 0; +#endif + +const int img_work = handle->desc.N; +const int img_chunksize = (img_work % handle->desc.threads == 0) ? (img_work / handle->desc.threads) : (img_work / handle->desc.threads) + 1; + +/* select kernel */ +if (handle->upd_linearized_pixels == 0) { + br_gemm_kernel = handle->upd_compute_kernel_brgemm_no_linearized_pixels; + gemm_kernel = handle->upd_compute_kernel_gemm_linearized_pixels_no_hybrid_par; /* @TODO: ci check */ +} else { + if (handle->use_hybrid_imgofm_parallelization == 0) { + gemm_kernel = handle->upd_compute_kernel_gemm_linearized_pixels_no_hybrid_par; + br_gemm_kernel = handle->upd_compute_kernel_brgemm_no_linearized_pixels; /* @TODO: ci check */ + } else { +#if 0 /* if/else branches with same outcome */ + if (handle->pack_to_cnhw == 1) +#endif + { + gemm_kernel = handle->upd_compute_kernel_gemm_linearized_pixels_hybrid_par_cnhw; + br_gemm_kernel = handle->upd_compute_kernel_brgemm_linearized_pixels_hybrid_par_no_cnhw; /* @TODO: ci check */ + } +#if 0 /* if/else branches with same outcome */ + else { + gemm_kernel = handle->upd_compute_kernel_gemm_linearized_pixels_hybrid_par_cnhw; /* @TODO: ci check */ + br_gemm_kernel = handle->upd_compute_kernel_brgemm_linearized_pixels_hybrid_par_no_cnhw; + } +#endif + } +} + +my_img_start = (ltid * img_chunksize < img_work) ? (ltid * img_chunksize) : img_work; +my_img_end = ((ltid + 1) * img_chunksize < img_work) ? ((ltid + 1) * img_chunksize) : img_work; + +libxsmm_barrier_init(handle->barrier, ltid); + +if (handle->upd_linearized_pixels == 1) { + /* First transpose input and output */ + if (handle->pack_to_cnhw == 0) { + if (handle->fuse_upd_transposes == 0) { + if (handle->upd_pack_input_upfront == 0) { + if (handle->upd_padding_copy == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels); + memset(zero_ptr_in, 0, handle->ifmblock * handle->input_pixels * sizeof(element_input_type)); + for (ij = 0; ij < handle->ifhp; ij++) { + for (ii = 0; ii < handle->ifwp; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, (ij + handle->desc.pad_h) * IFWP + (ii + handle->desc.pad_w), handle->blocksifm, handle->ifmblock, handle->input_pixels) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + } else { + if (handle->ifmblock % 32 == 0) { + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + TRANS_INPUT(img, ifm1); + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels); + memset(zero_ptr_in, 0, handle->desc.C * handle->input_pixels * sizeof(element_input_type)); + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (ij = 0; ij < handle->ifhp; ij++) { + for (ii = 0; ii < handle->ifwp; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, ij * handle->ifwp + ii, handle->blocksifm, handle->ifmblock, handle->input_pixels) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels); + memset(zero_ptr_in, 0, handle->desc.C * handle->input_pixels * sizeof(element_input_type)); + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (ij = 0; ij < handle->ifhp/handle->desc.u; ij++) { + for (ii = 0; ii < handle->ifwp/handle->desc.v; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, ij * (handle->ifwp/handle->desc.v) + ii, handle->blocksifm, handle->ifmblock, handle->input_pixels) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, ii*handle->desc.v, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + } + + /* Reformat output */ + if (handle->upd_padding_copy == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); + memset(zero_ptr_out, 0, handle->ofmblock * handle->output_pixels * sizeof(element_output_type)); + for (oj = 0; oj < handle->ofhp; oj++) { + for (oi = 0; oi < handle->ofwp; oi++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, (oj*OFWP+oi)/2, ofm2, (oj*OFWP+oi)%2, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2) = + LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + } + } + } + } + } + } else { + if (handle->ofmblock % 32 == 0) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1); + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, 0, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); + memset(zero_ptr_out, 0, handle->desc.K * handle->output_pixels * sizeof(element_output_type)); + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + for (oi = 0; oi < handle->compute_pixels; oi++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, oi/2, ofm2, oi%2, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2) = + *((element_output_type*)out_ptr + img * handle->blocksofm * handle->ofwp * handle->ofhp * handle->ofmblock + ofm1 * handle->ofwp * handle->ofhp * handle->ofmblock + oi * handle->ofmblock + ofm2); + } + } + } + } + } + } + } + } else { + int img_tile_id, img_in_tile, init_offset, /*pix_id,*/ images_in_tile = handle->desc.N/handle->weight_copies; + /* Zero out the input padding pixels */ + for (img = my_img_start; img < my_img_end; img++) { + img_tile_id = img/images_in_tile; + img_in_tile = img%images_in_tile; + if (img_in_tile == images_in_tile-1) { + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(3, tr_input_3, ifm1, ifm2, img_tile_id * handle->pixel_blocking + images_in_tile * (handle->ifhp/handle->desc.u) * (handle->ifwp/handle->desc.v), handle->ifmblock, handle->input_pixels); + memset(zero_ptr_in, 0, handle->remainder_pixels * sizeof(element_input_type)); + } + } + } + } + + if ((handle->ifmblock % 32 == 0) && (handle->desc.u == 1) && (handle->desc.v == 1)) { + for (img = my_img_start; img < my_img_end; img++) { + img_tile_id = img/images_in_tile; + img_in_tile = img%images_in_tile; + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + transpose_input_pixels_bf16((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), + (element_input_type*)&LIBXSMM_VLA_ACCESS(3, tr_input_3, ifm1, 0, img_tile_id * handle->pixel_blocking + img_in_tile * handle->ifhp * handle->ifwp, handle->ifmblock, handle->input_pixels) , + handle->ifmblock, handle->ifhp*handle->ifwp, handle->ifmblock, handle->input_pixels); + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + img_tile_id = img/images_in_tile; + img_in_tile = img%images_in_tile; + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (ij = 0; ij < handle->ifhp/handle->desc.u; ij++) { + for (ii = 0; ii < handle->ifwp/handle->desc.v; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(3, tr_input_3, ifm1, ifm2, img_tile_id * handle->pixel_blocking + img_in_tile * (handle->ifhp/handle->desc.u) * (handle->ifwp/handle->desc.v) + ij * (handle->ifwp/handle->desc.v) + ii, handle->ifmblock, handle->input_pixels) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, ii*handle->desc.v, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + } + + /* Zero out the output padding pixels */ + for (img = my_img_start; img < my_img_end; img++) { + img_tile_id = img/images_in_tile; + img_in_tile = img%images_in_tile; + if (img_in_tile == images_in_tile-1) { + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + init_offset = img_tile_id * handle->pixel_blocking + images_in_tile * handle->ofw * handle->ofh; + tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(4, tr_output_3, ofm1, init_offset/2, 0, init_offset%2, handle->output_pixels/2, handle->ofmblock, 2); + memset(tr_out, 0, handle->remainder_pixels * handle->ofmblock * sizeof(element_input_type)); +#if 0 + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + for (oi = 0; oi < handle->remainder_pixels; oi++ ) { + init_offset = img_tile_id * handle->pixel_blocking + images_in_tile * handle->ofw * handle->ofh; + pix_id = init_offset + oi; + LIBXSMM_VLA_ACCESS(4, tr_output_3, ofm1, pix_id/2, ofm2, pix_id%2, handle->output_pixels/2, handle->ofmblock, 2) = (element_output_type)0; + } + } +#endif + } + } + } + + if (handle->ofmblock % 32 == 0) { + int _trans_pixels = handle->ofw*handle->ofh, _n_full_pixel_pairs, _half_pixel_pair, init_pixel_pos; + for (img = my_img_start; img < my_img_end; img++) { + int pix_id; + img_tile_id = img/images_in_tile; + img_in_tile = img%images_in_tile; + pix_id = img_tile_id * handle->pixel_blocking + img_in_tile * handle->ofh * handle->ofw; + /* The first-odd pixel is done with scalar code... */ + if (pix_id % 2 == 1) { + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(4, tr_output_3, ofm1, pix_id/2, ofm2, 1, handle->output_pixels/2, handle->ofmblock, 2) = + LIBXSMM_VLA_ACCESS(5, output, img, ofm1, 0, 0, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + } + } + pix_id += 1; + _trans_pixels--; + init_pixel_pos = 1; + } else { + init_pixel_pos = 0; + } + _n_full_pixel_pairs = _trans_pixels/2; + _half_pixel_pair = _trans_pixels%2; + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + src_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, 0, init_pixel_pos, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(4, tr_output_3, ofm1, pix_id/2, 0, 0, handle->output_pixels/2, handle->ofmblock, 2); + for (pixel_pair = 0; pixel_pair < _n_full_pixel_pairs; pixel_pair++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) { + pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2); + pixel_1 = _mm512_loadu_si512(((element_output_type*)src_out+handle->ofmblock+ofm2)); + ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1); + ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1); + _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo); + _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi); + } + src_out += 2* handle->ofmblock; + tr_out += 2*handle->ofmblock; + } + } + /* The last-odd pixel is done with scalar code... */ + if (_half_pixel_pair == 1) { + pix_id = pix_id + _n_full_pixel_pairs*2; + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(4, tr_output_3, ofm1, pix_id/2, ofm2, pix_id%2, handle->output_pixels/2, handle->ofmblock, 2) = + LIBXSMM_VLA_ACCESS(5, output, img, ofm1, handle->ofh-1, handle->ofw-1, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + img_tile_id = img/images_in_tile; + img_in_tile = img%images_in_tile; + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + int pix_id = img_tile_id * handle->pixel_blocking + img_in_tile * handle->ofh * handle->ofw + oj * handle->ofw + oi; + LIBXSMM_VLA_ACCESS(4, tr_output_3, ofm1, pix_id/2, ofm2, pix_id%2, handle->output_pixels/2, handle->ofmblock, 2) = + LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + } + } + } + } + } + } + } +} else { + if (handle->on_the_fly_input_packing == 0) { + for (img = my_img_start; img < my_img_end; img++) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, 0, 0, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended); + memset(zero_ptr_in, 0, handle->desc.C * handle->ifhp * handle->ifwp_extended * sizeof(element_input_type)); + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (ij = 0; ij < handle->ifhp; ij++) { + for (ii = 0; ii < handle->ifwp; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, tr_input_2, img, ifm1, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img++) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, 0, 0, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended); + memset(zero_ptr_in, 0, handle->desc.C * IFHP * handle->ifwp_extended * sizeof(element_input_type)); + } + } + for (img = my_img_start; img < my_img_end; img++) { + for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, 0, 0, 0, handle->blocksofm, OFHP, handle->ofwp_extended/2, handle->ofmblock, 2); + memset(zero_ptr_out, 0, handle->ofmblock * (handle->ofw+handle->remainder_pixels) * sizeof(element_output_type)); + for (oi = 0; oi < handle->ofw; oi++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, oi/2, ofm2, oi%2, handle->blocksofm, OFHP, handle->ofwp_extended/2, handle->ofmblock, 2) = + LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + } + } + } + } + } +} + +/* Make sure we initialize intermediate weights to zero */ +if (handle->use_intermediate_f32_wt_tensor == 1 && handle->use_hybrid_imgofm_parallelization == 0) { + memset(weight_ptr_f32, 0, handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S * sizeof(float)); +} + +tile_config_kernel(NULL, NULL, NULL); + +if (handle->upd_linearized_pixels == 0) { +#if 0 + LDA = handle->ofmblock; + LDB = handle->ifhp*handle->ifwp_extended; + LDC = handle->ofmblock; + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); + br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw+handle->remainder_pixels, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + tile_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->ofw+handle->remainder_pixels, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL); +#endif + n_blocks = handle->batchreduce_h_pixels; + + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { + for (oj = 0; oj < handle->ofh; oj += handle->batchreduce_h_pixels){ + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + + /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ + if (handle->use_intermediate_f32_wt_tensor == 1) { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } else { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); + } + + /* Copy the input in such a way that we ignore "w-pixels" based on ki value */ + if (handle->on_the_fly_input_packing == 1) { + if (handle->upd_padding_copy == 1) { + for (ij = kj; ij < IFHP; ij+=handle->desc.u) { + for (ii = 0; ii < handle->ofw; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + if ( (ij >= handle->desc.pad_h) && (ii*handle->desc.v+ki >= handle->desc.pad_w) && (ij < handle->ifhp+handle->desc.pad_h) && (ii*handle->desc.v+ki < handle->ifwp+handle->desc.pad_w) ) { + LIBXSMM_VLA_ACCESS(5, tr_input_2, img, ifm1, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij-handle->desc.pad_h, ii*handle->desc.v+ki-handle->desc.pad_w, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } else { + LIBXSMM_VLA_ACCESS(5, tr_input_2, img, ifm1, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended) = (element_input_type)0; + } + } + } + } + } else { + for (ij = 0; ij < handle->ifhp; ij++) { + for (ii = 0; ii < handle->ofw; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, tr_input_2, img, ifm1, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii*handle->desc.v+ki, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } + } + +#if 0 + for (j_br = 0; j_br < handle->batchreduce_h_pixels; j_br++) { + A_ptrs[j_br] = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj+j_br, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2); + B_ptrs[j_br] = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, ifm1, 0, (oj+j_br)*handle->desc.u + kj, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended); + } + br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks); +#endif + br_gemm_kernel( &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, 0, 0, 0, handle->blocksofm, OFHP, handle->ofwp_extended/2, handle->ofmblock, 2), + &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, ifm1, 0, oj*handle->desc.u + kj, 0, handle->blocksifm, handle->ifmblock, IFHP, handle->ifwp_extended), dst_ptr, &n_blocks); + + /* Convert fully caccumulated buffer to bf16 weight buffer in case of full accumulation has happened */ + if (oj + handle->batchreduce_h_pixels >= handle->ofh) { + LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); + for (ij = 0; ij < handle->ifmblock; ij+=2) { + for (ii = 0; ii < handle->ofmblock; ii+=16) { + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index,(__m512i)c01)); + } + } + } + + } + } + } + } + } + } + } + } +} else { +#if 0 + LDA = handle->ofmblock; + LDB = handle->input_pixels; + LDC = handle->ofmblock; + prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); +#endif + if (handle->use_hybrid_imgofm_parallelization == 1) { + /* Here we are using batch-reduce kernel and hybrid minibatch/FM parallelization */ + /* FIXME: Hardcoed logic for N=27 */ + int group_size = (handle->desc.threads == 27 && handle->desc.N == 27 && handle->ofw == 14 && handle->desc.R == 1 && handle->desc.u == 1 && ltid >= 24) ? 3 : ((handle->desc.threads+handle->weight_copies-1)/handle->weight_copies); + int tile_id = ltid/( (handle->desc.threads+handle->weight_copies-1)/handle->weight_copies ); + int tiles = handle->weight_copies; + int img_per_tile = (handle->desc.N+tiles-1)/tiles; + int my_in_tile_id = ltid % group_size; + int ifms_per_thread = (handle->blocksifm+group_size-1)/group_size; + int ofms_per_thread = (handle->blocksofm+group_size-1)/group_size; + int my_R_start = 0; + int my_R_end = handle->desc.R; + element_filter_type *weight_ptr_group = (handle->weight_copies > 1) ? (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data; + LIBXSMM_VLA_DECL(7, element_filter_type, weight_private_group, (element_filter_type*)weight_ptr_group, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2); + /* This intermediate tensor is used when pixels are NOT fully accumulated */ + float *weight_tile_ptr_f32 = (float*)((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; + LIBXSMM_VLA_DECL(6, float, weight_private_tile_f32, (float*)weight_tile_ptr_f32, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + + my_img_start = LIBXSMM_MIN( tile_id * img_per_tile, handle->desc.N); + my_img_end = LIBXSMM_MIN( (tile_id+1) * img_per_tile, handle->desc.N); + my_ifm_start = LIBXSMM_MIN( my_in_tile_id * ifms_per_thread, handle->blocksifm ); + my_ifm_end = LIBXSMM_MIN( (my_in_tile_id+1) * ifms_per_thread, handle->blocksifm ); + my_ofm_start = 0; + my_ofm_end = handle->blocksofm; + /* FIXME: Hardcoed logic for N=27 */ + if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.C == 256 && handle->desc.K == 1024 && handle->ofh == 14 && handle->desc.u == 1) { + my_ofm_start = LIBXSMM_MIN( my_in_tile_id * ofms_per_thread, handle->blocksofm ); + my_ofm_end = LIBXSMM_MIN( (my_in_tile_id+1) * ofms_per_thread, handle->blocksofm ); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + } + if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.R == 3 && handle->desc.S == 3 && handle->ofh == 14) { + int r_per_tile = (handle->desc.R+group_size-1)/group_size; + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + my_ofm_start = 0; + my_ofm_end = handle->blocksofm; + my_R_start = LIBXSMM_MIN( my_in_tile_id * r_per_tile, handle->desc.R ); + my_R_end = LIBXSMM_MIN( (my_in_tile_id+1) * r_per_tile, handle->desc.R ); + } + if (handle->pack_to_cnhw == 1) { + my_ofm_start = LIBXSMM_MIN( my_in_tile_id * ofms_per_thread, handle->blocksofm ); + my_ofm_end = LIBXSMM_MIN( (my_in_tile_id+1) * ofms_per_thread, handle->blocksofm ); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + } + + block_ofm = my_ofm_end-my_ofm_start+1; + block_ifm = my_ifm_end-my_ifm_start+1; + img_block_size = my_img_end - my_img_start; + + /* Make sure we initialize intermediate weights to zero */ + if (handle->use_intermediate_f32_wt_tensor == 1) { + for (ofm1 = my_ofm_start; ofm1 < my_ofm_end; ofm1++ ) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + memset((float*)&LIBXSMM_VLA_ACCESS(6, weight_private_tile_f32, ofm1, ifm1, kj, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), 0, handle->ofmblock * handle->ifmblock * handle->desc.S * sizeof(float)); + } + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + if (handle->pack_to_cnhw == 0) { +#if 0 + br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + tile_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL); +#endif + n_blocks = img_block_size; + + for (img = my_img_start; img < my_img_end; img += img_block_size) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += block_ofm) { + for (pix = 0; pix < handle->n_used_pixels; pix += handle->pixel_blocking){ + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += block_ifm) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+block_ofm, my_ofm_end); ofm1++ ) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+block_ifm, my_ifm_end); ifm1++) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + + /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ + if (handle->use_intermediate_f32_wt_tensor == 1) { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_tile_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } else { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); + } + +#if 0 + for (img_br = 0; img_br < img_block_size; img_br++) { + A_ptrs[img_br] = &LIBXSMM_VLA_ACCESS(5, tr_output, img + img_br, ofm1, pix/2, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); + B_ptrs[img_br] = &LIBXSMM_VLA_ACCESS(4, tr_input, img + img_br, ifm1, 0, pix + kj * handle->ifwp + ki, handle->blocksifm, handle->ifmblock, handle->input_pixels); + } + br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks); +#endif + + br_gemm_kernel( &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, pix/2, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2), + &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, pix + kj * IFWP + ki, handle->blocksifm, handle->ifmblock, handle->input_pixels), + dst_ptr, &n_blocks); + + /* Convert fully caccumulated buffer to bf16 weight buffer in case of full accumulation has happened */ + if (pix + handle->pixel_blocking >= handle->n_used_pixels) { + LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); + for (ij = 0; ij < handle->ifmblock; ij+=2) { + for (ii = 0; ii < handle->ofmblock; ii+=16) { + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(7, weight_private_group, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, (__m512i)c01)); + } + } + } + } + } + } + } + } + } + } + } + } else { +#if 0 + gemm_function gemm_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + tile_config_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_tc_flags, NULL); +#endif + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += block_ofm) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += block_ifm) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+block_ofm, my_ofm_end); ofm1++ ) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+block_ifm, my_ifm_end); ifm1++) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); + gemm_kernel( &LIBXSMM_VLA_ACCESS(4, tr_output_3, ofm1, tile_id * handle->pixel_blocking/2, 0, 0, handle->output_pixels/2, handle->ofmblock, 2), + &LIBXSMM_VLA_ACCESS(3, tr_input_3, ifm1, 0, tile_id * handle->pixel_blocking, handle->ifmblock, handle->input_pixels), + dst_ptr); + /* Convert fully caccumulated buffer to bf16 weight buffer in case of full accumulation has happened */ + { + LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); + for (ij = 0; ij < handle->ifmblock; ij+=2) { + for (ii = 0; ii < handle->ofmblock; ii+=16) { + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(7, weight_private_group, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, (__m512i)c01)); + } + } + } + } + } + } + } + } + } + } + + } else { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { + for (pix = 0; pix < handle->n_used_pixels; pix += handle->pixel_blocking){ + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { + if ((handle->fuse_upd_transposes == 1) && (pix == 0) && (ifmb == 0)) { + /* (img,ofm1) transpose of output */ + if (handle->upd_padding_copy == 1) { + zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); + memset(zero_ptr_out, 0, handle->ofmblock * handle->output_pixels * sizeof(element_output_type)); + for (oj = 0; oj < handle->ofhp; oj++) { + for (oi = 0; oi < handle->ofwp; oi++) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { + LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, (oj*OFWP+oi)/2, ofm2, (oj*OFWP+oi)%2, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2) = + LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); + } + } + } + } else { + TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1); + } + } + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { + if ((handle->fuse_upd_transposes == 1) && (pix == 0) && (ofm1 == 0)) { + /* (img,ifm1) transpose of input */ + if (handle->upd_padding_copy == 1) { + zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels); + memset(zero_ptr_in, 0, handle->ifmblock * handle->input_pixels * sizeof(element_input_type)); + for (ij = 0; ij < handle->ifhp; ij++) { + for (ii = 0; ii < handle->ifwp; ii++) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, (ij + handle->desc.pad_h) * IFWP + (ii + handle->desc.pad_w), handle->blocksifm, handle->ifmblock, handle->input_pixels) = + LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); + } + } + } + } else { + TRANS_INPUT(img, ifm1); + } + } + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ + if (handle->use_intermediate_f32_wt_tensor == 1) { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); + } else { + dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); + } + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, pix/2, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2), + &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, pix + kj * IFWP + ki, handle->blocksifm, handle->ifmblock, handle->input_pixels), + dst_ptr); + /* Convert fully caccumulated buffer to bf16 weight buffer in case of full accumulation has happened */ + if (pix + handle->pixel_blocking >= handle->n_used_pixels) { + LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); + for (ij = 0; ij < handle->ifmblock; ij+=2) { + for (ii = 0; ii < handle->ofmblock; ii+=16) { + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, (__m512i)c01)); + } + } + } + + } + } + } + } + } + } + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + +if (handle->weight_copies > 1) { + const int filter_size = handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K; + LIBXSMM_VLA_DECL(2, element_filter_type, weight_copies_buffer, (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset), filter_size); + element_filter_type *weight_global_ptr = (element_filter_type*) handle->grad_filter->data; + for ( j = reduce_thr_begin; j < reduce_thr_end; j++) { + __m512 weight_sum = _mm512_setzero_ps(); + for ( i = 0; i < handle->weight_copies; i++ ) { + weight_sum = _mm512_add_ps(weight_sum, _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(2, weight_copies_buffer, i, j*16, filter_size))); + } + _mm512_streamstorecvt_fp32_bf16( ((libxsmm_bfloat16*) weight_global_ptr) + j*16, weight_sum); + } + libxsmm_barrier_wait(handle->barrier, ltid); +} +handle->tilerelease_kernel(NULL, NULL, NULL); + +#undef TRANS_OUTPUT_TO_VNNI_FORMAT +#undef TRANS_INPUT diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..fcc2f533e91e6b60e34373a99f170f163b810315 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c @@ -0,0 +1,675 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +int img, my_img_start, my_img_end, ofmb, ifmb, ojb, ofm1, ifm1, ifm2 = 0, ofm2 = 0, oj, oi, ii, ij, kj, ki, ind, j_br, img_br, img_block_size = 1, my_ofm_start, my_ofm_end, my_ifm_start, my_ifm_end, block_ofm, block_ifm; +/* computing first logical thread */ +const int ltid = tid - start_thread; +libxsmm_blasint LDA = handle->blocksofm * handle->ofmblock; +libxsmm_blasint LDB = (handle->upd_pack_input == 1) ? handle->blocksifm * handle->ifmblock : handle->desc.v * handle->blocksifm * handle->ifmblock; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) +libxsmm_blasint LDC = handle->ofmblock; +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) +libxsmm_blasint LDC = handle->blocksofm * handle->ofmblock; +#endif +int l_flags = LIBXSMM_GEMM_FLAGS('N', 'T'); +element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->blocksofm * handle->ofmblock; +LIBXSMM_VLA_DECL(5, const element_output_type, output, (const element_output_type*)out, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); +const int IFWP = (handle->upd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : handle->ifwp; +const int IFHP = (handle->upd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : handle->ifhp; +element_input_type *input_ptr_to_use = (handle->upd_padding_copy == 1) ? (element_input_type*) ((char*)handle->scratch + handle->upd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type*) input_ptr_to_use, IFHP, IFWP, handle->blocksifm, handle->ifmblock); +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) +LIBXSMM_VLA_DECL(6, element_filter_type, weight_global, (element_filter_type*)handle->grad_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) +LIBXSMM_VLA_DECL(6, element_filter_type, weight_global, (element_filter_type*)handle->grad_filter->data, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif +element_filter_type *weight_ptr = (handle->weight_copies == 1) ? (element_filter_type*)handle->grad_filter->data : (element_filter_type*) ((char*)handle->scratch + handle->upd_filter_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) +LIBXSMM_VLA_DECL(6, element_filter_type, weight_private, (element_filter_type*)weight_ptr, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) +LIBXSMM_VLA_DECL(6, element_filter_type, weight_private, (element_filter_type*)weight_ptr, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif +int prefetch_mode = (handle->desc.u == 2 || (handle->desc.R == 3 && handle->ofw == 7) ) ? libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE) : libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BL1); + +/* Batch reduce related variables */ +const element_output_type *A_ptrs[1024]; +const element_input_type *B_ptrs[1024]; +unsigned long long n_blocks; + +int brgemm_pf_oob = 0; +const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB"); +if ( 0 == env_brgemm_pf_oob ) { +} else { + brgemm_pf_oob = atoi(env_brgemm_pf_oob); +} +if (brgemm_pf_oob > 0) { + prefetch_mode = prefetch_mode | libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB); +} + +libxsmm_barrier_init(handle->barrier, ltid); + +/* physical pad input */ +if (handle->upd_padding_copy == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); + my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); + my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + + for (img = my_img_start; img < my_img_end; img++) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + /* copy the inner part */ + for (ij = 0; ij < handle->ifhp+(2*handle->desc.pad_h); ij++) { + for (ii = 0; ii < handle->ifwp+(2*handle->desc.pad_w); ii++) { + if ( (ij >= handle->desc.pad_h) && (ii >= handle->desc.pad_w) && (ij < handle->ifhp+handle->desc.pad_h) && (ii < handle->ifwp+handle->desc.pad_w) ) { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, ifm2, IFHP, IFWP, handle->blocksifm, handle->ifmblock) = + LIBXSMM_VLA_ACCESS(5, input_src, img, ij-handle->desc.pad_h, ii-handle->desc.pad_w, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + } + } else { + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, ifm2, IFHP, IFWP, handle->blocksifm, handle->ifmblock) = (element_input_type)0; + } + } + } + } + } + } + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if (handle->upd_use_batchreduce == 0 && handle->upd_linearized_tasklist == 0) { + /* Parallelize over minibatch */ + const int img_work = handle->desc.N; + const int img_chunksize = (img_work % handle->desc.threads == 0) ? (img_work / handle->desc.threads) : (img_work / handle->desc.threads) + 1; + const float beta = ((img_chunksize == 1) && (handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw)) ? 0.f : 1.f; + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb * handle->upd_ofh_rb, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + + my_img_start = (ltid * img_chunksize < img_work) ? (ltid * img_chunksize) : img_work; + my_img_end = ((ltid + 1) * img_chunksize < img_work) ? ((ltid + 1) * img_chunksize) : img_work; + + if (!((img_chunksize == 1) && (handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw))) { + memset(weight_ptr, 0, handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S * sizeof(element_filter_type)); + } + + if (handle->upd_loop_order == 0) { + for (img = my_img_start; img < my_img_end; img++) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->upd_ofh_rb) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->upd_ofh_rb,handle->ofh); oj+= handle->upd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->upd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, output, img, oj, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(6, weight_private, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) ); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, output, img, oj, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(6, weight_private, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock) ); +#endif + } + } + } + } + } + } + } + } + } + } + } + if (handle->upd_loop_order == 1) { + for (img = my_img_start; img < my_img_end; img++) { + for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { + for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->upd_ofh_rb) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->upd_ofh_rb,handle->ofh); oj+= handle->upd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->upd_ofw_rb) { + for (kj = 0; kj < handle->desc.R; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, output, img, oj, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(6, weight_private, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) ); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, output, img, oj, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(6, weight_private, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock) ); +#endif + } + } + } + } + } + } + } + } + } + } + } +} else { + if (handle->upd_linearized_tasklist == 1) { + /* Amount of work when using linearized view of tasks */ + const int work = handle->desc.R * handle->desc.S * handle->blocksofm * handle->blocksifm; + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : (work / handle->desc.threads) + 1; + const int work_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int work_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int work_item; + int Cb = handle->blocksifm; +#if 0 + int Kb = handle->blocksofm; +#endif + int R = handle->desc.R; + int S = handle->desc.S; + + if (handle->upd_avoid_rim_fmas == 0) { + const int IFH = (handle->upd_pack_input == 1) ? handle->ifhp/handle->desc.u : IFHP; + const int IFW = (handle->upd_pack_input == 1) ? handle->ifwp/handle->desc.v : IFWP; + element_input_type *input_ptr_base = (handle->upd_pack_input == 1) ? (element_input_type*)((char*)handle->scratch + handle->upd_packing_padding_scratch_offset) : (element_input_type*)input_ptr_to_use; + LIBXSMM_VLA_DECL(5, element_input_type, input_use, (element_input_type*)input_ptr_base, IFH, IFW, handle->blocksifm, handle->ifmblock); + const float beta = ((handle->desc.N == 1) && (handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw)) ? 0.f : 1.f; + gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb * handle->upd_ofh_rb, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + + /* If requested, pack input to avoid strided accesses */ + if (handle->upd_pack_input == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + const int img_chunk = (handle->desc.N % handle->desc.threads == 0) ? handle->desc.N/handle->desc.threads : (handle->desc.N/handle->desc.threads) + 1; + const int img_copy_start = LIBXSMM_MIN(ltid*img_chunk, handle->desc.N); + const int img_copy_end = LIBXSMM_MIN((ltid+1)*img_chunk, handle->desc.N); + + for (img = img_copy_start; img < img_copy_end; img++) { + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij = oj * handle->desc.u; + ii = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input_use, img, oj, oi, ifm1, ifm2, IFH, IFW, handle->blocksifm, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ij, ii, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + } + } + } + } + } + libxsmm_barrier_wait(handle->barrier, ltid); + } + + /* Initialize weights to zero */ + if (!((handle->desc.N == 1) && (handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw))) { + for (work_item = work_begin; work_item < work_end; work_item++) { + ofm1 = work_item/(Cb*R*S); + ifm1 = (work_item%(Cb*R*S))/(R*S); + kj = ((work_item%(Cb*R*S))%(R*S))/S; + ki = ((work_item%(Cb*R*S))%(R*S))%S; + + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) = (element_filter_type)0; +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + LIBXSMM_VLA_ACCESS(6, weight_global, kj, ki, ifm1, ifm2, ofm1, ofm2, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock) = (element_filter_type)0; +#endif + } + } + } + } + + for (img = 0; img < handle->desc.N; img++) { + for (work_item = work_begin; work_item < work_end; work_item++) { + ofm1 = work_item/(Cb*R*S); + ifm1 = (work_item%(Cb*R*S))/(R*S); + kj = ((work_item%(Cb*R*S))%(R*S))/S; + ki = ((work_item%(Cb*R*S))%(R*S))%S; + oi = 0; + ii = ki; + for (oj = 0; oj < handle->ofh; oj += handle->upd_ofh_rb) { + ij = oj * handle->desc.u + kj; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, output, img, oj, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input_use, img, ij, ii, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) ); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + gemm_kernel( &LIBXSMM_VLA_ACCESS(5, output, img, oj, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), + &LIBXSMM_VLA_ACCESS(5, input_use, img, ij, ii, ifm1, 0, IFH, IFW, handle->blocksifm, handle->ifmblock), + &LIBXSMM_VLA_ACCESS(6, weight_global, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock) ); +#endif + } + } + } + } else { + const float beta = ((handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw)) ? 0.f : 1.f; + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + gemm_br_function br_gemm_kernel2 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb-1, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + + for (work_item = work_begin; work_item < work_end; work_item++) { + ofm1 = work_item/(Cb*R*S); + ifm1 = (work_item%(Cb*R*S))/(R*S); + kj = ((work_item%(Cb*R*S))%(R*S))/S; + ki = ((work_item%(Cb*R*S))%(R*S))%S; + oi = 0; + oj = 0; + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; + img = 0; + img_block_size = handle->desc.N; + + if (kj == 0) { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 1; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, oj + j_br, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ij + j_br * handle->desc.u, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock); + ind++; + } + } + n_blocks = ind; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock), &n_blocks); +#endif + } else if (ki == 0) { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, oj + j_br, oi + 1, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ij + j_br * handle->desc.u, ii + 1, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock); + ind++; + } + } + n_blocks = ind; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock), &n_blocks); +#endif + } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, oj + j_br, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ij + j_br * handle->desc.u, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock); + ind++; + } + } + n_blocks = ind; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock), &n_blocks); +#endif + } else { + if (kj == handle->desc.R-1) { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb-1; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, oj + j_br, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ij + j_br * handle->desc.u, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock); + ind++; + } + } + n_blocks = ind; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock), &n_blocks); +#endif + } else { + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, oj + j_br, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ij + j_br * handle->desc.u, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock); + ind++; + } + } + n_blocks = ind; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_global, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock), &n_blocks); +#endif + } + } + } + } + } else { + /* Here we are using batch-reduce kernel and hybrid minibatch/FM parallelization */ + /* FIXME: Hardcoed logic for N=27 */ + int group_size = (handle->desc.threads == 27 && handle->desc.N == 27 && handle->ofw == 14 && handle->desc.R == 1 && handle->desc.u == 1 && ltid >= 24) ? 3 : LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies); + int tile_id = ltid / LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies); + int tiles = handle->weight_copies; + int img_per_tile = LIBXSMM_UPDIV(handle->desc.N, tiles); + int my_in_tile_id = ltid % group_size; + int ifms_per_thread = LIBXSMM_UPDIV(handle->blocksifm, group_size); + int ofms_per_thread = LIBXSMM_UPDIV(handle->blocksofm, group_size); + int my_R_start = 0; + int my_R_end = handle->desc.R; + const float beta = ((handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw)) ? 0.f : 1.f; + gemm_br_function br_gemm_kernel = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); + const float beta_flat = 0.0; + gemm_br_function br_gemm_kernel_flat = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->upd_ofw_rb, &LDA, &LDB, &LDC, NULL, &beta_flat, &l_flags, &prefetch_mode); + element_filter_type *weight_ptr_group = (handle->weight_copies > 1) ? (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + LIBXSMM_VLA_DECL(6, element_filter_type, weight_private_group, (element_filter_type*)weight_ptr_group, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + LIBXSMM_VLA_DECL(6, element_filter_type, weight_private_group, (element_filter_type*)weight_ptr_group, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + my_img_start = LIBXSMM_MIN(tile_id * img_per_tile, handle->desc.N); + my_img_end = LIBXSMM_MIN((tile_id+1) * img_per_tile, handle->desc.N); + my_ifm_start = LIBXSMM_MIN(my_in_tile_id * ifms_per_thread, handle->blocksifm ); + my_ifm_end = LIBXSMM_MIN((my_in_tile_id+1) * ifms_per_thread, handle->blocksifm ); + my_ofm_start = 0; + my_ofm_end = handle->blocksofm; + /* FIXME: Hardcoed logic for N=27 */ + if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.C == 256 && handle->desc.K == 1024 && handle->ofh == 14 && handle->desc.u == 1) { + my_ofm_start = LIBXSMM_MIN(my_in_tile_id * ofms_per_thread, handle->blocksofm); + my_ofm_end = LIBXSMM_MIN((my_in_tile_id+1) * ofms_per_thread, handle->blocksofm); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + } + if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.R == 3 && handle->desc.S == 3 && handle->ofh == 14) { + int r_per_tile = LIBXSMM_UPDIV(handle->desc.R, group_size); + my_ifm_start = 0; + my_ifm_end = handle->blocksifm; + my_ofm_start = 0; + my_ofm_end = handle->blocksofm; + my_R_start = LIBXSMM_MIN(my_in_tile_id * r_per_tile, handle->desc.R); + my_R_end = LIBXSMM_MIN((my_in_tile_id+1) * r_per_tile, handle->desc.R); + } + block_ofm = my_ofm_end-my_ofm_start+1; + block_ifm = my_ifm_end-my_ifm_start+1; + img_block_size = my_img_end - my_img_start; + + if (handle->desc.N != handle->desc.threads) { + /* Use "flat" parallelism + reduction */ + const int work = handle->desc.R * handle->desc.S * handle->blocksofm * handle->blocksifm * handle->desc.N; + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : (work / handle->desc.threads) + 1; + const int work_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int work_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int work_item; + int Cb = handle->blocksifm; + int Kb = handle->blocksofm; + int R = handle->desc.R; + int S = handle->desc.S; + const int IFH = (handle->upd_pack_input == 1) ? handle->ifhp/handle->desc.u : IFHP; + const int IFW = (handle->upd_pack_input == 1) ? handle->ifwp/handle->desc.v : IFWP; + element_input_type *input_ptr_base = (handle->upd_pack_input == 1) ? (element_input_type*)((char*)handle->scratch + handle->upd_packing_padding_scratch_offset) : (element_input_type*)input_ptr_to_use; + LIBXSMM_VLA_DECL(5, element_input_type, input_use, (element_input_type*)input_ptr_base, IFH, IFW, handle->blocksifm, handle->ifmblock); + + /* If requested, pack input to avoid strided accesses */ + if (handle->upd_pack_input == 1) { + LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + const int img_chunk = (handle->desc.N % handle->desc.threads == 0) ? handle->desc.N/handle->desc.threads : (handle->desc.N/handle->desc.threads) + 1; + const int img_copy_start = LIBXSMM_MIN(ltid*img_chunk, handle->desc.N); + const int img_copy_end = LIBXSMM_MIN((ltid+1)*img_chunk, handle->desc.N); + + for (img = img_copy_start; img < img_copy_end; img++) { + for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { + for (oj = 0; oj < handle->ofh; oj++) { + for (oi = 0; oi < handle->ofw; oi++) { + ij = oj * handle->desc.u; + ii = oi * handle->desc.v; + LIBXSMM_PRAGMA_SIMD + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_VLA_ACCESS(5, input_use, img, oj, oi, ifm1, ifm2, IFH, IFW, handle->blocksifm, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ij, ii, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); + } + } + } + } + } + libxsmm_barrier_wait(handle->barrier, ltid); + } + + /* Initialize weights to zero */ + if (handle->upd_ofw_rb != handle->ofw) { + for (work_item = work_begin; work_item < work_end; work_item++) { + img = work_item/(Cb*Kb*R*S); + ofm1 = (work_item%(Cb*Kb*R*S))/(Cb*R*S); + ifm1 = ((work_item%(Cb*Kb*R*S))%(Cb*R*S))/(R*S); + kj = (((work_item%(Cb*Kb*R*S))%(Cb*R*S))%(R*S))/S; + ki = (((work_item%(Cb*Kb*R*S))%(Cb*R*S))%(R*S))%S; + { + element_filter_type *weight_ptr_current = (handle->weight_copies > 1) ? (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset)+ img * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + LIBXSMM_VLA_DECL(6, element_filter_type, weight_current, (element_filter_type*)weight_ptr_current, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + LIBXSMM_VLA_DECL(6, element_filter_type, weight_current, (element_filter_type*)weight_ptr_current, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { + LIBXSMM_PRAGMA_SIMD + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + LIBXSMM_VLA_ACCESS(6, weight_current, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) = (element_filter_type)0; +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + LIBXSMM_VLA_ACCESS(6, weight_current, kj, ki, ifm1, ifm2, ofm1, ofm2, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock) = (element_filter_type)0; +#endif + } + } + } + } + } + + for (work_item = work_begin; work_item < work_end; work_item++) { + img = work_item/(Cb*Kb*R*S); + ofm1 = (work_item%(Cb*Kb*R*S))/(Cb*R*S); + ifm1 = ((work_item%(Cb*Kb*R*S))%(Cb*R*S))/(R*S); + kj = (((work_item%(Cb*Kb*R*S))%(Cb*R*S))%(R*S))/S; + ki = (((work_item%(Cb*Kb*R*S))%(Cb*R*S))%(R*S))%S; + ii = 0 + ki; + ij = 0 + kj; + oj = 0; + oi = 0; + { + element_filter_type *weight_ptr_current = (handle->weight_copies > 1) ? (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + img * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + LIBXSMM_VLA_DECL(6, element_filter_type, weight_current, (element_filter_type*)weight_ptr_current, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + LIBXSMM_VLA_DECL(6, element_filter_type, weight_current, (element_filter_type*)weight_ptr_current, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); +#endif + ind = 0; + for (j_br = 0; j_br < handle->ofh; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img , oj + j_br, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input_use, img, ij + j_br * handle->desc.u, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock); + ind++; + } + n_blocks = ind; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + br_gemm_kernel_flat(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_current, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + br_gemm_kernel_flat(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_current, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock), &n_blocks); +#endif + } + } + } else { + /* May need to initialized private weights to zero */ + if (!((handle->upd_ofh_rb == handle->ofh) && (handle->upd_ofw_rb == handle->ofw))) { + for (ofm1 = my_ofm_start; ofm1 < my_ofm_end; ofm1++ ) { + for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++ ) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + LIBXSMM_VLA_ACCESS(6, weight_private_group, ofm1, ifm1, kj, ki, ifm2, ofm2, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock) = (element_filter_type)0; +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + LIBXSMM_VLA_ACCESS(6, weight_private_group, kj, ki, ifm1, ifm2, ofm1, ofm2, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock) = (element_filter_type)0; +#endif + } + } + } + } + } + } + } + + if (handle->upd_loop_order == 0) { + for (img = my_img_start; img < my_img_end; img += img_block_size) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += block_ofm) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += block_ifm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->upd_ofh_rb) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+block_ofm, my_ofm_end); ofm1++ ) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+block_ifm, my_ifm_end); ifm1++) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->upd_ofh_rb,handle->ofh); oj+= handle->upd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->upd_ofw_rb) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, oj + j_br, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ij + j_br * handle->desc.u, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock); + ind++; + } + } + n_blocks = ind; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_private_group, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_private_group, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock), &n_blocks); +#endif + } + } + } + } + } + } + } + } + } + } + } else { + for (img = my_img_start; img < my_img_end; img += img_block_size) { + for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += block_ifm) { + for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += block_ofm) { + for (ojb = 0; ojb < handle->ofh; ojb += handle->upd_ofh_rb) { + for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+block_ifm, my_ifm_end); ifm1++) { + for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+block_ofm, my_ofm_end); ofm1++ ) { + for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->upd_ofh_rb,handle->ofh); oj+= handle->upd_ofh_rb) { + for (oi = 0; oi < handle->ofw; oi += handle->upd_ofw_rb) { + for (kj = my_R_start; kj < my_R_end; ++kj) { + for (ki = 0; ki < handle->desc.S; ++ki) { + ii = oi * handle->desc.u + ki; + ij = oj * handle->desc.v + kj; + ind = 0; + for (img_br = 0; img_br < img_block_size; img_br++) { + for (j_br = 0; j_br < handle->upd_ofh_rb; j_br++) { + A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, output, img + img_br, oj + j_br, oi, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); + B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img + img_br, ij + j_br * handle->desc.u, ii, ifm1, 0, IFHP, IFWP, handle->blocksifm, handle->ifmblock); + ind++; + } + } + n_blocks = ind; +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_private_group, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), &n_blocks); +#endif +#if defined(LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK) + br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(6, weight_private_group, kj, ki, ifm1, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock), &n_blocks); +#endif + } + } + } + } + } + } + } + } + } + } + } + } + } +} + +if (handle->weight_copies > 1) { + /* reduce work-related variables */ + const int fm_blocking = (handle->ofmblock % 16 == 0) ? 16 : handle->ofmblock; + const int reduce_work = handle->blocksofm * handle->blocksifm * handle->desc.R * handle->desc.S * (handle->ofmblock/fm_blocking) * handle->ifmblock; + const int reduce_chunksize = (reduce_work % handle->desc.threads == 0) ? (reduce_work / handle->desc.threads) : (reduce_work / handle->desc.threads) + 1; + const int reduce_thr_begin = (ltid * reduce_chunksize < reduce_work) ? (ltid * reduce_chunksize) : reduce_work; + const int reduce_thr_end = ((ltid + 1) * reduce_chunksize < reduce_work) ? ((ltid + 1) * reduce_chunksize) : reduce_work; + + /* Perform reduction here */ + libxsmm_barrier_wait(handle->barrier, ltid); + + for ( ij = reduce_thr_begin; ij < reduce_thr_end; ij++ ) { + element_filter_type *weight_ptr_glb = (element_filter_type*) handle->grad_filter->data; +#if 1 + float weight_sum[64]; + int wtcnt = 0; + assert( handle->ofmblock <= 64 ); + + LIBXSMM_PRAGMA_SIMD + for ( wtcnt = 0; wtcnt < fm_blocking; ++wtcnt ) { + weight_sum[wtcnt] = 0.0f; + } + + for ( ii = 0; ii < handle->weight_copies; ii++ ) { + element_filter_type *weight_ptr_src = (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset)+ ii * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S + ij * fm_blocking; + LIBXSMM_PRAGMA_SIMD + for ( wtcnt = 0; wtcnt < fm_blocking; ++wtcnt ) { + weight_sum[wtcnt] += weight_ptr_src[wtcnt]; + } + } + + LIBXSMM_PRAGMA_SIMD + for ( wtcnt = 0; wtcnt < fm_blocking; ++wtcnt ) { + weight_ptr_glb[(ij*fm_blocking) + wtcnt] = weight_sum[wtcnt]; + } +#else + __m512 weight_sum = _mm512_setzero_ps(); + for ( ii = 0; ii < handle->weight_copies; ii++ ) { + element_filter_type *weight_ptr_src = (element_filter_type*)handle->scratch7 + ii * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S + ij * 16; + weight_sum = _mm512_add_ps(weight_sum, LIBXSMM_INTRINSICS_MM512_LOAD_PS(weight_ptr_src)); + } + _mm512_storeu_ps(&weight_ptr_glb[ij*16], weight_sum); +#endif + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..950f0a230dfe54a4e9d90b5dfe6667b35bc1a16c --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c @@ -0,0 +1,246 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + /* size variables, all const */ + /* here we assume that input and output blocking is similar */ + const int nBlocksIFm = handle->blocksifm; + const int nIFmBlock = handle->ifmblock; + const int nBlocksOFm = handle->blocksofm; + const int nOFmBlock = handle->ofmblock; + + /* computing first logical thread */ + const int ltid = tid - start_thread; + /* number of tasks that could be run in parallel */ + const int work = nBlocksIFm; + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + + /* number of tasks for transpose that could be run in parallel */ + const int transpose_work = nBlocksIFm * nBlocksOFm; + /* compute chunk size */ + const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; + const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + + /* loop variables */ + int ofm1 = 0; + int ofm2 = 0; + int ifm1 = 0; + int ifm2 = 0; + int ifm1ofm1 = 0; + + LIBXSMM_VLA_DECL(3, const element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksOFm, nOFmBlock); + LIBXSMM_VLA_DECL(4, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, nIFmBlock, nOFmBlock); +#if defined(LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32) + float* dinput_f32_ptr = (float*)handle->scratch; + float* filter_f32_ptr = ((float*)handle->scratch)+((size_t)handle->desc.N*(size_t)handle->desc.C); + LIBXSMM_VLA_DECL(3, float, dinput, dinput_f32_ptr, nBlocksIFm, nIFmBlock); + LIBXSMM_VLA_DECL(4, float, filter_tr, filter_f32_ptr, nBlocksOFm, nOFmBlock, nIFmBlock); + + /* number of tasks that could be run in parallel */ + const int work_input = handle->desc.N * handle->desc.C; + /* compute chunk size */ + const int chunksize_input = (work_input % handle->desc.threads == 0) ? (work_input / handle->desc.threads) : ((work_input / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin_input = (ltid * chunksize_input < work_input) ? (ltid * chunksize_input) : work_input; + const int thr_end_input = ((ltid + 1) * chunksize_input < work_input) ? ((ltid + 1) * chunksize_input) : work_input; +#else + LIBXSMM_VLA_DECL(3, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksIFm, nIFmBlock); + LIBXSMM_VLA_DECL(4, element_filter_type, filter_tr, (element_filter_type*)handle->scratch, nBlocksOFm, nOFmBlock, nIFmBlock); +#endif + + /* lazy barrier init */ + libxsmm_barrier_init(handle->barrier, ltid); + + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / nBlocksIFm; + ifm1 = ifm1ofm1 % nBlocksIFm; + + for (ofm2 = 0; ofm2 < nOFmBlock; ++ofm2) { + for (ifm2 = 0; ifm2 < nIFmBlock; ++ifm2) { +#if defined(LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32) + union libxsmm_bfloat16_hp filter_f32; + filter_f32.i[0] = 0; + filter_f32.i[1] = LIBXSMM_VLA_ACCESS(4, filter, ofm1, ifm1, ifm2, ofm2, nBlocksIFm, nIFmBlock, nOFmBlock); + LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1, ofm2, ifm2, nBlocksOFm, nOFmBlock, nIFmBlock) = filter_f32.f; +#else + LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1, ofm2, ifm2, nBlocksOFm, nOFmBlock, nIFmBlock) = + LIBXSMM_VLA_ACCESS(4, filter, ofm1, ifm1, ifm2, ofm2, nBlocksIFm, nIFmBlock, nOFmBlock); +#endif + } + } + } + + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); + + for ( ifm1 = thr_begin; ifm1 < thr_end; ++ifm1 ) { /* outer GEMM m-loop */ +#if 1 + gemm_kernel_bwd( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, 0, 0, 0, nBlocksOFm, nOFmBlock, nIFmBlock), + &LIBXSMM_VLA_ACCESS(3, doutput, 0, 0, 0, nBlocksOFm, nOFmBlock), + &LIBXSMM_VLA_ACCESS(3, dinput, 0, ifm1, 0, nBlocksIFm, nIFmBlock) ); +#else + const int nImg = handle->desc.N; + int img2; + + /* this is a simple replacement code using regular loops */ + for ( img2 = 0; img2 < nImg; ++img2 ) { + LIBXSMM_PRAGMA_SIMD + for ( ifm2 = 0; ifm2 < nIFmBlock; ++ifm2 ) { + LIBXSMM_VLA_ACCESS(3, dinput, img2, ifm1, ifm2, nBlocksIFm, nIFmBlock) = (element_output_type)0; + } + } + for ( ofm1 = 0; ofm1 < nBlocksOFm; ++ofm1 ) { /* outer GEMM k-loop */ + for ( ofm2 = 0; ofm2 < nOFmBlock; ++ofm2 ) { /* GEMM K-loop */ + for ( img2 = 0; img2 < nImg; ++img2 ) { /* GEMM n-loop */ + LIBXSMM_PRAGMA_SIMD + for ( ifm2 = 0; ifm2 < nIFmBlock; ++ifm2 ) { /* GEMM m-loop */ + LIBXSMM_VLA_ACCESS(3, dinput, img2, ifm1, ifm2, nBlocksIFm, nIFmBlock) += + LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1, ofm2, ifm2, nBlocksOFm, nOFmBlock, nIFmBlock) * LIBXSMM_VLA_ACCESS(3, doutput, img2, ofm1, ofm2, nBlocksOFm, nOFmBlock); + } + } + } + } +#endif + } + +#if defined(LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32) + libxsmm_barrier_wait(handle->barrier, ltid); + + libxsmm_rne_convert_fp32_bf16( dinput_f32_ptr+thr_begin_input, ((element_input_type*)handle->grad_input->data)+thr_begin_input, thr_end_input-thr_begin_input ); +#endif + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + /* size variables, all const */ + const int nImg = handle->desc.N; + /* here we assume that input and output blocking is similar */ + const int nBlocksIFm = handle->blocksifm; + const int nIFmBlock = handle->ifmblock; + const int nBlocksOFm = handle->blocksofm; + const int nOFmBlock = handle->ofmblock; + + /* computing first logical thread */ + const int ltid = tid - start_thread; + /* number of tasks that could be run in parallel */ + const int work = nBlocksIFm * nBlocksOFm; + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + + /* number of tasks for transpose that could be run in parallel */ + const int transpose_work = nBlocksIFm; + /* compute chunk size */ + const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; + const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + + /* loop variables */ + int img2 = 0; + int ifm1ofm1 = 0; + int ofm1 = 0; + int ifm1 = 0; + int ifm2 = 0; + + LIBXSMM_VLA_DECL(3, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, nIFmBlock); + LIBXSMM_VLA_DECL(3, const element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksOFm, nOFmBlock); +#if defined(LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32) + float* input_f32_ptr = (float*)handle->scratch; + float* dfilter_f32_ptr = ((float*)handle->scratch)+((size_t)handle->desc.N*(size_t)handle->desc.C); + LIBXSMM_VLA_DECL(3, float, input_tr, input_f32_ptr, nIFmBlock, nImg); + LIBXSMM_VLA_DECL(4, float, dfilter, dfilter_f32_ptr, nBlocksIFm, nIFmBlock, nOFmBlock); + + /* number of tasks that could be run in parallel */ + const int work_filter = handle->desc.C * handle->desc.K; + /* compute chunk size */ + const int chunksize_filter = (work_filter % handle->desc.threads == 0) ? (work_filter / handle->desc.threads) : ((work_filter / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin_filter = (ltid * chunksize_filter < work_filter) ? (ltid * chunksize_filter) : work_filter; + const int thr_end_filter = ((ltid + 1) * chunksize_filter < work_filter) ? ((ltid + 1) * chunksize_filter) : work_filter; +#else + LIBXSMM_VLA_DECL(4, element_filter_type, dfilter, (element_filter_type*)handle->grad_filter->data, nBlocksIFm, nIFmBlock, nOFmBlock); + LIBXSMM_VLA_DECL(3, element_input_type, input_tr, (element_input_type* )handle->scratch, nIFmBlock, nImg); +#endif + + /* lazy barrier init */ + libxsmm_barrier_init(handle->barrier, ltid); + + for (ifm1 = transpose_thr_begin; ifm1 < transpose_thr_end; ++ifm1) { + for (ifm2 = 0; ifm2 < nIFmBlock; ++ifm2) { + for (img2 = 0; img2 < nImg; ++img2) { +#if defined(LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32) + union libxsmm_bfloat16_hp input_f32; + input_f32.i[0] = 0; + input_f32.i[1] = LIBXSMM_VLA_ACCESS(3, input, img2, ifm1, ifm2, nBlocksIFm, nIFmBlock); + LIBXSMM_VLA_ACCESS(3, input_tr, ifm1, ifm2, img2, nIFmBlock, nImg) = input_f32.f; +#else + LIBXSMM_VLA_ACCESS(3, input_tr, ifm1, ifm2, img2, nIFmBlock, nImg) = + LIBXSMM_VLA_ACCESS(3, input, img2, ifm1, ifm2, nBlocksIFm, nIFmBlock); +#endif + } + } + } + + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); + + for ( ifm1ofm1 = thr_begin; ifm1ofm1 < thr_end; ++ifm1ofm1 ) { /* outer GEMM m/n-loop */ + ofm1 = ifm1ofm1 / nBlocksIFm; + ifm1 = ifm1ofm1 % nBlocksIFm; + +#if 1 + gemm_kernel_upd( &LIBXSMM_VLA_ACCESS(3, doutput, 0, ofm1, 0, nBlocksOFm, nOFmBlock), + &LIBXSMM_VLA_ACCESS(3, input_tr, ifm1, 0, 0, nIFmBlock, nImg), + &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, 0, 0, nBlocksIFm, nIFmBlock, nOFmBlock) ); +#else + { + const int nImg = handle->desc.N; + int ifm2, ofm2; + + /* this is a simple replacement code using regular loops */ + for ( ifm2 = 0; ifm2 < nIFmBlock; ++ifm2 ) { + LIBXSMM_PRAGMA_SIMD + for ( ofm2 = 0; ofm2 < nOFmBlock; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ifm2, ofm2, nBlocksIFm, nIFmBlock, nOFmBlock) = (element_output_type)0; + } + } + for ( img2 = 0; img2 < nImg; ++img2 ) { /* GEMM k-loop */ + for ( ifm2 = 0; ifm2 < nIFmBlock; ++ifm2 ) { /* GEMM n-loop */ + LIBXSMM_PRAGMA_SIMD + for ( ofm2 = 0; ofm2 < nOFmBlock; ++ofm2 ) { /* GEMM m-loop */ + LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ifm2, ofm2, nBlocksIFm, nIFmBlock, nOFmBlock) += + LIBXSMM_VLA_ACCESS(3, doutput, img2, ofm1, ofm2, nBlocksOFm, nOFmBlock) * LIBXSMM_VLA_ACCESS(3, input_tr, ifm1, ifm2, img2, nIFmBlock, nImg); + } + } + } + } +#endif + } + +#if defined(LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32) + libxsmm_barrier_wait(handle->barrier, ltid); + + libxsmm_rne_convert_fp32_bf16( dfilter_f32_ptr+thr_begin_filter, ((element_input_type*)handle->grad_filter->data)+thr_begin_filter, thr_end_filter-thr_begin_filter ); +#endif + + libxsmm_barrier_wait(handle->barrier, ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..84313611260de765b6b0e40315d5dd5bc078b452 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c @@ -0,0 +1,346 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* here we assume that input and output blocking is similar */ +const int bn = handle->bn; +const int bk = handle->bk; +const int bc = handle->bc; +const int nBlocksIFm = handle->desc.C / bc; +const int nBlocksOFm = handle->desc.K / bk; +const int nBlocksMB = handle->desc.N / bn; + +/* computing first logical thread */ +const int ltid = tid - start_thread; + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) +/* number of tasks for transpose that could be run in parallel */ +const int eltwise_work = nBlocksOFm * nBlocksMB; +/* compute chunk size */ +const int eltwise_chunksize = (eltwise_work % handle->desc.threads == 0) ? (eltwise_work / handle->desc.threads) : ((eltwise_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int eltwise_thr_begin = (ltid * eltwise_chunksize < eltwise_work) ? (ltid * eltwise_chunksize) : eltwise_work; +const int eltwise_thr_end = ((ltid + 1) * eltwise_chunksize < eltwise_work) ? ((ltid + 1) * eltwise_chunksize) : eltwise_work; +int mb1ofm1; +#endif + +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS +/* number of tasks for transpose that could be run in parallel */ +const int dbias_work = nBlocksOFm; +/* compute chunk size */ +const int dbias_chunksize = (dbias_work % handle->desc.threads == 0) ? (dbias_work / handle->desc.threads) : ((dbias_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int dbias_thr_begin = (ltid * dbias_chunksize < dbias_work) ? (ltid * dbias_chunksize) : dbias_work; +const int dbias_thr_end = ((ltid + 1) * dbias_chunksize < dbias_work) ? ((ltid + 1) * dbias_chunksize) : dbias_work; +#endif + +/* loop variables */ +int ofm1 = 0, mb1 = 0, ofm2 = 0, mb2 = 0; + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) +element_output_type *grad_output_ptr = ((element_output_type*)handle->scratch)+(handle->desc.C*handle->desc.K); +LIBXSMM_VLA_DECL(4, const element_output_type, doutput_orig, (element_output_type*)handle->grad_output->data, nBlocksOFm, bn, bk); +#else +element_output_type *grad_output_ptr = (element_output_type*)handle->grad_output->data; +#endif +LIBXSMM_VLA_DECL(4, element_output_type, doutput, grad_output_ptr, nBlocksOFm, bn, bk); + +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS +LIBXSMM_VLA_DECL(2, float, dbias, (float*) handle->grad_bias->data, handle->bk); +#endif +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU +LIBXSMM_VLA_DECL(4, unsigned char, relumask, (unsigned char*) handle->relumask->data, nBlocksOFm, handle->bn, handle->bk); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) +for ( mb1ofm1 = eltwise_thr_begin; mb1ofm1 < eltwise_thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + float l_cur_out = LIBXSMM_VLA_ACCESS(4, doutput_orig, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU + l_cur_out = (LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) != 0) ? l_cur_out : (element_output_type)0; +#endif +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + l_cur_out = l_cur_out*(1.0f - l_cur_out); +#endif + LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } +} + +/* wait for eltwise to finish */ +libxsmm_barrier_wait(handle->barrier, ltid); +#endif + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_BIAS) +for ( ofm1 = dbias_thr_begin; ofm1 < dbias_thr_end; ++ofm1 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, ofm2, handle->bk ) = 0.0f; + } + + for ( mb1 = 0; mb1 < nBlocksMB; ++mb1 ) { + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, ofm2, handle->bk ) += LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); + } + } + } +} + +/* wait for eltwise to finish */ +libxsmm_barrier_wait(handle->barrier, ltid); +#endif + +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + const int use_2d_blocking = handle->bwd_2d_blocking; + + /* number of tasks that could be run in parallel */ + const int work = nBlocksIFm * nBlocksMB; + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + + /* number of tasks for transpose that could be run in parallel */ + const int transpose_work = nBlocksIFm * nBlocksOFm; + /* compute chunk size */ + const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; + const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + + /* loop variables */ + int ifm1 = 0, ifm2 = 0, ifm1ofm1 = 0, mb1ifm1 = 0; + int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; + + LIBXSMM_VLA_DECL(4, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc, bk); + LIBXSMM_VLA_DECL(4, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksIFm, bn, bc); + LIBXSMM_VLA_DECL(4, element_filter_type, filter_tr, (element_filter_type*)handle->scratch, nBlocksOFm, bk, bc); + + unsigned long long blocks = nBlocksOFm; + int KB_BLOCKS = nBlocksOFm, BF = 1; + libxsmm_meltw_unary_param trans_param; + + BF = handle->bwd_bf; + KB_BLOCKS = nBlocksOFm/BF; + blocks = KB_BLOCKS; + + if (use_2d_blocking == 1) { + row_teams = handle->bwd_row_teams; + column_teams = handle->bwd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksMB, row_teams); + in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksIFm, column_teams); + my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksMB); + my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksMB); + my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksIFm); + my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksIFm); + } + + /* transpose weight */ + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / nBlocksIFm; + ifm1 = ifm1ofm1 % nBlocksIFm; + trans_param.in.primary = (void*)&LIBXSMM_VLA_ACCESS(4, filter, ofm1, ifm1, 0, 0, nBlocksIFm, bc, bk); + trans_param.out.primary = &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1, 0, 0, nBlocksOFm, bk, bc); + handle->tr_kernel( &trans_param ) ; +#if 0 + for (ofm2 = 0; ofm2 < bk; ++ofm2) { + for (ifm2 = 0; ifm2 < bc; ++ifm2) { + LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1, ofm2, ifm2, nBlocksOFm, bk, bc) = + LIBXSMM_VLA_ACCESS(4, filter, ofm1, ifm1, ifm2, ofm2, nBlocksIFm, bc, bk); + } + } +#endif + } + + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); + + if (use_2d_blocking == 1) { + if (BF > 1) { + for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) { + for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { + /* Initialize intermediate f32 tensor */ + if ( ofm1 == 0 ) { + for ( mb2 = 0; mb2 < bn; ++mb2 ) { + for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) { + LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, mb2, ifm2, nBlocksIFm, bn, bc) = (element_input_type)0; + } + } + } + batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bk, bc ), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + } + } + } + } else { + for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { + batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, 0, 0, 0, nBlocksOFm, bk, bc), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, 0, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + } + } + } + } else { + if (BF > 1) { + for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) { + for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + /* Initialize intermediate f32 tensor */ + if ( ofm1 == 0 ) { + for ( mb2 = 0; mb2 < bn; ++mb2 ) { + for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) { + LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, mb2, ifm2, nBlocksIFm, bn, bc) = (element_input_type)0; + } + } + } + batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bk, bc ), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + } + } + } else { + for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, 0, 0, 0, nBlocksOFm, bk, bc ), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, 0, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + /* number of tasks that could be run in parallel */ + const int ofm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ofm_subtasks; + const int ifm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ifm_subtasks; + const int bbk = (handle->upd_2d_blocking == 1) ? bk : bk/ofm_subtasks; + const int bbc = (handle->upd_2d_blocking == 1) ? bc : bc/ifm_subtasks; + const int work = nBlocksIFm * ifm_subtasks * nBlocksOFm * ofm_subtasks; + const int Cck_work = nBlocksIFm * ifm_subtasks * ofm_subtasks; + const int Cc_work = nBlocksIFm * ifm_subtasks; + + /* 2D blocking parameters */ + int use_2d_blocking = handle->upd_2d_blocking; + int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; + + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int BF = handle->upd_bf; + + /* loop variables */ + int ifm1ofm1 = 0, ifm1 = 0, ifm2 = 0, bfn = 0, ii = 0, jj = 0; + + /* Batch reduce related variables */ + unsigned long long blocks = nBlocksMB/BF; + + LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, bn, bc); + LIBXSMM_VLA_DECL(4, element_filter_type, dfilter, (element_filter_type*)handle->grad_filter->data, nBlocksIFm, bc, bk); + + if (use_2d_blocking == 1) { + row_teams = handle->upd_row_teams; + column_teams = handle->upd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksIFm, row_teams); + in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksOFm, column_teams); + my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksIFm); + my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksIFm); + my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksOFm); + my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksOFm); + } + + if (use_2d_blocking == 1) { + if (BF == 1) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) { + batchreduce_kernel_upd_zerobeta(&LIBXSMM_VLA_ACCESS(4, doutput, 0, ofm1, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, input, 0, ifm1, 0, 0, nBlocksIFm, bn, bc), + &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, 0, 0, nBlocksIFm, bc, bk), &blocks); + } + } + } else { + for (bfn = 0; bfn < BF; bfn++) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) { + /* initialize current work task to zero */ + if (bfn == 0) { + for (ii = 0; iibarrier, ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..a47e49c77f5cac72cc8942a2f7e802d1bd9adf59 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c @@ -0,0 +1,625 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +/* size variables, all const */ +/* here we assume that input and output blocking is similar */ +const int bn = handle->bn; +const int bk = handle->bk; +const int bc = handle->bc; +int lpb = 2; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +const int bn_lp = bn/lpb; +const int nBlocksIFm = handle->desc.C / handle->bc; +const int nBlocksOFm = handle->desc.K / handle->bk; +const int nBlocksMB = handle->desc.N / handle->bn; +int mb1ofm1 = 0, mb1 = 0, ofm1 = 0, mb2 = 0, ofm2 = 0; +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) || defined(LIBXSMM_DNN_FC_BWD_FUSE_BIAS) +int iteri = 0, iterj = 0; +#endif +int performed_doutput_transpose = 0; + +/* computing first logical thread */ +const int ltid = tid - start_thread; + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) +/* number of tasks for transpose that could be run in parallel */ +const int eltwise_work = nBlocksOFm * nBlocksMB; +/* compute chunk size */ +const int eltwise_chunksize = (eltwise_work % handle->desc.threads == 0) ? (eltwise_work / handle->desc.threads) : ((eltwise_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int eltwise_thr_begin = (ltid * eltwise_chunksize < eltwise_work) ? (ltid * eltwise_chunksize) : eltwise_work; +const int eltwise_thr_end = ((ltid + 1) * eltwise_chunksize < eltwise_work) ? ((ltid + 1) * eltwise_chunksize) : eltwise_work; +#endif + +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS +/* number of tasks for transpose that could be run in parallel */ +const int dbias_work = nBlocksOFm; +/* compute chunk size */ +const int dbias_chunksize = (dbias_work % handle->desc.threads == 0) ? (dbias_work / handle->desc.threads) : ((dbias_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int dbias_thr_begin = (ltid * dbias_chunksize < dbias_work) ? (ltid * dbias_chunksize) : dbias_work; +const int dbias_thr_end = ((ltid + 1) * dbias_chunksize < dbias_work) ? ((ltid + 1) * dbias_chunksize) : dbias_work; +#endif + +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS +LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, dbias, (libxsmm_bfloat16*) handle->grad_bias->data, handle->bk); +#endif +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU +LIBXSMM_VLA_DECL(4, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk); +LIBXSMM_VLA_DECL(4, __mmask32, relubitmask, (__mmask32*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk/32); +#endif + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) +element_output_type *grad_output_ptr = (element_output_type*)((char*)handle->scratch + handle->doutput_scratch_mark); +element_output_type *tr_doutput_ptr = (element_output_type*)grad_output_ptr + handle->desc.N * handle->desc.K; +LIBXSMM_VLA_DECL(4, const element_output_type, doutput_orig, (element_output_type*)handle->grad_output->data, nBlocksOFm, bn, bk); +#else +element_output_type *grad_output_ptr = (element_output_type*)handle->grad_output->data; +element_output_type *tr_doutput_ptr = (element_output_type*)handle->scratch; +#endif +LIBXSMM_VLA_DECL(4, element_output_type, doutput, grad_output_ptr, nBlocksOFm, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, doutput_tr, tr_doutput_ptr, nBlocksMB, bn_lp, bk, lpb); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +/* Apply to doutput potential fusions */ +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) +if (bk % 32 == 0) { + for ( mb1ofm1 = eltwise_thr_begin; mb1ofm1 < eltwise_thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + + for ( iteri = 0; iteri < handle->bn; ++iteri ) { + for ( iterj = 0; iterj < handle->bk; iterj += 32 ) { + __m512i cur_out_reg = _mm512_loadu_si512(&LIBXSMM_VLA_ACCESS(4, doutput_orig, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk)); +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + __m512 cur_out_reg_0, cur_out_reg_1; + const __m512 ones = _mm512_set1_ps(1.0f); +#endif +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU + __m512i zero_reg = _mm512_setzero_si512(); + __mmask32 relumask = LIBXSMM_INTRINSICS_MM512_LOAD_MASK32 (&LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, iteri, iterj/32, nBlocksOFm, handle->bn, handle->bk/32)); + cur_out_reg = _mm512_mask_blend_epi16 (relumask, zero_reg, cur_out_reg); +#endif +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + cur_out_reg_0 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(cur_out_reg, 0)),16)); + cur_out_reg_1 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(cur_out_reg, 1)),16)); + cur_out_reg_0 = _mm512_mul_ps(cur_out_reg_0, _mm512_sub_ps(ones, cur_out_reg_0)); + cur_out_reg_1 = _mm512_mul_ps(cur_out_reg_1, _mm512_sub_ps(ones, cur_out_reg_1)); + cur_out_reg = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(cur_out_reg_1, cur_out_reg_0); +#endif + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk), cur_out_reg); + } + } + + /* If in UPD pass, also perform transpose of doutput */ + if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + bf16_vnni_reformat((element_output_type*)&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, mb1, 0, 0, 0, nBlocksMB, bn_lp, bk, lpb), bk, bn, bk, bn); + } + } +} else { + for ( mb1ofm1 = eltwise_thr_begin; mb1ofm1 < eltwise_thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + + for ( iteri = 0; iteri < handle->bn; ++iteri ) { + for ( iterj = 0; iterj < handle->bk; ++iterj ) { + element_output_type l_cur_out = LIBXSMM_VLA_ACCESS(4, doutput_orig, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + float l_cur_out_f32 = 0; + libxsmm_bfloat16_hp tmp; +#endif +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU + l_cur_out = (element_output_type)((LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk) != 0) ? l_cur_out : (element_output_type)0); +#endif +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID + tmp.i[0] = 0; + tmp.i[1] = l_cur_out; + l_cur_out_f32 = tmp.f; + l_cur_out_f32 = l_cur_out_f32*(1.0f - l_cur_out_f32); + libxsmm_rne_convert_fp32_bf16(&l_cur_out_f32, &l_cur_out, 1); +#endif + LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } + + /* If in UPD pass, also perform transpose of doutput */ + if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + for (mb2 = 0; mb2 < bn; mb2++) { + for (ofm2 = 0; ofm2 < bk; ofm2++) { + LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, mb1, mb2/lpb, ofm2, mb2%lpb, nBlocksMB, bn_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, bn, bk); + } + } + } + } +} +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + performed_doutput_transpose = 1; +} +libxsmm_barrier_wait(handle->barrier, ltid); +#endif + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_BIAS) +/* Accumulation of bias happens in f32 */ +{ + float *scratch_dbias = (float*) ((element_output_type*)handle->scratch + handle->desc.N * (handle->desc.K + handle->desc.C) + ltid * bk * 2); + if (handle->bk % 16 == 0) { + __m512 zero_reg = _mm512_setzero_ps(); + __m512 doutput_reg = _mm512_setzero_ps(); + __m512 dbias_reg = _mm512_setzero_ps(); + for ( ofm1 = dbias_thr_begin; ofm1 < dbias_thr_end; ++ofm1 ) { + for ( iterj = 0; iterj < handle->bk; iterj += 16 ) { + _mm512_storeu_ps(scratch_dbias+iterj, zero_reg); + } + for ( mb1 = 0; mb1 < nBlocksMB; ++mb1 ) { + for ( iteri = 0; iteri < handle->bn; ++iteri ) { + for ( iterj = 0; iterj < handle->bk; iterj += 16 ) { + doutput_reg = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk))); + dbias_reg = LIBXSMM_INTRINSICS_MM512_LOAD_PS(scratch_dbias+iterj); + dbias_reg = _mm512_add_ps(dbias_reg, doutput_reg); + _mm512_storeu_ps(scratch_dbias+iterj, dbias_reg); + } + } + } + for ( iterj = 0; iterj < handle->bk; iterj += 16 ) { + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, iterj, handle->bk ), LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(scratch_dbias+iterj)) ); + } + } + } else { + for ( ofm1 = dbias_thr_begin; ofm1 < dbias_thr_end; ++ofm1 ) { + for ( iterj = 0; iterj < handle->bk; ++iterj ) { + scratch_dbias[iterj] = 0.0; + } + for ( mb1 = 0; mb1 < nBlocksMB; ++mb1 ) { + for ( iteri = 0; iteri < handle->bn; ++iteri ) { + for ( iterj = 0; iterj < handle->bk; ++iterj ) { + float doutput_f32 = 0; + libxsmm_bfloat16_hp tmp; + tmp.i[0] = 0; + tmp.i[1] = LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk); + doutput_f32 = tmp.f; + scratch_dbias[iterj] += doutput_f32; + } + } + } + libxsmm_rne_convert_fp32_bf16(scratch_dbias, &LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, 0, handle->bk ), handle->bk); + } + } +} + +/* wait for eltwise to finish */ +libxsmm_barrier_wait(handle->barrier, ltid); +#endif + +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ){ + int use_2d_blocking = handle->bwd_2d_blocking; + + /* number of tasks that could be run in parallel */ + const int work = nBlocksIFm * nBlocksMB; + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + + /* number of tasks for transpose that could be run in parallel */ + const int transpose_work = nBlocksIFm * nBlocksOFm; + /* compute chunk size */ + const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; + const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + + /* loop variables */ + int ifm1 = 0, ifm2 = 0, ifm1ofm1 = 0, mb1ifm1 = 0; + int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; + + LIBXSMM_VLA_DECL(5, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc_lp, bk, lpb); + LIBXSMM_VLA_DECL(4, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksIFm, bn, bc); + LIBXSMM_VLA_DECL(5, element_filter_type, filter_tr, (element_filter_type*)handle->scratch, nBlocksOFm, bk_lp, bc, lpb); + float* temp_output = (float*)handle->scratch + (handle->desc.C * handle->desc.K)/2; + LIBXSMM_VLA_DECL(4, float, dinput_f32, (float*) temp_output, nBlocksIFm, bn, bc); + + unsigned long long blocks = nBlocksOFm; + int KB_BLOCKS = nBlocksOFm, BF = 1; + BF = handle->bwd_bf; + KB_BLOCKS = nBlocksOFm/BF; + blocks = KB_BLOCKS; + + if (use_2d_blocking == 1) { + row_teams = handle->bwd_row_teams; + column_teams = handle->bwd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksMB, row_teams); + in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksIFm, column_teams); + my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksMB); + my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksMB); + my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksIFm); + my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksIFm); + } + + if (handle->desc.K > 1) { + /* transpose weight */ + if ((bk % 16 == 0) && (bc % 16 == 0)) { + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / nBlocksIFm; + ifm1 = ifm1ofm1 % nBlocksIFm; + bf16_vnni_transpose((element_filter_type*)&LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1, 0, 0, 0, nBlocksIFm, bc_lp, bk, lpb), (element_filter_type*)&LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, ofm1, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), bk, bc, bk, bc); + } + } else { + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / nBlocksIFm; + ifm1 = ifm1ofm1 % nBlocksIFm; + for (ofm2 = 0; ofm2 < bk; ++ofm2) { + for (ifm2 = 0; ifm2 < bc; ++ifm2) { + LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, ofm1, ofm2/lpb, ifm2, ofm2%lpb, nBlocksOFm, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1, ifm2/lpb, ofm2, ifm2%lpb, nBlocksIFm, bc_lp, bk, lpb); + } + } + } + } + + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); + + if (use_2d_blocking == 1) { + if (BF > 1) { + for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) { + for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { + /* Initialize intermediate f32 tensor */ + if ( ofm1 == 0 ) { + memset(&LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), 0, bn*bc*sizeof(float)); + } + batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ + if ( ofm1 == BF-1 ) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(&LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), bn*bc); + } + } + } + } + } else { + for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { + batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, 0, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, 0, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + } + } + } + } else { + if (BF > 1) { + for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) { + for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + /* Initialize intermediate f32 tensor */ + if ( ofm1 == 0 ) { + memset(&LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), 0, bn*bc*sizeof(float)); + } + batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ + if ( ofm1 == BF-1 ) { + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(&LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), bn*bc); + } + } + } + } else { + for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, 0, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, 0, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + } + } + } + } else { + /* Special case when K = 1 */ + /* number of tasks for doutput copy that could be run in parallel */ + const int copy_work_output = nBlocksMB * nBlocksOFm; + /* compute chunk size */ + const int copy_chunksize = (copy_work_output % handle->desc.threads == 0) ? (copy_work_output / handle->desc.threads) : ((copy_work_output / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int copy_thr_begin = (ltid * copy_chunksize < copy_work_output) ? (ltid * copy_chunksize) : copy_work_output; + const int copy_thr_end = ((ltid + 1) * copy_chunksize < copy_work_output) ? ((ltid + 1) * copy_chunksize) : copy_work_output; + LIBXSMM_VLA_DECL(5, element_filter_type, filter_tr_padded, (element_filter_type*)handle->scratch, nBlocksOFm, 1, bc, lpb); + LIBXSMM_VLA_DECL(4, element_output_type, doutput_padded, (element_output_type*)handle->scratch + handle->desc.C * 2, nBlocksOFm, bn, lpb); + + /* Copy in weights and doutput in a padded buffer */ + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / nBlocksIFm; + ifm1 = ifm1ofm1 % nBlocksIFm; + ofm2 = 0; + for (ifm2 = 0; ifm2 < bc; ++ifm2) { + LIBXSMM_VLA_ACCESS(5, filter_tr_padded, ifm1, ofm1, ofm2/lpb, ifm2, ofm2%lpb, nBlocksOFm, 1, bc, lpb) = LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1, ifm2/lpb, ofm2, ifm2%lpb, nBlocksIFm, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, filter_tr_padded, ifm1, ofm1, ofm2/lpb, ifm2, 1, nBlocksOFm, 1, bc, lpb) = (element_filter_type)0; + } + } + + for (mb1ofm1 = copy_thr_begin; mb1ofm1 < copy_thr_end; ++mb1ofm1) { + mb1 = mb1ofm1 / nBlocksOFm; + ofm1 = mb1ofm1 % nBlocksOFm; + ofm2 = 0; + for (mb2 = 0; mb2 < bn; ++mb2) { + LIBXSMM_VLA_ACCESS(4, doutput_padded, mb1, ofm1, mb2, 0, nBlocksOFm, bn, 2) = LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, 0, nBlocksOFm, bn, bk); + LIBXSMM_VLA_ACCESS(4, doutput_padded, mb1, ofm1, mb2, 1, nBlocksOFm, bn, 2) = (element_output_type)0; + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter_tr_padded, ifm1, 0, 0, 0, 0, nBlocksOFm, 1, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, doutput_padded, mb1, 0, 0, 0, nBlocksOFm, bn, 2), + &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + /* number of tasks that could be run in parallel */ + const int ofm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ofm_subtasks; + const int ifm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ifm_subtasks; + const int bbk = (handle->upd_2d_blocking == 1) ? bk : bk/ofm_subtasks; + const int bbc = (handle->upd_2d_blocking == 1) ? bc : bc/ifm_subtasks; + const int work = nBlocksIFm * ifm_subtasks * nBlocksOFm * ofm_subtasks; + const int Cck_work = nBlocksIFm * ifm_subtasks * ofm_subtasks; + const int Cc_work = nBlocksIFm * ifm_subtasks; + + /* 2D blocking parameters */ + int use_2d_blocking = handle->upd_2d_blocking; + int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; + + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int BF = handle->upd_bf; + + /* loop variables */ + int ifm1ofm1 = 0, ifm1 = 0, ifm2 = 0, bfn = 0, ii = 0, jj = 0, mb1ifm1 = 0, jc = 0, jk = 0; + + /* Batch reduce related variables */ + unsigned long long blocks = nBlocksMB/BF; + + LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, bn, bc); + LIBXSMM_VLA_DECL(5, element_filter_type, dfilter, (element_filter_type*)handle->grad_filter->data, nBlocksIFm, bc_lp, bk, lpb); + + /* Set up tensors for transposing/scratch before vnni reformatting dfilter */ + element_input_type *tr_inp_ptr = (element_input_type*) ((element_output_type*)handle->scratch + handle->desc.N * handle->desc.K); + float *dfilter_f32_ptr = (float*) ((element_input_type*)tr_inp_ptr + handle->desc.N * handle->desc.C); + element_filter_type *dfilter_scratch = (element_filter_type*) ((float*)dfilter_f32_ptr + handle->desc.C * handle->desc.K) + ltid * bc * bk; + + LIBXSMM_VLA_DECL(4, element_input_type, input_tr, (element_input_type*)tr_inp_ptr, nBlocksMB, bc, bn); + LIBXSMM_VLA_DECL(4, float, dfilter_f32, (float*)dfilter_f32_ptr, nBlocksIFm, bc, bk); + LIBXSMM_VLA_DECL(2, element_filter_type, dfilter_block, (element_filter_type*)dfilter_scratch, bk); + + const int tr_out_work = nBlocksMB * nBlocksOFm; + const int tr_out_chunksize = (tr_out_work % handle->desc.threads == 0) ? (tr_out_work / handle->desc.threads) : ((tr_out_work / handle->desc.threads) + 1); + const int tr_out_thr_begin = (ltid * tr_out_chunksize < tr_out_work) ? (ltid * tr_out_chunksize) : tr_out_work; + const int tr_out_thr_end = ((ltid + 1) * tr_out_chunksize < tr_out_work) ? ((ltid + 1) * tr_out_chunksize) : tr_out_work; + + const int tr_inp_work = nBlocksMB * nBlocksIFm; + const int tr_inp_chunksize = (tr_inp_work % handle->desc.threads == 0) ? (tr_inp_work / handle->desc.threads) : ((tr_inp_work / handle->desc.threads) + 1); + const int tr_inp_thr_begin = (ltid * tr_inp_chunksize < tr_inp_work) ? (ltid * tr_inp_chunksize) : tr_inp_work; + const int tr_inp_thr_end = ((ltid + 1) * tr_inp_chunksize < tr_inp_work) ? ((ltid + 1) * tr_inp_chunksize) : tr_inp_work; + + /* These are used for the vnni reformatting of the f32 output */ + __m256i c0, c1; + __m512 a01, b01; + __m512i c01 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + const __m512i perm_index = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + + if (use_2d_blocking == 1) { + row_teams = handle->upd_row_teams; + column_teams = handle->upd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksIFm, row_teams); + in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksOFm, column_teams); + my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksIFm); + my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksIFm); + my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksOFm); + my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksOFm); + } + + /* Required upfront tranposes */ + if (bc % 32 == 0) { + for (mb1ifm1 = tr_inp_thr_begin; mb1ifm1 < tr_inp_thr_end; mb1ifm1++) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + bf16_transpose((element_input_type*)&LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &LIBXSMM_VLA_ACCESS(4, input_tr, ifm1, mb1, 0, 0, nBlocksMB, bc, bn), bc, bn, bc, bn); + } + } else { + for (mb1ifm1 = tr_inp_thr_begin; mb1ifm1 < tr_inp_thr_end; mb1ifm1++) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + for (mb2 = 0; mb2 < bn; mb2++) { + for (ifm2 = 0; ifm2 < bc; ifm2++) { + LIBXSMM_VLA_ACCESS(4, input_tr, ifm1, mb1, ifm2, mb2, nBlocksMB, bc, bn) = LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1, mb2, ifm2, nBlocksIFm, bn, bc); + } + } + } + } + + if (performed_doutput_transpose == 0) { + if (bk % 32 == 0) { + for (mb1ofm1 = tr_out_thr_begin; mb1ofm1 < tr_out_thr_end; mb1ofm1++) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + bf16_vnni_reformat((element_output_type*)&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, mb1, 0, 0, 0, nBlocksMB, bn_lp, bk, lpb), bk, bn, bk, bn); + } + } else { + for (mb1ofm1 = tr_out_thr_begin; mb1ofm1 < tr_out_thr_end; mb1ofm1++) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + for (mb2 = 0; mb2 < bn; mb2++) { + for (ofm2 = 0; ofm2 < bk; ofm2++) { + LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, mb1, mb2/lpb, ofm2, mb2%lpb, nBlocksMB, bn_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, bn, bk); + } + } + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + if (use_2d_blocking == 1) { + if (BF == 1) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) { + batchreduce_kernel_upd_zerobeta(&LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, 0, 0, 0, 0, nBlocksMB, bn_lp, bk, lpb), &LIBXSMM_VLA_ACCESS(4, input_tr, ifm1, 0, 0, 0, nBlocksMB, bc, bn), &LIBXSMM_VLA_ACCESS(2, dfilter_block, 0, 0, bk), &blocks); + /* TODO: Make this vnni reformating in the kernel... */ + /* Copy result back to vnni format */ + if ((bc % 2 == 0) && (bk % 16 == 0)) { + for (jc = 0; jc < bc; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c1 = _mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dfilter_block, jc+1,jk, bk)); + c0 = _mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dfilter_block, jc, jk, bk)); + c01 = _mm512_inserti64x4(c01, c0, 0); + c01 = _mm512_inserti64x4(c01, c1, 1); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(5, dfilter, ofm1, ifm1, jc/lpb, jk, 0, nBlocksIFm, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } else { + for (ii = 0; ii < bc; ii++) { + for (jj = 0; jj < bk; jj++) { + LIBXSMM_VLA_ACCESS(5, dfilter, ofm1, ifm1, ii/lpb, jj, ii%lpb, nBlocksIFm, bc_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, dfilter_block, ii, jj, bk); + } + } + } + } + } + } else { + for (bfn = 0; bfn < BF; bfn++) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) { + /* initialize current work task to zero */ + if (bfn == 0) { + for (ii = 0; iibarrier, ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..19538fd90a1227c40cb30eab695f089d0a950a89 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c @@ -0,0 +1,604 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +/* size variables, all const */ +/* here we assume that input and output blocking is similar */ +const int bn = handle->bn; +const int bk = handle->bk; +const int bc = handle->bc; +int lpb = 2; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +const int bn_lp = bn/lpb; +const int nBlocksIFm = handle->desc.C / handle->bc; +const int nBlocksOFm = handle->desc.K / handle->bk; +const int nBlocksMB = handle->desc.N / handle->bn; +int mb1ofm1 = 0, mb1 = 0, ofm1 = 0, mb2 = 0, ofm2 = 0; +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) || defined(LIBXSMM_DNN_FC_BWD_FUSE_BIAS) +int iteri = 0, iterj = 0; +#endif +int performed_doutput_transpose = 0; + +/* computing first logical thread */ +const int ltid = tid - start_thread; + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) +/* number of tasks for transpose that could be run in parallel */ +const int eltwise_work = nBlocksOFm * nBlocksMB; +/* compute chunk size */ +const int eltwise_chunksize = (eltwise_work % handle->desc.threads == 0) ? (eltwise_work / handle->desc.threads) : ((eltwise_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int eltwise_thr_begin = (ltid * eltwise_chunksize < eltwise_work) ? (ltid * eltwise_chunksize) : eltwise_work; +const int eltwise_thr_end = ((ltid + 1) * eltwise_chunksize < eltwise_work) ? ((ltid + 1) * eltwise_chunksize) : eltwise_work; +#endif + +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS +/* number of tasks for transpose that could be run in parallel */ +const int dbias_work = nBlocksOFm; +/* compute chunk size */ +const int dbias_chunksize = (dbias_work % handle->desc.threads == 0) ? (dbias_work / handle->desc.threads) : ((dbias_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int dbias_thr_begin = (ltid * dbias_chunksize < dbias_work) ? (ltid * dbias_chunksize) : dbias_work; +const int dbias_thr_end = ((ltid + 1) * dbias_chunksize < dbias_work) ? ((ltid + 1) * dbias_chunksize) : dbias_work; +#endif + +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS +LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, dbias, (libxsmm_bfloat16*) handle->grad_bias->data, handle->bk); +#endif +#ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU +LIBXSMM_VLA_DECL(4, __mmask32, relubitmask, (__mmask32*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk/32); +#endif + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) +element_output_type *grad_output_ptr = (element_output_type*)((char*)handle->scratch + handle->doutput_scratch_mark); +element_output_type *tr_doutput_ptr = (element_output_type*)grad_output_ptr + handle->desc.N * handle->desc.K; +LIBXSMM_VLA_DECL(4, const element_output_type, doutput_orig, (element_output_type*)handle->grad_output->data, nBlocksOFm, bn, bk); +#else +element_output_type *grad_output_ptr = (element_output_type*)handle->grad_output->data; +element_output_type *tr_doutput_ptr = (element_output_type*)handle->scratch; +#endif +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) +libxsmm_meltw_unary_param relu_params; +libxsmm_meltwfunction_unary relu_kernel = handle->bwd_relu_kernel; +#endif +LIBXSMM_VLA_DECL(4, element_output_type, doutput, grad_output_ptr, nBlocksOFm, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, doutput_tr, tr_doutput_ptr, nBlocksMB, bn_lp, bk, lpb); + +libxsmm_meltwfunction_unary eltwise_kernel = handle->bwd_cvtfp32bf16_kernel; +libxsmm_meltw_unary_param eltwise_params; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); +bwd_tile_config_kernel(NULL, NULL, NULL); + +/* Apply to doutput potential fusions */ +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) +LIBXSMM_UNUSED(iteri); +LIBXSMM_UNUSED(iterj); +for ( mb1ofm1 = eltwise_thr_begin; mb1ofm1 < eltwise_thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1/nBlocksOFm; + ofm1 = mb1ofm1%nBlocksOFm; + + relu_params.in.primary = (void*) &LIBXSMM_VLA_ACCESS(4, doutput_orig, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk); + relu_params.out.primary = &LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk); + relu_params.in.secondary = &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk/32); + relu_kernel(&relu_params); + + /* If in UPD pass, also perform transpose of doutput */ + if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + bf16_vnni_reformat((element_output_type*)&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, mb1, 0, 0, 0, nBlocksMB, bn_lp, bk, lpb), bk, bn, bk, bn); + } +} + +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + performed_doutput_transpose = 1; +} +libxsmm_barrier_wait(handle->barrier, ltid); +#endif + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) +if (bk % 32 == 0) { + for ( mb1ofm1 = eltwise_thr_begin; mb1ofm1 < eltwise_thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + + for ( iteri = 0; iteri < handle->bn; ++iteri ) { + for ( iterj = 0; iterj < handle->bk; iterj += 32 ) { + __m512i cur_out_reg = _mm512_loadu_si512(&LIBXSMM_VLA_ACCESS(4, doutput_orig, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk)); + __m512 cur_out_reg_0, cur_out_reg_1; + const __m512 ones = _mm512_set1_ps(1.0f); + cur_out_reg_0 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(cur_out_reg, 0)),16)); + cur_out_reg_1 = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(cur_out_reg, 1)),16)); + cur_out_reg_0 = _mm512_mul_ps(cur_out_reg_0, _mm512_sub_ps(ones, cur_out_reg_0)); + cur_out_reg_1 = _mm512_mul_ps(cur_out_reg_1, _mm512_sub_ps(ones, cur_out_reg_1)); + cur_out_reg = LIBXSMM_INTRINSICS_MM512_CVT2_FP32_BF16(cur_out_reg_1, cur_out_reg_0); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk), cur_out_reg); +#ifdef USE_CLDEMOTE + _mm_cldemote(&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk)); +#endif + } + } + + /* If in UPD pass, also perform transpose of doutput */ + if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + bf16_vnni_reformat((element_output_type*)&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, mb1, 0, 0, 0, nBlocksMB, bn_lp, bk, lpb), bk, bn, bk, bn); + } + } +} else { + for ( mb1ofm1 = eltwise_thr_begin; mb1ofm1 < eltwise_thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + + for ( iteri = 0; iteri < handle->bn; ++iteri ) { + for ( iterj = 0; iterj < handle->bk; ++iterj ) { + element_output_type l_cur_out = LIBXSMM_VLA_ACCESS(4, doutput_orig, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk); + float l_cur_out_f32 = 0; + libxsmm_bfloat16_hp tmp; + tmp.i[0] = 0; + tmp.i[1] = l_cur_out; + l_cur_out_f32 = tmp.f; + l_cur_out_f32 = l_cur_out_f32*(1.0f - l_cur_out_f32); + libxsmm_rne_convert_fp32_bf16(&l_cur_out_f32, &l_cur_out, 1); + LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } + + /* If in UPD pass, also perform transpose of doutput */ + if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + for (mb2 = 0; mb2 < bn; mb2++) { + for (ofm2 = 0; ofm2 < bk; ofm2++) { + LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, mb1, mb2/lpb, ofm2, mb2%lpb, nBlocksMB, bn_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, bn, bk); + } + } + } + } +} +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + performed_doutput_transpose = 1; +} +libxsmm_barrier_wait(handle->barrier, ltid); +#endif + +#if defined(LIBXSMM_DNN_FC_BWD_FUSE_BIAS) +/* Accumulation of bias happens in f32 */ +{ + float *scratch_dbias = (float*) ((element_output_type*)handle->scratch + handle->desc.N * (handle->desc.K + handle->desc.C) + ltid * bk * 2); + if (handle->bk % 32 == 0) { + for ( ofm1 = dbias_thr_begin; ofm1 < dbias_thr_end; ++ofm1 ) { + for ( iterj = 0; iterj < handle->bk; iterj += 32 ) { + __m512 doutput_reg_0, doutput_reg_1, dbias_reg_0, dbias_reg_1; + dbias_reg_0 = _mm512_setzero_ps(); + dbias_reg_1 = _mm512_setzero_ps(); + for ( mb1 = 0; mb1 < nBlocksMB; ++mb1 ) { + for ( iteri = 0; iteri < handle->bn; ++iteri ) { + doutput_reg_0 = _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk)); + doutput_reg_1 = _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj+16, nBlocksOFm, handle->bn, handle->bk)); + dbias_reg_0 = _mm512_add_ps(dbias_reg_0, doutput_reg_0); + dbias_reg_1 = _mm512_add_ps(dbias_reg_1, doutput_reg_1); + } + } + _mm512_store_si512(&LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, iterj, handle->bk), LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(dbias_reg_1, dbias_reg_0)); + } + } + } else { + for ( ofm1 = dbias_thr_begin; ofm1 < dbias_thr_end; ++ofm1 ) { + for ( iterj = 0; iterj < handle->bk; ++iterj ) { + scratch_dbias[iterj] = 0.0; + } + for ( mb1 = 0; mb1 < nBlocksMB; ++mb1 ) { + for ( iteri = 0; iteri < handle->bn; ++iteri ) { + for ( iterj = 0; iterj < handle->bk; ++iterj ) { + float doutput_f32 = 0; + libxsmm_bfloat16_hp tmp; + tmp.i[0] = 0; + tmp.i[1] = LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, iteri, iterj, nBlocksOFm, handle->bn, handle->bk); + doutput_f32 = tmp.f; + scratch_dbias[iterj] += doutput_f32; + } + } + } + libxsmm_rne_convert_fp32_bf16(scratch_dbias, &LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, 0, handle->bk ), handle->bk); + } + } +} + +/* wait for eltwise to finish */ +libxsmm_barrier_wait(handle->barrier, ltid); +#endif + +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ){ + int use_2d_blocking = handle->bwd_2d_blocking; + + /* number of tasks that could be run in parallel */ + const int work = nBlocksIFm * nBlocksMB; + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + + /* number of tasks for transpose that could be run in parallel */ + const int transpose_work = nBlocksIFm * nBlocksOFm; + /* compute chunk size */ + const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; + const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; + + /* loop variables */ + int ifm1 = 0, ifm2 = 0, ifm1ofm1 = 0, mb1ifm1 = 0; + int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; + + LIBXSMM_VLA_DECL(5, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc_lp, bk, lpb); + LIBXSMM_VLA_DECL(4, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksIFm, bn, bc); + LIBXSMM_VLA_DECL(5, element_filter_type, filter_tr, (element_filter_type*)handle->scratch, nBlocksOFm, bk_lp, bc, lpb); + float* temp_output = (float*)handle->scratch + (handle->desc.C * handle->desc.K)/2; + LIBXSMM_VLA_DECL(4, float, dinput_f32, (float*) temp_output, nBlocksIFm, bn, bc); + + unsigned long long blocks = nBlocksOFm; + int KB_BLOCKS = nBlocksOFm, BF = 1; + BF = handle->bwd_bf; + KB_BLOCKS = nBlocksOFm/BF; + blocks = KB_BLOCKS; + + if (use_2d_blocking == 1) { + row_teams = handle->bwd_row_teams; + column_teams = handle->bwd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = (nBlocksMB + row_teams-1)/row_teams; + in_tasks_per_thread = (nBlocksIFm + column_teams-1)/column_teams; + my_im_start = LIBXSMM_MIN( my_row_id * im_tasks_per_thread, nBlocksMB); + my_im_end = LIBXSMM_MIN( (my_row_id+1) * im_tasks_per_thread, nBlocksMB); + my_in_start = LIBXSMM_MIN( my_col_id * in_tasks_per_thread, nBlocksIFm); + my_in_end = LIBXSMM_MIN( (my_col_id+1) * in_tasks_per_thread, nBlocksIFm); + } + + /* transpose weight */ + if ((bk % 16 == 0) && (bc % 16 == 0)) { + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / nBlocksIFm; + ifm1 = ifm1ofm1 % nBlocksIFm; + bf16_vnni_transpose((element_filter_type*)&LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1, 0, 0, 0, nBlocksIFm, bc_lp, bk, lpb), (element_filter_type*)&LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, ofm1, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), bk, bc, bk, bc); + } + } else { + for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { + ofm1 = ifm1ofm1 / nBlocksIFm; + ifm1 = ifm1ofm1 % nBlocksIFm; + for (ofm2 = 0; ofm2 < bk; ++ofm2) { + for (ifm2 = 0; ifm2 < bc; ++ifm2) { + LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, ofm1, ofm2/lpb, ifm2, ofm2%lpb, nBlocksOFm, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1, ifm2/lpb, ofm2, ifm2%lpb, nBlocksIFm, bc_lp, bk, lpb); + } + } + } + } + + /* wait for transpose to finish */ + libxsmm_barrier_wait(handle->barrier, ltid); + + if (use_2d_blocking == 1) { + if (BF > 1) { + for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) { + for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { + /* Initialize intermediate f32 tensor */ + if ( ofm1 == 0 ) { + memset(&LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), 0, bn*bc*sizeof(float)); + } +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), handle->bn*handle->bc*sizeof(float)); + if ( ofm1 == BF-1 ) { + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), handle->bn*handle->bc*sizeof(libxsmm_bfloat16)); + } +#endif + batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ + if ( ofm1 == BF-1 ) { + eltwise_params.in.primary = &LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc); + eltwise_params.out.primary = &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc); + eltwise_kernel(&eltwise_params); + } + } + } + } + } else { + for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), handle->bn*handle->bc*sizeof(libxsmm_bfloat16)); +#endif + bf16_batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, 0, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, 0, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + } + } + } + } else { + if (BF > 1) { + for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) { + for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + /* Initialize intermediate f32 tensor */ + if ( ofm1 == 0 ) { + memset(&LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), 0, bn*bc*sizeof(float)); + } +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), handle->bn*handle->bc*sizeof(float)); + if ( ofm1 == BF-1 ) { + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), handle->bn*handle->bc*sizeof(libxsmm_bfloat16)); + } +#endif + batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ + if ( ofm1 == BF-1 ) { + eltwise_params.in.primary = &LIBXSMM_VLA_ACCESS(4, dinput_f32, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc); + eltwise_params.out.primary = &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc); + eltwise_kernel(&eltwise_params); + } + } + } + } else { + for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), handle->bn*handle->bc*sizeof(libxsmm_bfloat16)); +#endif + bf16_batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter_tr, ifm1, 0, 0, 0, 0, nBlocksOFm, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, doutput, mb1, 0, 0, 0, nBlocksOFm, bn, bk), + &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { + /* number of tasks that could be run in parallel */ + const int ofm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ofm_subtasks; + const int ifm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ifm_subtasks; + const int bbk = (handle->upd_2d_blocking == 1) ? bk : bk/ofm_subtasks; + const int bbc = (handle->upd_2d_blocking == 1) ? bc : bc/ifm_subtasks; + const int work = nBlocksIFm * ifm_subtasks * nBlocksOFm * ofm_subtasks; + const int Cck_work = nBlocksIFm * ifm_subtasks * ofm_subtasks; + const int Cc_work = nBlocksIFm * ifm_subtasks; + + /* 2D blocking parameters */ + int use_2d_blocking = handle->upd_2d_blocking; + int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; + + /* compute chunk size */ + const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); + /* compute thr_begin and thr_end */ + const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; + const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + int BF = handle->upd_bf; + + /* loop variables */ + int ifm1ofm1 = 0, ifm1 = 0, ifm2 = 0, bfn = 0, ii = 0, jj = 0, mb1ifm1 = 0, jc = 0, jk = 0; + + /* Batch reduce related variables */ + unsigned long long blocks = nBlocksMB/BF; + + LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, bn, bc); + LIBXSMM_VLA_DECL(5, element_filter_type, dfilter, (element_filter_type*)handle->grad_filter->data, nBlocksIFm, bc_lp, bk, lpb); + + /* Set up tensors for transposing/scratch before vnni reformatting dfilter */ + element_input_type *tr_inp_ptr = (element_input_type*) ((element_output_type*)handle->scratch + handle->desc.N * handle->desc.K); + float *dfilter_f32_ptr = (float*) ((element_input_type*)tr_inp_ptr + handle->desc.N * handle->desc.C); + element_filter_type *dfilter_scratch = (element_filter_type*) ((float*)dfilter_f32_ptr + handle->desc.C * handle->desc.K) + ltid * bc * bk; + + LIBXSMM_VLA_DECL(4, element_input_type, input_tr, (element_input_type*)tr_inp_ptr, nBlocksMB, bc, bn); + LIBXSMM_VLA_DECL(4, float, dfilter_f32, (float*)dfilter_f32_ptr, nBlocksIFm, bc, bk); + LIBXSMM_VLA_DECL(2, element_filter_type, dfilter_block, (element_filter_type*)dfilter_scratch, bk); + + const int tr_out_work = nBlocksMB * nBlocksOFm; + const int tr_out_chunksize = (tr_out_work % handle->desc.threads == 0) ? (tr_out_work / handle->desc.threads) : ((tr_out_work / handle->desc.threads) + 1); + const int tr_out_thr_begin = (ltid * tr_out_chunksize < tr_out_work) ? (ltid * tr_out_chunksize) : tr_out_work; + const int tr_out_thr_end = ((ltid + 1) * tr_out_chunksize < tr_out_work) ? ((ltid + 1) * tr_out_chunksize) : tr_out_work; + + const int tr_inp_work = nBlocksMB * nBlocksIFm; + const int tr_inp_chunksize = (tr_inp_work % handle->desc.threads == 0) ? (tr_inp_work / handle->desc.threads) : ((tr_inp_work / handle->desc.threads) + 1); + const int tr_inp_thr_begin = (ltid * tr_inp_chunksize < tr_inp_work) ? (ltid * tr_inp_chunksize) : tr_inp_work; + const int tr_inp_thr_end = ((ltid + 1) * tr_inp_chunksize < tr_inp_work) ? ((ltid + 1) * tr_inp_chunksize) : tr_inp_work; + + /* These are used for the vnni reformatting of the f32 output */ + __m512 a01, b01; + __m512i c01 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); + const __m512i perm_index = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + + if (use_2d_blocking == 1) { + row_teams = handle->upd_row_teams; + column_teams = handle->upd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = (nBlocksIFm + row_teams-1)/row_teams; + in_tasks_per_thread = (nBlocksOFm + column_teams-1)/column_teams; + my_im_start = LIBXSMM_MIN( my_row_id * im_tasks_per_thread, nBlocksIFm); + my_im_end = LIBXSMM_MIN( (my_row_id+1) * im_tasks_per_thread, nBlocksIFm); + my_in_start = LIBXSMM_MIN( my_col_id * in_tasks_per_thread, nBlocksOFm); + my_in_end = LIBXSMM_MIN( (my_col_id+1) * in_tasks_per_thread, nBlocksOFm); + } + + /* Required upfront tranposes */ + if (bc % 32 == 0) { + for (mb1ifm1 = tr_inp_thr_begin; mb1ifm1 < tr_inp_thr_end; mb1ifm1++) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + bf16_transpose((element_input_type*)&LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &LIBXSMM_VLA_ACCESS(4, input_tr, ifm1, mb1, 0, 0, nBlocksMB, bc, bn), bc, bn, bc, bn); + } + } else { + for (mb1ifm1 = tr_inp_thr_begin; mb1ifm1 < tr_inp_thr_end; mb1ifm1++) { + mb1 = mb1ifm1%nBlocksMB; + ifm1 = mb1ifm1/nBlocksMB; + for (mb2 = 0; mb2 < bn; mb2++) { + for (ifm2 = 0; ifm2 < bc; ifm2++) { + LIBXSMM_VLA_ACCESS(4, input_tr, ifm1, mb1, ifm2, mb2, nBlocksMB, bc, bn) = LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1, mb2, ifm2, nBlocksIFm, bn, bc); + } + } + } + } + + if (performed_doutput_transpose == 0) { + if (bk % 32 == 0) { + for (mb1ofm1 = tr_out_thr_begin; mb1ofm1 < tr_out_thr_end; mb1ofm1++) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + bf16_vnni_reformat((element_output_type*)&LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, mb1, 0, 0, 0, nBlocksMB, bn_lp, bk, lpb), bk, bn, bk, bn); + } + } else { + for (mb1ofm1 = tr_out_thr_begin; mb1ofm1 < tr_out_thr_end; mb1ofm1++) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + for (mb2 = 0; mb2 < bn; mb2++) { + for (ofm2 = 0; ofm2 < bk; ofm2++) { + LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, mb1, mb2/lpb, ofm2, mb2%lpb, nBlocksMB, bn_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, bn, bk); + } + } + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + if (use_2d_blocking == 1) { + ifm2 = 0; + ofm2 = 0; + if (BF == 1) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) { +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(5, dfilter, ofm1, ifm1, 0, 0, 0, nBlocksIFm, bc_lp, bk, lpb), bbc*bbk*sizeof(libxsmm_bfloat16)); +#endif + bf16_batchreduce_kernel_upd_zerobeta(&LIBXSMM_VLA_ACCESS(5, doutput_tr, ofm1, 0, 0, ofm2*bbk, 0, nBlocksMB, bn_lp, bk, lpb), &LIBXSMM_VLA_ACCESS(4, input_tr, ifm1, 0, ifm2*bbc, 0, nBlocksMB, bc, bn), &LIBXSMM_VLA_ACCESS(5, dfilter, ofm1, ifm1, 0, 0, 0, nBlocksIFm, bc_lp, bk, lpb), &blocks); + } + } + } else { + for (bfn = 0; bfn < BF; bfn++) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) { + /* initialize current work task to zero */ + if (bfn == 0) { + for (ii = 0; iibarrier, ltid); +} + +handle->tilerelease_kernel(NULL, NULL, NULL); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..69cedb30e02f8ee19bafbafe6a2204c78a92be07 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_custom_generic.tpl.c @@ -0,0 +1,102 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +/* size variables, all const */ +/* here we assume that input and output blocking is similar */ +const int nBlocksIFm = handle->blocksifm; +const int nIFmBlock = handle->ifmblock; +const int nBlocksOFm = handle->blocksofm; +const int nOFmBlock = handle->ofmblock; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nBlocksOFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int ofm1 = 0; + +LIBXSMM_VLA_DECL(3, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksOFm, nOFmBlock); +#if defined(LIBXSMM_DNN_FULLYCONNECTED_FWD_BF16_F32) +float* input_f32_ptr = (float*)handle->scratch; +float* filter_f32_ptr = ((float*)handle->scratch)+((size_t)handle->desc.N*(size_t)handle->desc.C); +LIBXSMM_VLA_DECL(3, const float, input, input_f32_ptr, nBlocksIFm, nIFmBlock); +LIBXSMM_VLA_DECL(4, const float, filter, filter_f32_ptr, nBlocksIFm, nIFmBlock, nOFmBlock); + +/* number of tasks that could be run in parallel */ +const int work_input = handle->desc.N * handle->desc.C; +/* compute chunk size */ +const int chunksize_input = (work_input % handle->desc.threads == 0) ? (work_input / handle->desc.threads) : ((work_input / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin_input = (ltid * chunksize_input < work_input) ? (ltid * chunksize_input) : work_input; +const int thr_end_input = ((ltid + 1) * chunksize_input < work_input) ? ((ltid + 1) * chunksize_input) : work_input; + +/* number of tasks that could be run in parallel */ +const int work_filter = handle->desc.C * handle->desc.K; +/* compute chunk size */ +const int chunksize_filter = (work_filter % handle->desc.threads == 0) ? (work_filter / handle->desc.threads) : ((work_filter / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin_filter = (ltid * chunksize_filter < work_filter) ? (ltid * chunksize_filter) : work_filter; +const int thr_end_filter = ((ltid + 1) * chunksize_filter < work_filter) ? ((ltid + 1) * chunksize_filter) : work_filter; +#else +LIBXSMM_VLA_DECL(3, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, nIFmBlock); +LIBXSMM_VLA_DECL(4, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, nIFmBlock, nOFmBlock); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +#if defined(LIBXSMM_DNN_FULLYCONNECTED_FWD_BF16_F32) +libxsmm_convert_bf16_f32( ((element_input_type*)handle->reg_input->data)+thr_begin_input, input_f32_ptr+thr_begin_input, thr_end_input - thr_begin_input ); +libxsmm_convert_bf16_f32( ((element_filter_type*)handle->reg_filter->data)+thr_begin_filter, filter_f32_ptr+thr_begin_filter, thr_end_filter - thr_begin_filter ); + +libxsmm_barrier_wait(handle->barrier, ltid); +#endif + +for ( ofm1 = thr_begin; ofm1 < thr_end; ++ofm1 ) { /* outer GEMM m-loop */ +#if 1 + gemm_kernel( &LIBXSMM_VLA_ACCESS(4, filter, ofm1, 0, 0, 0, nBlocksIFm, nIFmBlock, nOFmBlock), + &LIBXSMM_VLA_ACCESS(3, input, 0, 0, 0, nBlocksIFm, nIFmBlock), + &LIBXSMM_VLA_ACCESS(3, output, 0, ofm1, 0, nBlocksOFm, nOFmBlock) ); +#else + { + const int nImg = handle->desc.N; + int img2, ifm1, ifm2, ofm2; + + /* this is a simple replacement code using regular loops */ + for ( img2 = 0; img2 < nImg; ++img2 ) { + LIBXSMM_PRAGMA_SIMD + for ( ofm2 = 0; ofm2 < nOFmBlock; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(3, output, img2, ofm1, ofm2, nBlocksOFm, nOFmBlock) = (element_output_type)0; + } + } + for ( ifm1 = 0; ifm1 < nBlocksIFm; ++ifm1 ) { /* outer GEMM k-loop */ + for ( ifm2 = 0; ifm2 < nIFmBlock; ++ifm2 ) { /* GEMM K-loop */ + for ( img2 = 0; img2 < nImg; ++img2 ) { /* GEMM n-loop */ + LIBXSMM_PRAGMA_SIMD + for ( ofm2 = 0; ofm2 < nOFmBlock; ++ofm2 ) { /* GEMM m-loop */ + LIBXSMM_VLA_ACCESS(3, output, img2, ofm1, ofm2, nBlocksOFm, nOFmBlock) += + LIBXSMM_VLA_ACCESS(4, filter, ofm1, ifm1, ifm2, ofm2, nBlocksIFm, nIFmBlock, nOFmBlock) * LIBXSMM_VLA_ACCESS(3, input, img2, ifm1, ifm2, nBlocksIFm, nIFmBlock); + } + } + } + } + } +#endif +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..e0f854b365a0ef44453b178206ab52fd1d7f27af --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c @@ -0,0 +1,235 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +/* size variables, all const */ +/* here we assume that input and output blocking is similar */ +const int nBlocksIFm = handle->desc.C / handle->bc; +const int nBlocksOFm = handle->desc.K / handle->bk; +const int nBlocksMB = handle->desc.N / handle->bn; +int use_2d_blocking = handle->fwd_2d_blocking; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nBlocksOFm * nBlocksMB; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int mb1ofm1 = 0, mb1 = 0, ofm1 = 0, ifm1 = 0; +int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; +int mb2 = 0, ofm2 = 0; + +LIBXSMM_VLA_DECL(4, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksOFm, handle->bn, handle->bk); +LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, handle->bn, handle->bc); +LIBXSMM_VLA_DECL(4, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, handle->bc, handle->bk); +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS +LIBXSMM_VLA_DECL(2, const element_output_type, bias, (element_output_type*)handle->reg_bias->data, handle->bk); +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU +LIBXSMM_VLA_DECL(4, unsigned char, relumask, (unsigned char*) handle->relumask->data, nBlocksOFm, handle->bn, handle->bk); +#endif +#endif + +unsigned long long blocks = nBlocksIFm; +int CB_BLOCKS = nBlocksIFm, BF = 1; + +BF = handle->fwd_bf; +CB_BLOCKS = nBlocksIFm/BF; +blocks = CB_BLOCKS; + +if (use_2d_blocking == 1) { + row_teams = handle->fwd_row_teams; + column_teams = handle->fwd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksMB, row_teams); + in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksOFm, column_teams); + my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksMB); + my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksMB); + my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksOFm); + my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksOFm); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if (use_2d_blocking == 1) { + if (BF > 1) { + for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { + /* Initialize intermediate f32 tensor */ + if ( ifm1 == 0 ) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk); + } + } +#else + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (element_output_type)0; + } + } +#endif + } + batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(4, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bc, handle->bk), + &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE + if ( ifm1 == BF-1 ) { + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + float l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > (element_output_type)0 ) ? 1 : 0); + l_cur_out = (l_cur_out > (element_output_type)0) ? l_cur_out : (element_output_type)0; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f; +#endif + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } + } +#endif + } + } + } + } else { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk); + } + } + batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(4, filter, ofm1, 0, 0, 0, nBlocksIFm, handle->bc, handle->bk), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); +#else + batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(4, filter, ofm1, 0, 0, 0, nBlocksIFm, handle->bc, handle->bk), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); +#endif +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + element_output_type l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > (element_output_type)0 ) ? 1 : 0); + l_cur_out = ( l_cur_out > (element_output_type)0 ) ? l_cur_out : (element_output_type)0; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f; +#endif + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } +#endif + } + } + } +} else { + if (BF > 1) { + for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) { + for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + /* Initialize intermediate f32 tensor */ + if ( ifm1 == 0 ) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk); + } + } +#else + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (element_output_type)0; + } + } +#endif + } + batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(4, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bc, handle->bk), + &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE + if ( ifm1 == BF-1 ) { + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + float l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > (element_output_type)0 ) ? 1 : 0); + l_cur_out = (l_cur_out > (element_output_type)0) ? l_cur_out : (element_output_type)0; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f; +#endif + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } + } +#endif + } + } + } else { + for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk); + } + } + batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(4, filter, ofm1, 0, 0, 0, nBlocksIFm, handle->bc, handle->bk), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); +#else + batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(4, filter, ofm1, 0, 0, 0, nBlocksIFm, handle->bc, handle->bk), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); +#endif +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + element_output_type l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > (element_output_type)0 ) ? 1 : 0); + l_cur_out = ( l_cur_out > (element_output_type)0 ) ? l_cur_out : (element_output_type)0; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f; +#endif + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } +#endif + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..bb3a22da77b86f0e9de9237867c97a9fb0232b66 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c @@ -0,0 +1,379 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +/* size variables, all const */ +/* here we assume that input and output blocking is similar */ +const int nBlocksIFm = handle->desc.C / handle->bc; +const int nBlocksOFm = handle->desc.K / handle->bk; +const int nBlocksMB = handle->desc.N / handle->bn; +int lpb = 2; +const int bc_lp = handle->bc/lpb; +/* const int bc = handle->bc;*/ +int use_2d_blocking = handle->fwd_2d_blocking; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nBlocksOFm * nBlocksMB; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int mb1ofm1 = 0, mb1 = 0, ofm1 = 0, ifm1 = 0; +int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE +int mb2 = 0, ofm2 = 0; +#endif +LIBXSMM_VLA_DECL(4, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksOFm, handle->bn, handle->bk); +LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, handle->bn, handle->bc); +LIBXSMM_VLA_DECL(5, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc_lp, handle->bk, lpb); +float* temp_output = (float*)handle->scratch; +LIBXSMM_VLA_DECL(4, float, output_f32, (float*) temp_output, nBlocksOFm,handle->bn,handle->bk); +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS +LIBXSMM_VLA_DECL(2, const element_input_type, bias, (element_input_type*) handle->reg_bias->data, handle->bk); +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU +LIBXSMM_VLA_DECL(4, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk); +LIBXSMM_VLA_DECL(4, __mmask16, relubitmask, (__mmask16*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk/16); +#endif +#endif +unsigned long long blocks = nBlocksIFm; +int CB_BLOCKS = nBlocksIFm, BF = 1; + +BF = handle->fwd_bf; +CB_BLOCKS = nBlocksIFm/BF; +blocks = CB_BLOCKS; + +if (use_2d_blocking == 1) { + row_teams = handle->fwd_row_teams; + column_teams = handle->fwd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksMB, row_teams); + in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksOFm, column_teams); + my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksMB); + my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksMB); + my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksOFm); + my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksOFm); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if (use_2d_blocking == 1) { + if (BF > 1) { + for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { + /* Initialize intermediate f32 tensor */ + if ( ifm1 == 0 ) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 bn; ++mb2 ) { + LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, ofm1, 0,handle->bk), &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, 0, nBlocksOFm,handle->bn,handle->bk), handle->bk ); + } +#else + memset(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), 0, handle->bn*handle->bk*sizeof(float)); +#endif + } + batchreduce_kernel( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ + if ( ifm1 == BF-1 ) { +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE + if (handle->bk % 32 == 0) { + __m512 cur_out_0 = _mm512_setzero_ps(); + __m512 cur_out_1 = _mm512_setzero_ps(); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + __mmask16 relumask0; + __mmask16 relumask1; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + __m512 ones = _mm512_set1_ps(1.0); + __m512 halves = _mm512_set1_ps(0.5); +#endif + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) { + cur_out_0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk)); + cur_out_1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk)); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ ); + relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ ); + cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 ); + cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 ); +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(_mm512_mul_ps(cur_out_0, halves)), ones), halves); + cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(_mm512_mul_ps(cur_out_1, halves)), ones), halves); +#endif + _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), cur_out_0); + _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk), cur_out_1); + } + } + } else { + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + float l_cur_out = LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > (float)0 ) ? 1 : 0); + l_cur_out = (l_cur_out > (float)0) ? l_cur_out : (float)0; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f; +#endif + LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } + } +#endif + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm,handle->bn,handle->bk), &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm,handle->bn,handle->bk),handle->bn*handle->bk); + } + } + } + } + } else { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk); + } + } + batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); +#else + batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); +#endif +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE + if (handle->bk % 32 == 0) { + __m512 cur_out_0 = _mm512_setzero_ps(); + __m512 cur_out_1 = _mm512_setzero_ps(); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + __mmask16 relumask0; + __mmask16 relumask1; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + __m512 ones = _mm512_set1_ps(1.0); + __m512 halves = _mm512_set1_ps(0.5); +#endif + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) { + cur_out_0 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk))); + cur_out_1 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk))); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ ); + relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ ); + cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 ); + cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 ); +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(_mm512_mul_ps(cur_out_0, halves)), ones), halves); + cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(_mm512_mul_ps(cur_out_1, halves)), ones), halves); +#endif + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( cur_out_1, cur_out_0 )); + } + } + } else { + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + libxsmm_bfloat16_hp t; +#endif + libxsmm_bfloat16 l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( (l_cur_out & 0x8000) > 0 ) ? 0 : 1); + l_cur_out = (libxsmm_bfloat16)(( (l_cur_out & 0x8000) > 0 ) ? 0 : l_cur_out); +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + t.i[1] = l_cur_out; + t.i[0] = 0; + t.f = (libxsmm_stanh_pade78( t.f / 2.0f ) + 1.0f) / 2.0f; + l_cur_out = t.i[1]; +#endif + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } + } +#endif + } + } + } +} else { + if (BF > 1) { + for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) { + for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + /* Initialize intermediate f32 tensor */ + if ( ifm1 == 0 ) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 bn; ++mb2 ) { + LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, ofm1, 0,handle->bk), &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, 0, nBlocksOFm, handle->bn, handle->bk), handle->bk ); + } +#else + memset(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), 0, handle->bn*handle->bk*sizeof(float)); +#endif + } + batchreduce_kernel( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ + if ( ifm1 == BF-1 ) { +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE + if (handle->bk % 32 == 0) { + __m512 cur_out_0 = _mm512_setzero_ps(); + __m512 cur_out_1 = _mm512_setzero_ps(); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + __mmask16 relumask0; + __mmask16 relumask1; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + __m512 ones = _mm512_set1_ps(1.0); + __m512 halves = _mm512_set1_ps(0.5); +#endif + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) { + cur_out_0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk)); + cur_out_1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk)); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ ); + relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ ); + cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 ); + cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 ); +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(_mm512_mul_ps(cur_out_0, halves)), ones), halves); + cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(_mm512_mul_ps(cur_out_1, halves)), ones), halves); +#endif + _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), cur_out_0); + _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk), cur_out_1); + } + } + } else { + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { + float l_cur_out = LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > 0.0 ) ? 1 : 0); + l_cur_out = (l_cur_out > (float)0) ? l_cur_out : (float)0; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f; +#endif + LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } + } +#endif + LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk); + } + } + } + } else { + for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 bk; ++ofm2 ) { + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk); + } + } + batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); +#else + batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); +#endif +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE + if (handle->bk % 32 == 0) { + __m512 cur_out_0 = _mm512_setzero_ps(); + __m512 cur_out_1 = _mm512_setzero_ps(); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + __mmask16 relumask0; + __mmask16 relumask1; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + __m512 ones = _mm512_set1_ps(1.0); + __m512 halves = _mm512_set1_ps(0.5); +#endif + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) { + cur_out_0 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk))); + cur_out_1 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk))); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ ); + relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ ); + cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 ); + cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 ); +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(_mm512_mul_ps(cur_out_0, halves)), ones), halves); + cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(_mm512_mul_ps(cur_out_1, halves)), ones), halves); +#endif + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( cur_out_1, cur_out_0 )); + } + } + } else { + for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { + for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + libxsmm_bfloat16_hp t; +#endif + libxsmm_bfloat16 l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( (l_cur_out & 0x8000) > 0 ) ? 0 : 1); + l_cur_out = (libxsmm_bfloat16)(( (l_cur_out & 0x8000) > 0 ) ? 0 : l_cur_out); +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID + /* we ar using Pade 7/8 approximation */ + t.i[1] = l_cur_out; + t.i[0] = 0; + t.f = (libxsmm_stanh_pade78( t.f / 2.0f ) + 1.0f) / 2.0f; + l_cur_out = t.i[1]; +#endif + LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; + } + } + } + +#endif + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..a8fb8f8ad3e3d751dc11823b6edce3f6be862558 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c @@ -0,0 +1,223 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +/* size variables, all const */ +/* here we assume that input and output blocking is similar */ +const int nBlocksIFm = handle->desc.C / handle->bc; +const int nBlocksOFm = handle->desc.K / handle->bk; +const int nBlocksMB = handle->desc.N / handle->bn; +const int bn = handle->bn; +const int bk = handle->bk; +const int lpb = 2; +const int bc_lp = handle->bc/lpb; +/* const int bc = handle->bc;*/ +int use_2d_blocking = handle->fwd_2d_blocking; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nBlocksOFm * nBlocksMB; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int mb1ofm1 = 0, mb1 = 0, ofm1 = 0, ifm1 = 0; +int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; +LIBXSMM_VLA_DECL(4, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksOFm, handle->bn, handle->bk); +LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, handle->bn, handle->bc); +LIBXSMM_VLA_DECL(5, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc_lp, handle->bk, lpb); +float* temp_output = (float*)handle->scratch; +LIBXSMM_VLA_DECL(4, float, output_f32, (float*) temp_output, nBlocksOFm, bn, bk); + +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE +libxsmm_meltw_gemm_param gemm_eltwise_params; +#if defined(LIBXSMM_DNN_FC_FWD_FUSE_BIAS) +int mb2 = 0; +float* fp32_bias_scratch = (float*)handle->scratch + ltid * handle->desc.K; +LIBXSMM_VLA_DECL(2, const element_input_type, bias, (element_input_type*) handle->reg_bias->data, handle->bk); +#endif +#if defined(LIBXSMM_DNN_FC_FWD_FUSE_RELU) +LIBXSMM_VLA_DECL(4, __mmask32, relubitmask, (__mmask32*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk/32); +libxsmm_meltwfunction_unary eltwise_kernel = handle->fwd_cvtfp32bf16_relu_kernel; +libxsmm_meltw_unary_param eltwise_params; +#elif defined(LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID) +libxsmm_meltwfunction_unary eltwise_kernel = handle->fwd_sigmoid_cvtfp32bf16_kernel; +libxsmm_meltw_unary_param eltwise_params; +#else +libxsmm_meltwfunction_unary eltwise_kernel = handle->fwd_cvtfp32bf16_kernel; +libxsmm_meltw_unary_param eltwise_params; +#endif +#else +libxsmm_meltwfunction_unary eltwise_kernel = handle->fwd_cvtfp32bf16_kernel; +libxsmm_meltw_unary_param eltwise_params; +#endif + +unsigned long long blocks = nBlocksIFm; +int CB_BLOCKS = nBlocksIFm, BF = 1; + +BF = handle->fwd_bf; +CB_BLOCKS = nBlocksIFm/BF; +blocks = CB_BLOCKS; + +if (use_2d_blocking == 1) { + row_teams = handle->fwd_row_teams; + column_teams = handle->fwd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = (nBlocksMB + row_teams-1)/row_teams; + in_tasks_per_thread = (nBlocksOFm + column_teams-1)/column_teams; + my_im_start = LIBXSMM_MIN( my_row_id * im_tasks_per_thread, nBlocksMB); + my_im_end = LIBXSMM_MIN( (my_row_id+1) * im_tasks_per_thread, nBlocksMB); + my_in_start = LIBXSMM_MIN( my_col_id * in_tasks_per_thread, nBlocksOFm); + my_in_end = LIBXSMM_MIN( (my_col_id+1) * in_tasks_per_thread, nBlocksOFm); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +tile_config_kernel(NULL, NULL, NULL); + +if (use_2d_blocking == 1) { + if (BF > 1) { + for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { + /* Initialize intermediate f32 tensor */ + if ( ifm1 == 0 ) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 bn; ++mb2 ) { + LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, ofm1, 0,handle->bk), &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, 0, nBlocksOFm,handle->bn,handle->bk), handle->bk ); + } +#else + memset(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), 0, handle->bn*handle->bk*sizeof(float)); +#endif + } + +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk*sizeof(float)); + if ( ifm1 == BF-1 ) { + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk*sizeof(libxsmm_bfloat16)); + } +#endif + batchreduce_kernel( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); + + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ + if ( ifm1 == BF-1 ) { + eltwise_params.in.primary = &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk); + eltwise_params.out.primary = &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk); +#if defined(LIBXSMM_DNN_FC_FWD_FUSE_RELU) + eltwise_params.out.secondary = &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk/32); +#endif + eltwise_kernel(&eltwise_params); + } + } + } + } + } else { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, 0, 0,handle->bk), fp32_bias_scratch, handle->desc.K ); +#endif + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk*sizeof(libxsmm_bfloat16)); +#endif +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + gemm_eltwise_params.bias_ptr = (float*) fp32_bias_scratch + ofm1 * handle->bk; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + gemm_eltwise_params.out_ptr = &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk/32); +#endif + bf16_batchreduce_kernel_zerobeta_fused_eltwise( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &blocks, &gemm_eltwise_params); +#else + bf16_batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &blocks); +#endif + } + } + } +} else { + if (BF > 1) { + for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) { + for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; + /* Initialize intermediate f32 tensor */ + if ( ifm1 == 0 ) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 bn; ++mb2 ) { + LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, ofm1, 0,handle->bk), &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, 0, nBlocksOFm,handle->bn,handle->bk), handle->bk ); + } +#else + memset(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), 0, handle->bn*handle->bk*sizeof(float)); +#endif + } +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk*sizeof(float)); + if ( ifm1 == BF-1 ) { + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk*sizeof(libxsmm_bfloat16)); + } +#endif + batchreduce_kernel( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); + + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ + if ( ifm1 == BF-1 ) { + eltwise_params.in.primary = &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk); + eltwise_params.out.primary = &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk); +#if defined(LIBXSMM_DNN_FC_FWD_FUSE_RELU) + eltwise_params.out.secondary = &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk/32); +#endif + eltwise_kernel(&eltwise_params); + } + } + } + } else { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, 0, 0,handle->bk), fp32_bias_scratch, handle->desc.K ); +#endif + for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) { + mb1 = mb1ofm1%nBlocksMB; + ofm1 = mb1ofm1/nBlocksMB; +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk*sizeof(libxsmm_bfloat16)); +#endif +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + gemm_eltwise_params.bias_ptr = (float*) fp32_bias_scratch + ofm1 * handle->bk; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + gemm_eltwise_params.out_ptr = &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk/32); +#endif + bf16_batchreduce_kernel_zerobeta_fused_eltwise( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &blocks, &gemm_eltwise_params); +#else + bf16_batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &blocks); +#endif + } + } +} + +handle->tilerelease_kernel(NULL, NULL, NULL); +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..57f2712c9a6b29650b9b56f50b30d8e5833125d8 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c @@ -0,0 +1,177 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke (Intel Corp.) +******************************************************************************/ +/* size variables, all const */ +/* here we assume that input and output blocking is similar */ +const int nBlocksIFm = handle->desc.C / handle->bc; +const int nBlocksOFm = handle->desc.K / handle->bk; +const int nBlocksMB = handle->desc.N / handle->bn; +const int bn = handle->bn; +const int bk = handle->bk; +const int lpb = 2; +const int bc_lp = handle->bc/lpb; +/* const int bc = handle->bc;*/ +int use_2d_blocking = handle->fwd_2d_blocking; + +/* computing first logical thread */ +const int ltid = tid - start_thread; + +/* loop variables */ +int mb1 = 0, ofm1 = 0, ifm1 = 0; +int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; +LIBXSMM_VLA_DECL(4, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksOFm, handle->bn, handle->bk); +LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, handle->bn, handle->bc); + +LIBXSMM_VLA_DECL(5, const element_filter_type, filter_compressed, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc_lp, handle->bk/handle->sparsity_factor_A, lpb); +LIBXSMM_VLA_DECL(5, __mmask32, idx_filter_compressed, (__mmask32*) ((element_filter_type*)handle->reg_filter->data + (handle->desc.C*handle->desc.K)/handle->sparsity_factor_A), nBlocksIFm, bc_lp, handle->bk/32, lpb); +LIBXSMM_VLA_DECL(4, element_filter_type, decompressed_filter, (element_filter_type*)handle->scratch + ltid * handle->bk * handle->desc.C, bc_lp, handle->bk, lpb); + +float* temp_output = (float*)handle->scratch + (handle->desc.threads * handle->desc.C * handle->bk)/2; +LIBXSMM_VLA_DECL(4, float, output_f32, (float*) temp_output, nBlocksOFm, bn, bk); +libxsmm_meltw_gemm_param gemm_eltwise_params; + +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE +#if defined(LIBXSMM_DNN_FC_FWD_FUSE_BIAS) +int mb2 = 0; +float* fp32_bias_scratch = (float*)handle->scratch + (handle->desc.threads * handle->desc.C * handle->bk)/2 + ltid * handle->desc.K; +LIBXSMM_VLA_DECL(2, const element_input_type, bias, (element_input_type*) handle->reg_bias->data, handle->bk); +#endif +#if defined(LIBXSMM_DNN_FC_FWD_FUSE_RELU) +LIBXSMM_VLA_DECL(4, __mmask32, relubitmask, (__mmask32*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk/32); +libxsmm_meltwfunction_unary eltwise_kernel = handle->fwd_cvtfp32bf16_relu_kernel; +libxsmm_meltw_unary_param eltwise_params; +#elif defined(LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID) +libxsmm_meltwfunction_unary eltwise_kernel = handle->fwd_sigmoid_cvtfp32bf16_kernel; +libxsmm_meltw_unary_param eltwise_params; +#else +libxsmm_meltwfunction_unary eltwise_kernel = handle->fwd_cvtfp32bf16_kernel; +libxsmm_meltw_unary_param eltwise_params; +#endif +#else +libxsmm_meltwfunction_unary eltwise_kernel = handle->fwd_cvtfp32bf16_kernel; +libxsmm_meltw_unary_param eltwise_params; +#endif + +unsigned long long blocks = nBlocksIFm; +int CB_BLOCKS = nBlocksIFm, BF = 1; + +BF = handle->fwd_bf; +CB_BLOCKS = nBlocksIFm/BF; +blocks = CB_BLOCKS; + +if (use_2d_blocking == 1) { + row_teams = handle->fwd_row_teams; + column_teams = handle->fwd_column_teams; + my_col_id = ltid % column_teams; + my_row_id = ltid / column_teams; + im_tasks_per_thread = (nBlocksMB + row_teams-1)/row_teams; + in_tasks_per_thread = (nBlocksOFm + column_teams-1)/column_teams; + my_im_start = LIBXSMM_MIN( my_row_id * im_tasks_per_thread, nBlocksMB); + my_im_end = LIBXSMM_MIN( (my_row_id+1) * im_tasks_per_thread, nBlocksMB); + my_in_start = LIBXSMM_MIN( my_col_id * in_tasks_per_thread, nBlocksOFm); + my_in_end = LIBXSMM_MIN( (my_col_id+1) * in_tasks_per_thread, nBlocksOFm); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +tile_config_kernel(NULL, NULL, NULL); + +if (BF > 1) { + for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) { + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { + /* Initialize intermediate f32 tensor */ + if ( ifm1 == 0 ) { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + for ( mb2 = 0; mb2 bn; ++mb2 ) { + LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, ofm1, 0,handle->bk), &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, 0, nBlocksOFm,handle->bn,handle->bk), handle->bk ); + } +#else + memset(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), 0, handle->bn*handle->bk*sizeof(float)); +#endif + } + +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk*sizeof(float)); + if ( ifm1 == BF-1 ) { + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk*sizeof(libxsmm_bfloat16)); + } +#endif + if (mb1 == my_im_start) { + gemm_eltwise_params.sparse_bitmap = &LIBXSMM_VLA_ACCESS(5, idx_filter_compressed, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk/32, lpb); + gemm_eltwise_params.decompress_buffer = &LIBXSMM_VLA_ACCESS(4, decompressed_filter, 0, 0, 0, 0, bc_lp, handle->bk, lpb); + batchreduce_kernel_decompress( &LIBXSMM_VLA_ACCESS(5, filter_compressed, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk/handle->sparsity_factor_A, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks, &gemm_eltwise_params); + } else { + batchreduce_kernel( &LIBXSMM_VLA_ACCESS(4, decompressed_filter, 0, 0, 0, 0, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); + } + + /* downconvert intermediate f32 tensor to bf 16 and store to final C */ + if ( ifm1 == BF-1 ) { + eltwise_params.in.primary = &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk); + eltwise_params.out.primary = &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk); +#if defined(LIBXSMM_DNN_FC_FWD_FUSE_RELU) + eltwise_params.out.secondary = &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk/32); +#endif + eltwise_kernel(&eltwise_params); + } + } + } + } +} else { +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, 0, 0,handle->bk), fp32_bias_scratch, handle->desc.K ); +#endif + for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { + for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { +#ifdef WR_PREFETCH_OUTPUT + prefetchwt_chunk((char*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk*sizeof(libxsmm_bfloat16)); +#endif +#ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS + gemm_eltwise_params.bias_ptr = (float*) fp32_bias_scratch + ofm1 * handle->bk; +#endif +#ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU + gemm_eltwise_params.out_ptr = &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk/32); +#endif + if (mb1 == my_im_start) { + gemm_eltwise_params.sparse_bitmap = &LIBXSMM_VLA_ACCESS(5, idx_filter_compressed, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk/32, lpb); + gemm_eltwise_params.decompress_buffer = &LIBXSMM_VLA_ACCESS(4, decompressed_filter, 0, 0, 0, 0, bc_lp, handle->bk, lpb); + bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress( &LIBXSMM_VLA_ACCESS(5, filter_compressed, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk/handle->sparsity_factor_A, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &blocks, &gemm_eltwise_params); + } else { + bf16_batchreduce_kernel_zerobeta_fused_eltwise( &LIBXSMM_VLA_ACCESS(4, decompressed_filter, 0, 0, 0, 0, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &blocks, &gemm_eltwise_params); + } +#else + if (mb1 == my_im_start) { + gemm_eltwise_params.sparse_bitmap = &LIBXSMM_VLA_ACCESS(5, idx_filter_compressed, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk/32, lpb); + gemm_eltwise_params.decompress_buffer = &LIBXSMM_VLA_ACCESS(4, decompressed_filter, 0, 0, 0, 0, bc_lp, handle->bk, lpb); + bf16_batchreduce_kernel_zerobeta_decompress( &LIBXSMM_VLA_ACCESS(5, filter_compressed, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk/handle->sparsity_factor_A, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &blocks, &gemm_eltwise_params); + } else { + bf16_batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(4, decompressed_filter, 0, 0, 0, 0, bc_lp, handle->bk, lpb), + &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), + &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, bn, bk), &blocks); + } +#endif + } + } +} +handle->tilerelease_kernel(NULL, NULL, NULL); +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..d0acc711c140757be93843394866c1a1d5755aa7 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c @@ -0,0 +1,251 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.partN; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +const element_stats_type nhw = (element_stats_type)(handle->desc.fullN * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 16); +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 16); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, element_input_type, dinput_add, (element_input_type* )handle->grad_add->data, nBlocksFm, ifhp, ifwp, 16); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) +LIBXSMM_VLA_DECL(5, const element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 16); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 16); + +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, dgamma, (element_stats_type*)handle->grad_gamma->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, dbeta, (element_stats_type*)handle->grad_beta->data, 16); +LIBXSMM_VLA_DECL(2, const element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 16); +LIBXSMM_VLA_DECL(2, const element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 16); +LIBXSMM_VLA_DECL(3, element_stats_type, dgamma_img, (element_stats_type*)handle->scratch, nImg, 16); +LIBXSMM_VLA_DECL(3, element_stats_type, dbeta_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)16), nImg, 16); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, const unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 2); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) > 0) ) { + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + __m512 lcl_vbmean, lcl_vbrstd; + element_stats_type* del_gamma_img_ptr; + element_stats_type* del_beta_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, 16); + del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, 16); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 16) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 16) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + element_input_type* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(5, dinput_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + const element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 16); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + const unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 2); +#endif + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); + element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 16); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdeloutput = _mm512_load_act( del_output_ptr ); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + const __m512 value = _mm512_load_act( output_ptr ); + const __mmask16 lcl_relumask = _mm512_cmp_ps_mask( value, _mm512_setzero_ps(), _CMP_NEQ_OQ ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); + output_ptr += 16; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + const __mmask16 lcl_relumask = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr, lcl_vdeloutput ); + del_input_add_ptr += sw*16; +#endif + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ), lcl_vdeloutput ), lcl_vbrstd ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, lcl_vdeloutput ); + + input_ptr += sw*16; + del_output_ptr += 16; + } + } + + _mm512_storeu_ps( del_gamma_img_ptr, lcl_vdgamma ); + _mm512_storeu_ps( del_beta_img_ptr, lcl_vdbeta ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) ) { + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, 0, 0, nImg, 16); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, 0, 0, nImg, 16); + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + + for ( img=0; img < nImg; img++ ) { + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_loadu_ps( del_gamma_img_ptr ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, _mm512_loadu_ps( del_beta_img_ptr ) ); + del_gamma_img_ptr += 16; + del_beta_img_ptr += 16; + } + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 16), lcl_vdgamma ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 16), lcl_vdbeta ); + } + } else { + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, 0, 0, nImg, 16); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, 0, 0, nImg, 16); + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + + for ( img=0; img < nImg; img++ ) { + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_loadu_ps( del_gamma_img_ptr ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, _mm512_loadu_ps( del_beta_img_ptr ) ); + del_gamma_img_ptr += 16; + del_beta_img_ptr += 16; + } + + _mm512_storeu_ps( del_gamma_img_ptr - (nImg*16), lcl_vdgamma ); + _mm512_storeu_ps( del_beta_img_ptr - (nImg*16), lcl_vdbeta ); + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) > 0) ) { + /* now we apply the actual backward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbmean, lcl_vbrstd, lcl_vdgamma, lcl_vdbeta; + __m512 lcl_vnhw = _mm512_set1_ps( nhw ); + __m512 lcl_vrec_nhw = _mm512_set1_ps( recp_nhw ); + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 16) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 16) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 16) ); + lcl_vdgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 16) ); + lcl_vdbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 16) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + element_input_type* del_input_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); + const element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 16); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdelinput; + + lcl_vdelinput = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vdgamma ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vbrstd ); + lcl_vdelinput = _mm512_add_ps( lcl_vdbeta, lcl_vdelinput ); + lcl_vdelinput = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr ) ), lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vbrstd, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vgamma, lcl_vdelinput ); + _mm512_stream_act( del_input_ptr, lcl_vdelinput ); + + del_input_ptr += sw*16; + input_ptr += sw*16; + del_output_ptr += 16; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..dfc6c36886cda353f4fd50fdfb1380d8051a1a2e --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c @@ -0,0 +1,312 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.partN; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +const element_stats_type nhw = (element_stats_type)(handle->desc.fullN * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 32); +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 32); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, element_input_type, dinput_add, (element_input_type* )handle->grad_add->data, nBlocksFm, ifhp, ifwp, 32); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) +LIBXSMM_VLA_DECL(5, const element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 32); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 32); + +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, dgamma, (element_stats_type*)handle->grad_gamma->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, dbeta, (element_stats_type*)handle->grad_beta->data, 32); +LIBXSMM_VLA_DECL(2, const element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 32); +LIBXSMM_VLA_DECL(2, const element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 32); +LIBXSMM_VLA_DECL(3, element_stats_type, dgamma_img, (element_stats_type*)handle->scratch, nImg, 32); +LIBXSMM_VLA_DECL(3, element_stats_type, dbeta_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)32), nImg, 32); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, const unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 4); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) > 0) ) { + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + __m512 lcl_vdgamma2 = _mm512_setzero_ps(); + __m512 lcl_vdbeta2 = _mm512_setzero_ps(); + __m512 lcl_vbmean, lcl_vbrstd; + __m512 lcl_vbmean2, lcl_vbrstd2; + element_stats_type* del_gamma_img_ptr; + element_stats_type* del_beta_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, 32); + del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, 32); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 32) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 32) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 32) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 32) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + element_input_type* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(5, dinput_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + const element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 32); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + const unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 4); +#endif + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); + element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 32); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdeloutput, lcl_vdeloutput2; +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) || defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + const __m512 vzero = _mm512_setzero_ps(); + __mmask16 lcl_relumask, lcl_relumask2; +#endif + + lcl_vdeloutput = _mm512_load_act( del_output_ptr ); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + lcl_relumask = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr ), vzero, _CMP_NEQ_OQ ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, vzero, lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, vzero, lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr, lcl_vdeloutput ); +#endif + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ), lcl_vdeloutput ), lcl_vbrstd ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, lcl_vdeloutput ); + + lcl_vdeloutput2 = _mm512_load_act( del_output_ptr+16 ); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + lcl_relumask2 = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr+16 ), vzero, _CMP_NEQ_OQ ); + lcl_vdeloutput2 = _mm512_mask_blend_ps( lcl_relumask2, vzero, lcl_vdeloutput2 ); + _mm512_store_act( del_output_ptr+16, lcl_vdeloutput2 ); + output_ptr += 32; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask2 = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput2 = _mm512_mask_blend_ps( lcl_relumask2, vzero, lcl_vdeloutput2 ); + _mm512_store_act( del_output_ptr+16, lcl_vdeloutput2 ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr+16, lcl_vdeloutput2 ); + del_input_add_ptr += sw*32; +#endif + lcl_vdgamma2 = _mm512_add_ps( lcl_vdgamma2, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ), lcl_vdeloutput2 ), lcl_vbrstd2 ) ); + lcl_vdbeta2 = _mm512_add_ps( lcl_vdbeta2, lcl_vdeloutput2 ); + + input_ptr += sw*32; + del_output_ptr += 32; + } + } + + _mm512_storeu_ps( del_gamma_img_ptr, lcl_vdgamma ); + _mm512_storeu_ps( del_beta_img_ptr, lcl_vdbeta ); + _mm512_storeu_ps( del_gamma_img_ptr+16, lcl_vdgamma2 ); + _mm512_storeu_ps( del_beta_img_ptr+16, lcl_vdbeta2 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) ) { + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, 0, 0, nImg, 32); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, 0, 0, nImg, 32); + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + __m512 lcl_vdgamma2 = _mm512_setzero_ps(); + __m512 lcl_vdbeta2 = _mm512_setzero_ps(); + + for ( img=0; img < nImg; img++ ) { + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_loadu_ps( del_gamma_img_ptr ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, _mm512_loadu_ps( del_beta_img_ptr ) ); + lcl_vdgamma2 = _mm512_add_ps( lcl_vdgamma2, _mm512_loadu_ps( del_gamma_img_ptr+16 ) ); + lcl_vdbeta2 = _mm512_add_ps( lcl_vdbeta2, _mm512_loadu_ps( del_beta_img_ptr+16 ) ); + del_gamma_img_ptr += 32; + del_beta_img_ptr += 32; + } + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 32), lcl_vdgamma ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 32), lcl_vdbeta ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 16, 32), lcl_vdgamma2 ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 16, 32), lcl_vdbeta2 ); + } + } else { + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, 0, 0, nImg, 32); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, 0, 0, nImg, 32); + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + __m512 lcl_vdgamma2 = _mm512_setzero_ps(); + __m512 lcl_vdbeta2 = _mm512_setzero_ps(); + + for ( img=0; img < nImg; img++ ) { + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_loadu_ps( del_gamma_img_ptr ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, _mm512_loadu_ps( del_beta_img_ptr ) ); + lcl_vdgamma2 = _mm512_add_ps( lcl_vdgamma2, _mm512_loadu_ps( del_gamma_img_ptr+16 ) ); + lcl_vdbeta2 = _mm512_add_ps( lcl_vdbeta2, _mm512_loadu_ps( del_beta_img_ptr+16 ) ); + del_gamma_img_ptr += 32; + del_beta_img_ptr += 32; + } + + _mm512_storeu_ps( del_gamma_img_ptr - (32*nImg), lcl_vdgamma ); + _mm512_storeu_ps( del_beta_img_ptr - (32*nImg), lcl_vdbeta ); + _mm512_storeu_ps( del_gamma_img_ptr - (32*nImg) + 16, lcl_vdgamma2 ); + _mm512_storeu_ps( del_beta_img_ptr - (32*nImg) + 16, lcl_vdbeta2 ); + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) > 0) ) { + /* now we apply the actual backward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbmean, lcl_vbrstd, lcl_vdgamma, lcl_vdbeta; + __m512 lcl_vgamma2, lcl_vbmean2, lcl_vbrstd2, lcl_vdgamma2, lcl_vdbeta2; + __m512 lcl_vnhw = _mm512_set1_ps( nhw ); + __m512 lcl_vrec_nhw = _mm512_set1_ps( recp_nhw ); + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 32) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 32) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 32) ); + lcl_vdgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 32) ); + lcl_vdbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 32) ); + + lcl_vgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 16, 32) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 32) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 32) ); + lcl_vdgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 16, 32) ); + lcl_vdbeta2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 16, 32) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + element_input_type* del_input_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); + const element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 32); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdelinput; + __m512 lcl_vdelinput2; + + lcl_vdelinput = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vdgamma ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vbrstd ); + lcl_vdelinput = _mm512_add_ps( lcl_vdbeta, lcl_vdelinput ); + lcl_vdelinput = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr ) ), lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vbrstd, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vgamma, lcl_vdelinput ); + + lcl_vdelinput2 = _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vdelinput2, lcl_vdgamma2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vdelinput2, lcl_vbrstd2 ); + lcl_vdelinput2 = _mm512_add_ps( lcl_vdbeta2, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr+16 ) ), lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vbrstd2, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vgamma2, lcl_vdelinput2 ); + + _mm512_stream_act( del_input_ptr, lcl_vdelinput ); + _mm512_stream_act( del_input_ptr+16, lcl_vdelinput2 ); + + del_input_ptr += sw*32; + input_ptr += sw*32; + del_output_ptr += 32; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..3d09b9728e35adf78f2e3f3191f40df05cd683d6 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c @@ -0,0 +1,386 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.partN; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +const element_stats_type nhw = (element_stats_type)(handle->desc.fullN * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm * 4; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 64); +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 64); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, element_input_type, dinput_add, (element_input_type* )handle->grad_add->data, nBlocksFm, ifhp, ifwp, 64); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) +LIBXSMM_VLA_DECL(5, const element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 64); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 64); + +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, dgamma, (element_stats_type*)handle->grad_gamma->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, dbeta, (element_stats_type*)handle->grad_beta->data, 64); +LIBXSMM_VLA_DECL(2, const element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 64); +LIBXSMM_VLA_DECL(2, const element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 64); +LIBXSMM_VLA_DECL(3, element_stats_type, dgamma_img, (element_stats_type*)handle->scratch, nImg, 64); +LIBXSMM_VLA_DECL(3, element_stats_type, dbeta_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)64), nImg, 64); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, const unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 8); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) > 0) ) { + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + __m512 lcl_vdgamma2 = _mm512_setzero_ps(); + __m512 lcl_vdbeta2 = _mm512_setzero_ps(); + __m512 lcl_vdgamma3 = _mm512_setzero_ps(); + __m512 lcl_vdbeta3 = _mm512_setzero_ps(); + __m512 lcl_vdgamma4 = _mm512_setzero_ps(); + __m512 lcl_vdbeta4 = _mm512_setzero_ps(); + __m512 lcl_vbmean, lcl_vbrstd; + __m512 lcl_vbmean2, lcl_vbrstd2; + __m512 lcl_vbmean3, lcl_vbrstd3; + __m512 lcl_vbmean4, lcl_vbrstd4; + element_stats_type* del_gamma_img_ptr; + element_stats_type* del_beta_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, 64); + del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, 64); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 64) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 64) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 64) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 64) ); + lcl_vbmean3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 32, 64) ); + lcl_vbrstd3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 32, 64) ); + lcl_vbmean4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 48, 64) ); + lcl_vbrstd4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 48, 64) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + element_input_type* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(5, dinput_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + const element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 64); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + const unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 8); +#endif + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); + element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 64); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdeloutput, lcl_vdeloutput2, lcl_vdeloutput3, lcl_vdeloutput4; +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) || defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + __mmask16 lcl_relumask, lcl_relumask2, lcl_relumask3, lcl_relumask4; + const __m512 vzero = _mm512_setzero_ps(); +#endif + + lcl_vdeloutput = _mm512_load_act( del_output_ptr ); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + lcl_relumask = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr ), vzero, _CMP_NEQ_OQ ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, vzero, lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, vzero, lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr, lcl_vdeloutput ); +#endif + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ), lcl_vdeloutput ), lcl_vbrstd ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, lcl_vdeloutput ); + + lcl_vdeloutput2 = _mm512_load_act( del_output_ptr+16 ); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + lcl_relumask2 = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr+16 ), vzero, _CMP_NEQ_OQ ); + lcl_vdeloutput2 = _mm512_mask_blend_ps( lcl_relumask2, vzero, lcl_vdeloutput2 ); + _mm512_store_act( del_output_ptr+16, lcl_vdeloutput2 ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask2 = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput2 = _mm512_mask_blend_ps( lcl_relumask2, vzero, lcl_vdeloutput2 ); + _mm512_store_act( del_output_ptr+16, lcl_vdeloutput2 ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr+16, lcl_vdeloutput2 ); +#endif + lcl_vdgamma2 = _mm512_add_ps( lcl_vdgamma2, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ), lcl_vdeloutput2 ), lcl_vbrstd2 ) ); + lcl_vdbeta2 = _mm512_add_ps( lcl_vdbeta2, lcl_vdeloutput2 ); + + lcl_vdeloutput3 = _mm512_load_act( del_output_ptr+32 ); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + lcl_relumask3 = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr+32 ), vzero, _CMP_NEQ_OQ ); + lcl_vdeloutput3 = _mm512_mask_blend_ps( lcl_relumask3, vzero, lcl_vdeloutput3 ); + _mm512_store_act( del_output_ptr+32, lcl_vdeloutput3 ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask3 = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput3 = _mm512_mask_blend_ps( lcl_relumask3, vzero, lcl_vdeloutput3 ); + _mm512_store_act( del_output_ptr+32, lcl_vdeloutput3 ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr+32, lcl_vdeloutput3 ); +#endif + lcl_vdgamma3 = _mm512_add_ps( lcl_vdgamma3, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr+32 ), lcl_vbmean3 ), lcl_vdeloutput3 ), lcl_vbrstd3 ) ); + lcl_vdbeta3 = _mm512_add_ps( lcl_vdbeta3, lcl_vdeloutput3 ); + + lcl_vdeloutput4 = _mm512_load_act( del_output_ptr+48 ); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + lcl_relumask4 = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr+48 ), vzero, _CMP_NEQ_OQ ); + lcl_vdeloutput4 = _mm512_mask_blend_ps( lcl_relumask4, vzero, lcl_vdeloutput4 ); + _mm512_store_act( del_output_ptr+48, lcl_vdeloutput4 ); + output_ptr += 64; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask4 = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput4 = _mm512_mask_blend_ps( lcl_relumask4, vzero, lcl_vdeloutput4 ); + _mm512_store_act( del_output_ptr+48, lcl_vdeloutput4 ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr+48, lcl_vdeloutput4 ); + del_input_add_ptr += sw*64; +#endif + lcl_vdgamma4 = _mm512_add_ps( lcl_vdgamma4, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr+48 ), lcl_vbmean4 ), lcl_vdeloutput4 ), lcl_vbrstd4 ) ); + lcl_vdbeta4 = _mm512_add_ps( lcl_vdbeta4, lcl_vdeloutput4 ); + + input_ptr += sw*64; + del_output_ptr += 64; + } + } + + _mm512_storeu_ps( del_gamma_img_ptr, lcl_vdgamma ); + _mm512_storeu_ps( del_beta_img_ptr, lcl_vdbeta ); + _mm512_storeu_ps( del_gamma_img_ptr+16, lcl_vdgamma2 ); + _mm512_storeu_ps( del_beta_img_ptr+16, lcl_vdbeta2 ); + _mm512_storeu_ps( del_gamma_img_ptr+32, lcl_vdgamma3 ); + _mm512_storeu_ps( del_beta_img_ptr+32, lcl_vdbeta3 ); + _mm512_storeu_ps( del_gamma_img_ptr+48, lcl_vdgamma4 ); + _mm512_storeu_ps( del_beta_img_ptr+48, lcl_vdbeta4 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) ) { + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + + for ( img=0; img < nImg; img++ ) { + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_loadu_ps( del_gamma_img_ptr ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, _mm512_loadu_ps( del_beta_img_ptr ) ); + del_gamma_img_ptr += 64; + del_beta_img_ptr += 64; + } + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, (fm/4), ((fm%4)*16), 64), lcl_vdgamma ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, (fm/4), ((fm%4)*16), 64), lcl_vdbeta ); + } + } else { + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + + for ( img=0; img < nImg; img++ ) { + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_loadu_ps( del_gamma_img_ptr ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, _mm512_loadu_ps( del_beta_img_ptr ) ); + del_gamma_img_ptr += 64; + del_beta_img_ptr += 64; + } + + _mm512_storeu_ps( del_gamma_img_ptr - (64*nImg), lcl_vdgamma ); + _mm512_storeu_ps( del_beta_img_ptr - (64*nImg), lcl_vdbeta ); + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) > 0) ) { + /* now we apply the actual backward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbmean, lcl_vbrstd, lcl_vdgamma, lcl_vdbeta; + __m512 lcl_vgamma2, lcl_vbmean2, lcl_vbrstd2, lcl_vdgamma2, lcl_vdbeta2; + __m512 lcl_vgamma3, lcl_vbmean3, lcl_vbrstd3, lcl_vdgamma3, lcl_vdbeta3; + __m512 lcl_vgamma4, lcl_vbmean4, lcl_vbrstd4, lcl_vdgamma4, lcl_vdbeta4; + __m512 lcl_vnhw = _mm512_set1_ps( nhw ); + __m512 lcl_vrec_nhw = _mm512_set1_ps( recp_nhw ); + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 64) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 64) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 64) ); + lcl_vdgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 64) ); + lcl_vdbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 64) ); + + lcl_vgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 16, 64) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 64) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 64) ); + lcl_vdgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 16, 64) ); + lcl_vdbeta2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 16, 64) ); + + lcl_vgamma3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 32, 64) ); + lcl_vbmean3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 32, 64) ); + lcl_vbrstd3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 32, 64) ); + lcl_vdgamma3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 32, 64) ); + lcl_vdbeta3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 32, 64) ); + + lcl_vgamma4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 48, 64) ); + lcl_vbmean4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 48, 64) ); + lcl_vbrstd4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 48, 64) ); + lcl_vdgamma4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 48, 64) ); + lcl_vdbeta4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 48, 64) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + element_input_type* del_input_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); + const element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 64); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdelinput; + __m512 lcl_vdelinput2; + __m512 lcl_vdelinput3; + __m512 lcl_vdelinput4; + + lcl_vdelinput = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vdgamma ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vbrstd ); + lcl_vdelinput = _mm512_add_ps( lcl_vdbeta, lcl_vdelinput ); + lcl_vdelinput = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr ) ), lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vbrstd, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vgamma, lcl_vdelinput ); + + lcl_vdelinput2 = _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vdelinput2, lcl_vdgamma2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vdelinput2, lcl_vbrstd2 ); + lcl_vdelinput2 = _mm512_add_ps( lcl_vdbeta2, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr+16 ) ), lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vbrstd2, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vgamma2, lcl_vdelinput2 ); + + lcl_vdelinput3 = _mm512_sub_ps( _mm512_load_act( input_ptr+32 ), lcl_vbmean3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vdelinput3, lcl_vdgamma3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vdelinput3, lcl_vbrstd3 ); + lcl_vdelinput3 = _mm512_add_ps( lcl_vdbeta3, lcl_vdelinput3 ); + lcl_vdelinput3 = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr+32 ) ), lcl_vdelinput3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vbrstd3, lcl_vdelinput3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vgamma3, lcl_vdelinput3 ); + + lcl_vdelinput4 = _mm512_sub_ps( _mm512_load_act( input_ptr+48 ), lcl_vbmean4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vdelinput4, lcl_vdgamma4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vdelinput4, lcl_vbrstd4 ); + lcl_vdelinput4 = _mm512_add_ps( lcl_vdbeta4, lcl_vdelinput4 ); + lcl_vdelinput4 = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr+48 ) ), lcl_vdelinput4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vbrstd4, lcl_vdelinput4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vgamma4, lcl_vdelinput4 ); + + _mm512_stream_act( del_input_ptr, lcl_vdelinput ); + _mm512_stream_act( del_input_ptr+16, lcl_vdelinput2 ); + _mm512_stream_act( del_input_ptr+32, lcl_vdelinput3 ); + _mm512_stream_act( del_input_ptr+48, lcl_vdelinput4 ); + + del_input_ptr += sw*64; + input_ptr += sw*64; + del_output_ptr += 64; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..e7b286c427a6a177d9820fba44b17277c15da48d --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c @@ -0,0 +1,274 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +/* size variables, all const */ +const int nImg = handle->desc.partN; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; +const int nFmBlock = handle->ifmblock; + +const element_stats_type nhw = (element_stats_type)(handle->desc.fullN * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int v = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, nFmBlock); +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, element_input_type, dinput_add, (element_input_type* )handle->grad_add->data, nBlocksFm, ifhp, ifwp, nFmBlock); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) +LIBXSMM_VLA_DECL(5, const element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, nFmBlock); + +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, nFmBlock); +LIBXSMM_VLA_DECL(2, element_stats_type, dgamma, (element_stats_type*)handle->grad_gamma->data, nFmBlock); +LIBXSMM_VLA_DECL(2, element_stats_type, dbeta, (element_stats_type*)handle->grad_beta->data, nFmBlock); +LIBXSMM_VLA_DECL(2, const element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, nFmBlock); +LIBXSMM_VLA_DECL(2, const element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, nFmBlock); +LIBXSMM_VLA_DECL(3, element_stats_type, dgamma_img, (element_stats_type*)handle->scratch, nImg, nFmBlock); +LIBXSMM_VLA_DECL(3, element_stats_type, dbeta_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif + +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_BF16) +union libxsmm_bfloat16_hp input_f32; +union libxsmm_bfloat16_hp del_input_f32; +union libxsmm_bfloat16_hp del_output_f32; +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) +union libxsmm_bfloat16_hp output_f32; +output_f32.i[1] = 0; +output_f32.i[0] = 0; +#endif +input_f32.i[1] = 0; +input_f32.i[0] = 0; +del_output_f32.i[1] = 0; +del_output_f32.i[0] = 0; +del_input_f32.i[1] = 0; +del_input_f32.i[0] = 0; +#endif + +assert( nFmBlock <= 64 ); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) > 0) ) { + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + /* @TODO check if we can bake this in into scratch */ + element_stats_type lcl_gamma_ptr[64]; + element_stats_type lcl_beta_ptr[64]; + element_stats_type* del_gamma_img_ptr; + element_stats_type* del_beta_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, nFmBlock); + del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + lcl_gamma_ptr[v] = 0.0f; + lcl_beta_ptr[v] = 0.0f; + } + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + element_input_type* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(5, dinput_add, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + const element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + const unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); + element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); + const element_stats_type* bmean_ptr = &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, nFmBlock); + const element_stats_type* brstd_ptr = &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, nFmBlock); + +#if !defined(LIBXSMM_DNN_FUSEDBN_BWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif + for ( v=0; v < nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_BF16) + del_output_f32.i[1] = del_output_ptr[v]; + del_output_f32.i[0] = 0; +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + output_f32.i[1] = output_ptr[v]; + del_output_f32.f = LIBXSMM_FEQ(output_f32.f, 0) ? 0 : del_output_f32.f; + del_output_ptr[v] = del_output_f32.i[1]; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + del_output_ptr[v] = (element_output_type)(relumask_ptr[v] == 1 ? del_output_ptr[v] : 0); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + del_input_add_ptr[v] = del_output_ptr[v]; +#endif + input_f32.i[1] = input_ptr[v]; + lcl_gamma_ptr[v] += (input_f32.f - bmean_ptr[v]) * del_output_f32.f * brstd_ptr[v]; + lcl_beta_ptr[v] += del_output_f32.f; +#else +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU) + del_output_ptr[v] = LIBXSMM_FEQ(output_ptr[v], 0) ? 0 : del_output_ptr[v]; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK) + del_output_ptr[v] = (element_output_type)(relumask_ptr[v] == 1 ? del_output_ptr[v] : 0); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE) + del_input_add_ptr[v] = del_output_ptr[v]; +#endif + lcl_gamma_ptr[v] += (input_ptr[v] - bmean_ptr[v]) * del_output_ptr[v] * brstd_ptr[v]; + lcl_beta_ptr[v] += del_output_ptr[v]; +#endif + } + } + } + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + del_gamma_img_ptr[v] = lcl_gamma_ptr[v]; + del_beta_img_ptr[v] = lcl_beta_ptr[v]; + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) ) { + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, nFmBlock); + element_stats_type* del_beta_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + del_gamma_ptr[v] = (element_stats_type)0; + del_beta_ptr[v] = (element_stats_type)0; + } + + for ( img=0; img < nImg; img++ ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, nFmBlock); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + del_gamma_ptr[v] += del_gamma_img_ptr[v]; + del_beta_ptr[v] += del_beta_img_ptr[v]; + } + } + } + } else { + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, 0, 0, nImg, nFmBlock); + element_stats_type* del_beta_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, 0, 0, nImg, nFmBlock); + + for ( img=1; img < nImg; img++ ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, nFmBlock); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + del_gamma_ptr[v] += del_gamma_img_ptr[v]; + del_beta_ptr[v] += del_beta_img_ptr[v]; + } + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) > 0) ) { + /* now we apply the actual backward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + element_input_type* del_input_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); + const element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); + const element_stats_type* bmean_ptr = &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, nFmBlock); + const element_stats_type* brstd_ptr = &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, nFmBlock); + const element_stats_type* gamma_ptr = &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, nFmBlock); + const element_stats_type* del_gamma_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, nFmBlock); + const element_stats_type* del_beta_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, nFmBlock); + +#if !defined(LIBXSMM_DNN_FUSEDBN_BWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif + for ( v=0; v < nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_FUSEDBN_BWD_BF16) + del_output_f32.i[1] = del_output_ptr[v]; + input_f32.i[1] = input_ptr[v]; + del_input_f32.f = gamma_ptr[v] * brstd_ptr[v] * recp_nhw * (nhw*del_output_f32.f - + (del_beta_ptr[v] + (input_f32.f - bmean_ptr[v]) * del_gamma_ptr[v] * brstd_ptr[v])); + del_input_ptr[v] = del_input_f32.i[1]; +#else + del_input_ptr[v] = gamma_ptr[v] * brstd_ptr[v] * recp_nhw * (nhw*del_output_ptr[v] - + (del_beta_ptr[v] + (input_ptr[v] - bmean_ptr[v]) * del_gamma_ptr[v] * brstd_ptr[v])); +#endif + } + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..4d09a868fbc46308f87e8aa478b376042dae10cc --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c @@ -0,0 +1,248 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.partN; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* eps to avoid sqrt of zero */ +const element_stats_type sqrt_eps = 1e-7f; +const element_stats_type nhw = (element_stats_type)(handle->desc.fullN * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 16); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, const element_input_type, input_add, (element_input_type* )handle->reg_add->data, nBlocksFm, ifhp, ifwp, 16); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 16); +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 16); +LIBXSMM_VLA_DECL(2, const element_stats_type, beta, (element_stats_type*)handle->reg_beta->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, variance, (element_stats_type*)handle->variance->data, 16); +LIBXSMM_VLA_DECL(3, element_stats_type, sum_img, (element_stats_type*)handle->scratch, nImg, 16); +LIBXSMM_VLA_DECL(3, element_stats_type, sumsq_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * 16), nImg, 16); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 2); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) > 0) ) { + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr; + element_stats_type* sumsq_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, img, 0, nImg, 16); + sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, img, 0, nImg, 16); + + for ( hi=iph; hi < (ifh + iph); hi++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); + for ( wi=ipw; wi < (ifw + ipw); wi++ ) { + __m512 lcl_vinput = _mm512_load_act( input_ptr ); + lcl_vsum = _mm512_add_ps( lcl_vsum, lcl_vinput ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_mul_ps( lcl_vinput, lcl_vinput ) ); + + input_ptr += 16; + } + } + + _mm512_storeu_ps( sum_img_ptr, lcl_vsum ); + _mm512_storeu_ps( sumsq_img_ptr, lcl_vsumsq ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the sum and sum^2, we use the final */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, 0, 0, nImg, 16); + element_stats_type* sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, 0, 0, nImg, 16); + + for ( img=0; img < nImg; img++ ) { + lcl_vsum = _mm512_add_ps( lcl_vsum, _mm512_loadu_ps( sum_img_ptr ) ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_loadu_ps( sumsq_img_ptr ) ); + sum_img_ptr += 16; + sumsq_img_ptr += 16; + } + + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) ) { + __m512 lcl_vsqrt_eps = _mm512_set1_ps(sqrt_eps); + __m512 lcl_vrec_nhw = _mm512_set1_ps(recp_nhw); + __m512 lcl_vone = _mm512_set1_ps(1.0); + __m512 lcl_vbmean, lcl_vbmeansq, lcl_vsqbmean, lcl_vbrstd, lcl_vvar; + lcl_vbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsum ); /* E(X) */ + lcl_vbmeansq = _mm512_mul_ps( lcl_vbmean, lcl_vbmean ); /* E(X)^2 */ + lcl_vsqbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsumsq ); /* E(X^2) */ + lcl_vvar = _mm512_sub_ps( lcl_vsqbmean, lcl_vbmeansq ); /* variance */ +#if 0 + { + __m512d lcl_voned = _mm512_set1_pd(1.0); + __m512d lcl_vepsd = _mm512_set1_pd(1e-7); + __m512d lcl_vlo = _mm512_cvtps_pd( _mm256_castpd_ps( _mm512_extractf64x4_pd( _mm512_castps_pd( lcl_vvar ), 0 ) ) ); + __m512d lcl_vhi = _mm512_cvtps_pd( _mm256_castpd_ps( _mm512_extractf64x4_pd( _mm512_castps_pd( lcl_vvar ), 1 ) ) ); + lcl_vlo = _mm512_sqrt_pd( _mm512_add_pd( lcl_vlo, lcl_vepsd ) ); + lcl_vhi = _mm512_sqrt_pd( _mm512_add_pd( lcl_vhi, lcl_vepsd ) ); + lcl_vlo = _mm512_div_pd( lcl_voned, lcl_vlo ); + lcl_vhi = _mm512_div_pd( lcl_voned, lcl_vhi ); + lcl_vbrstd = _mm512_castpd_ps( _mm512_insertf64x4( _mm512_setzero_pd(), _mm256_castps_pd( _mm512_cvtpd_ps( lcl_vlo ) ), 0 ) ); + lcl_vbrstd = _mm512_castpd_ps( _mm512_insertf64x4( _mm512_castps_pd( lcl_vbrstd ), _mm256_castps_pd( _mm512_cvtpd_ps( lcl_vhi ) ), 1 ) ); + } +#else + lcl_vbrstd = _mm512_div_ps( lcl_vone, _mm512_sqrt_ps( _mm512_add_ps( lcl_vvar, lcl_vsqrt_eps ) ) ); +#endif + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 16), lcl_vbmean ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 16), lcl_vbrstd ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, variance, fm, 0, 16), lcl_vvar ); + } else { + sum_img_ptr -= 16*nImg; + sumsq_img_ptr -= 16*nImg; + + _mm512_storeu_ps( sum_img_ptr, lcl_vsum ); + _mm512_storeu_ps( sumsq_img_ptr, lcl_vsumsq ); + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) > 0) ) { + /* now we apply the actual forward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbeta, lcl_vbmean, lcl_vbrstd; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 16) ); + lcl_vbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 0, 16) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 16) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 16) ); + + for ( hi=iph, ho=oph; hi < (ifh+iph); hi+=sh, ho++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + const element_input_type* input_add_ptr = &LIBXSMM_VLA_ACCESS(5, input_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); +#endif + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 16); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 2); +#endif + for ( wi=ipw, wo=opw; wi < (ifw+ipw); wi+=sw, wo++ ) { + __m512 lcl_vo; +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + __mmask16 lcl_relumask; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vo = _mm512_mul_ps( lcl_vgamma, lcl_vo ); + lcl_vo = _mm512_fmadd_ps( lcl_vo, lcl_vbrstd, lcl_vbeta ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + lcl_vo = _mm512_add_ps( lcl_vo, _mm512_load_act( input_add_ptr ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU) + lcl_vo = _mm512_max_ps( lcl_vo, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = _mm512_cmp_ps_mask( lcl_vo, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vo ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask ); + relumask_ptr += 2; +#endif + _mm512_stream_act( output_ptr, lcl_vo ); + + input_ptr += sw*16; +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + input_add_ptr += sw*16; +#endif + output_ptr += 16; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..fac158a78e0823312e077f9ae98e988e430ff451 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c @@ -0,0 +1,294 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.partN; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* eps to avoid sqrt of zero */ +const element_stats_type sqrt_eps = 1e-7f; +const element_stats_type nhw = (element_stats_type)(handle->desc.fullN * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 32); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, const element_input_type, input_add, (element_input_type* )handle->reg_add->data, nBlocksFm, ifhp, ifwp, 32); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 32); +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 32); +LIBXSMM_VLA_DECL(2, const element_stats_type, beta, (element_stats_type*)handle->reg_beta->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, variance, (element_stats_type*)handle->variance->data, 32); +LIBXSMM_VLA_DECL(3, element_stats_type, sum_img, (element_stats_type*)handle->scratch, nImg, 32); +LIBXSMM_VLA_DECL(3, element_stats_type, sumsq_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * 32), nImg, 32); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 4); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) > 0) ) { + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + __m512 lcl_vsum2 = _mm512_setzero_ps(); + __m512 lcl_vsumsq2 = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr; + element_stats_type* sumsq_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, img, 0, nImg, 32); + sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, img, 0, nImg, 32); + + for ( hi=iph; hi < (ifh + iph); hi++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); + for ( wi=ipw; wi < (ifw + ipw); wi++ ) { + __m512 lcl_vinput = _mm512_load_act( input_ptr ); + __m512 lcl_vinput2 = _mm512_load_act( input_ptr+16 ); + + lcl_vsum = _mm512_add_ps( lcl_vsum, lcl_vinput ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_mul_ps( lcl_vinput, lcl_vinput ) ); + + lcl_vsum2 = _mm512_add_ps( lcl_vsum2, lcl_vinput2 ); + lcl_vsumsq2 = _mm512_add_ps( lcl_vsumsq2, _mm512_mul_ps( lcl_vinput2, lcl_vinput2 ) ); + + input_ptr += 32; + } + } + + _mm512_storeu_ps( sum_img_ptr, lcl_vsum ); + _mm512_storeu_ps( sumsq_img_ptr, lcl_vsumsq ); + + _mm512_storeu_ps( sum_img_ptr+16, lcl_vsum2 ); + _mm512_storeu_ps( sumsq_img_ptr+16, lcl_vsumsq2 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the sum and sum^2, we use the final */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + __m512 lcl_vsum2 = _mm512_setzero_ps(); + __m512 lcl_vsumsq2 = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, 0, 0, nImg, 32); + element_stats_type* sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, 0, 0, nImg, 32); + + for ( img=0; img < nImg; img++ ) { + lcl_vsum = _mm512_add_ps( lcl_vsum, _mm512_loadu_ps( sum_img_ptr ) ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_loadu_ps( sumsq_img_ptr ) ); + + lcl_vsum2 = _mm512_add_ps( lcl_vsum2, _mm512_loadu_ps( sum_img_ptr+16 ) ); + lcl_vsumsq2 = _mm512_add_ps( lcl_vsumsq2, _mm512_loadu_ps( sumsq_img_ptr+16 ) ); + + sum_img_ptr += 32; + sumsq_img_ptr += 32; + } + + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) ) { + __m512 lcl_vsqrt_eps = _mm512_set1_ps(sqrt_eps); + __m512 lcl_vrec_nhw = _mm512_set1_ps(recp_nhw); + __m512 lcl_vone = _mm512_set1_ps(1.0); + __m512 lcl_vbmean, lcl_vbmeansq, lcl_vsqbmean, lcl_vbrstd, lcl_vvar; + __m512 lcl_vbmean2, lcl_vbmeansq2, lcl_vsqbmean2, lcl_vbrstd2, lcl_vvar2; + + lcl_vbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsum ); /* E(X) */ + lcl_vbmeansq = _mm512_mul_ps( lcl_vbmean, lcl_vbmean ); /* E(X)^2 */ + lcl_vsqbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsumsq ); /* E(X^2) */ + lcl_vvar = _mm512_sub_ps( lcl_vsqbmean, lcl_vbmeansq ); /* variance */ + lcl_vbrstd = _mm512_div_ps( lcl_vone, _mm512_sqrt_ps( _mm512_add_ps( lcl_vvar, lcl_vsqrt_eps ) ) ); + + lcl_vbmean2 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsum2 ); /* E(X) */ + lcl_vbmeansq2 = _mm512_mul_ps( lcl_vbmean2, lcl_vbmean2 ); /* E(X)^2 */ + lcl_vsqbmean2 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsumsq2 ); /* E(X^2) */ + lcl_vvar2 = _mm512_sub_ps( lcl_vsqbmean2, lcl_vbmeansq2 ); /* variance */ + lcl_vbrstd2 = _mm512_div_ps( lcl_vone, _mm512_sqrt_ps( _mm512_add_ps( lcl_vvar2, lcl_vsqrt_eps ) ) ); + + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 32), lcl_vbmean ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 32), lcl_vbrstd ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, variance, fm, 0, 32), lcl_vvar ); + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 32), lcl_vbmean2 ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 32), lcl_vbrstd2 ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, variance, fm, 16, 32), lcl_vvar2 ); + } else { + sum_img_ptr -= 32*nImg; + sumsq_img_ptr -= 32*nImg; + + _mm512_storeu_ps( sum_img_ptr, lcl_vsum ); + _mm512_storeu_ps( sumsq_img_ptr, lcl_vsumsq ); + + _mm512_storeu_ps( sum_img_ptr+16, lcl_vsum2 ); + _mm512_storeu_ps( sumsq_img_ptr+16, lcl_vsumsq2 ); + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) > 0) ) { + /* now we apply the actual forward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbeta, lcl_vbmean, lcl_vbrstd; + __m512 lcl_vgamma2, lcl_vbeta2, lcl_vbmean2, lcl_vbrstd2; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 32) ); + lcl_vbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 0, 32) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 32) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 32) ); + + lcl_vgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 16, 32) ); + lcl_vbeta2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 16, 32) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 32) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 32) ); + + for ( hi=iph, ho=oph; hi < (ifh+iph); hi+=sh, ho++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + const element_input_type* input_add_ptr = &LIBXSMM_VLA_ACCESS(5, input_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); +#endif + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 32); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 4); +#endif + for ( wi=ipw, wo=opw; wi < (ifw+ipw); wi+=sw, wo++ ) { + __m512 lcl_vo; + __m512 lcl_vo2; +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + __mmask16 lcl_relumask; + __mmask16 lcl_relumask2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vo = _mm512_mul_ps( lcl_vgamma, lcl_vo ); + lcl_vo = _mm512_fmadd_ps( lcl_vo, lcl_vbrstd, lcl_vbeta ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + lcl_vo = _mm512_add_ps( lcl_vo, _mm512_load_act( input_add_ptr ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU) + lcl_vo = _mm512_max_ps( lcl_vo, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = _mm512_cmp_ps_mask( lcl_vo, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vo ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask ); + relumask_ptr += 2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo2 = _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ); + lcl_vo2 = _mm512_mul_ps( lcl_vgamma2, lcl_vo2 ); + lcl_vo2 = _mm512_fmadd_ps( lcl_vo2, lcl_vbrstd2, lcl_vbeta2 ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + lcl_vo2 = _mm512_add_ps( lcl_vo2, _mm512_load_act( input_add_ptr+16 ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU) + lcl_vo2 = _mm512_max_ps( lcl_vo2, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask2 = _mm512_cmp_ps_mask( lcl_vo2, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo2 = _mm512_mask_blend_ps( lcl_relumask2, _mm512_setzero_ps(), lcl_vo2 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask2 ); + relumask_ptr += 2; +#endif + + _mm512_stream_act( output_ptr, lcl_vo ); + _mm512_stream_act( output_ptr+16, lcl_vo2 ); + + input_ptr += sw*32; +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + input_add_ptr += sw*32; +#endif + output_ptr += 32; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..2aacc9a7ca6d5064b9e75ca065a2657cabbdbb44 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c @@ -0,0 +1,348 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.partN; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm * 4; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* eps to avoid sqrt of zero */ +const element_stats_type sqrt_eps = 1e-7f; +const element_stats_type nhw = (element_stats_type)(handle->desc.fullN * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 64); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, const element_input_type, input_add, (element_input_type* )handle->reg_add->data, nBlocksFm, ifhp, ifwp, 64); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 64); +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 64); +LIBXSMM_VLA_DECL(2, const element_stats_type, beta, (element_stats_type*)handle->reg_beta->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, variance, (element_stats_type*)handle->variance->data, 64); +LIBXSMM_VLA_DECL(3, element_stats_type, sum_img, (element_stats_type*)handle->scratch, nImg, 64); +LIBXSMM_VLA_DECL(3, element_stats_type, sumsq_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * 64), nImg, 64); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 8); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) > 0) ) { + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + __m512 lcl_vsum2 = _mm512_setzero_ps(); + __m512 lcl_vsumsq2 = _mm512_setzero_ps(); + __m512 lcl_vsum3 = _mm512_setzero_ps(); + __m512 lcl_vsumsq3 = _mm512_setzero_ps(); + __m512 lcl_vsum4 = _mm512_setzero_ps(); + __m512 lcl_vsumsq4 = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr; + element_stats_type* sumsq_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, img, 0, nImg, 64); + sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, img, 0, nImg, 64); + + for ( hi=iph; hi < (ifh + iph); hi++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); + for ( wi=ipw; wi < (ifw + ipw); wi++ ) { + __m512 lcl_vinput = _mm512_load_act( input_ptr ); + __m512 lcl_vinput2 = _mm512_load_act( input_ptr+16 ); + __m512 lcl_vinput3 = _mm512_load_act( input_ptr+32 ); + __m512 lcl_vinput4 = _mm512_load_act( input_ptr+48 ); + + lcl_vsum = _mm512_add_ps( lcl_vsum, lcl_vinput ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_mul_ps( lcl_vinput, lcl_vinput ) ); + + lcl_vsum2 = _mm512_add_ps( lcl_vsum2, lcl_vinput2 ); + lcl_vsumsq2 = _mm512_add_ps( lcl_vsumsq2, _mm512_mul_ps( lcl_vinput2, lcl_vinput2 ) ); + + lcl_vsum3 = _mm512_add_ps( lcl_vsum3, lcl_vinput3 ); + lcl_vsumsq3 = _mm512_add_ps( lcl_vsumsq3, _mm512_mul_ps( lcl_vinput3, lcl_vinput3 ) ); + + lcl_vsum4 = _mm512_add_ps( lcl_vsum4, lcl_vinput4 ); + lcl_vsumsq4 = _mm512_add_ps( lcl_vsumsq4, _mm512_mul_ps( lcl_vinput4, lcl_vinput4 ) ); + + input_ptr += 64; + } + } + + _mm512_storeu_ps( sum_img_ptr, lcl_vsum ); + _mm512_storeu_ps( sumsq_img_ptr, lcl_vsumsq ); + + _mm512_storeu_ps( sum_img_ptr+16, lcl_vsum2 ); + _mm512_storeu_ps( sumsq_img_ptr+16, lcl_vsumsq2 ); + + _mm512_storeu_ps( sum_img_ptr+32, lcl_vsum3 ); + _mm512_storeu_ps( sumsq_img_ptr+32, lcl_vsumsq3 ); + + _mm512_storeu_ps( sum_img_ptr+48, lcl_vsum4 ); + _mm512_storeu_ps( sumsq_img_ptr+48, lcl_vsumsq4 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the sum and sum^2, we use the final */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + element_stats_type* sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + + for ( img=0; img < nImg; img++ ) { + lcl_vsum = _mm512_add_ps( lcl_vsum, _mm512_loadu_ps( sum_img_ptr ) ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_loadu_ps( sumsq_img_ptr ) ); + + sum_img_ptr += 64; + sumsq_img_ptr += 64; + } + + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) ) { + __m512 lcl_vsqrt_eps = _mm512_set1_ps(sqrt_eps); + __m512 lcl_vrec_nhw = _mm512_set1_ps(recp_nhw); + __m512 lcl_vone = _mm512_set1_ps(1.0); + __m512 lcl_vbmean, lcl_vbmeansq, lcl_vsqbmean, lcl_vbrstd, lcl_vvar; + + lcl_vbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsum ); /* E(X) */ + lcl_vbmeansq = _mm512_mul_ps( lcl_vbmean, lcl_vbmean ); /* E(X)^2 */ + lcl_vsqbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsumsq ); /* E(X^2) */ + lcl_vvar = _mm512_sub_ps( lcl_vsqbmean, lcl_vbmeansq ); /* variance */ + lcl_vbrstd = _mm512_div_ps( lcl_vone, _mm512_sqrt_ps( _mm512_add_ps( lcl_vvar, lcl_vsqrt_eps ) ) ); + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, (fm/4), ((fm%4)*16), 64), lcl_vbmean ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, (fm/4), ((fm%4)*16), 64), lcl_vbrstd ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, variance, (fm/4), ((fm%4)*16), 64), lcl_vvar ); + } else { + sum_img_ptr -= 64*nImg; + sumsq_img_ptr -= 64*nImg; + + _mm512_storeu_ps( sum_img_ptr, lcl_vsum ); + _mm512_storeu_ps( sumsq_img_ptr, lcl_vsumsq ); + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) > 0) ) { + /* now we apply the actual forward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbeta, lcl_vbmean, lcl_vbrstd; + __m512 lcl_vgamma2, lcl_vbeta2, lcl_vbmean2, lcl_vbrstd2; + __m512 lcl_vgamma3, lcl_vbeta3, lcl_vbmean3, lcl_vbrstd3; + __m512 lcl_vgamma4, lcl_vbeta4, lcl_vbmean4, lcl_vbrstd4; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 64) ); + lcl_vbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 0, 64) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 64) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 64) ); + + lcl_vgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 16, 64) ); + lcl_vbeta2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 16, 64) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 64) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 64) ); + + lcl_vgamma3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 32, 64) ); + lcl_vbeta3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 32, 64) ); + lcl_vbmean3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 32, 64) ); + lcl_vbrstd3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 32, 64) ); + + lcl_vgamma4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 48, 64) ); + lcl_vbeta4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 48, 64) ); + lcl_vbmean4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 48, 64) ); + lcl_vbrstd4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 48, 64) ); + + for ( hi=iph, ho=oph; hi < (ifh+iph); hi+=sh, ho++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + const element_input_type* input_add_ptr = &LIBXSMM_VLA_ACCESS(5, input_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); +#endif + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 64); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 8); +#endif + for ( wi=ipw, wo=opw; wi < (ifw+ipw); wi+=sw, wo++ ) { + __m512 lcl_vo; + __m512 lcl_vo2; + __m512 lcl_vo3; + __m512 lcl_vo4; +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + __mmask16 lcl_relumask; + __mmask16 lcl_relumask2; + __mmask16 lcl_relumask3; + __mmask16 lcl_relumask4; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vo = _mm512_mul_ps( lcl_vgamma, lcl_vo ); + lcl_vo = _mm512_fmadd_ps( lcl_vo, lcl_vbrstd, lcl_vbeta ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + lcl_vo = _mm512_add_ps( lcl_vo, _mm512_load_act( input_add_ptr ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU) + lcl_vo = _mm512_max_ps( lcl_vo, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = _mm512_cmp_ps_mask( lcl_vo, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vo ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask ); + relumask_ptr += 2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo2 = _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ); + lcl_vo2 = _mm512_mul_ps( lcl_vgamma2, lcl_vo2 ); + lcl_vo2 = _mm512_fmadd_ps( lcl_vo2, lcl_vbrstd2, lcl_vbeta2 ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + lcl_vo2 = _mm512_add_ps( lcl_vo2, _mm512_load_act( input_add_ptr+16 ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU) + lcl_vo2 = _mm512_max_ps( lcl_vo2, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask2 = _mm512_cmp_ps_mask( lcl_vo2, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo2 = _mm512_mask_blend_ps( lcl_relumask2, _mm512_setzero_ps(), lcl_vo2 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask2 ); + relumask_ptr += 2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo3 = _mm512_sub_ps( _mm512_load_act( input_ptr+32 ), lcl_vbmean3 ); + lcl_vo3 = _mm512_mul_ps( lcl_vgamma3, lcl_vo3 ); + lcl_vo3 = _mm512_fmadd_ps( lcl_vo3, lcl_vbrstd3, lcl_vbeta3 ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + lcl_vo3 = _mm512_add_ps( lcl_vo3, _mm512_load_act( input_add_ptr+32 ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU) + lcl_vo3 = _mm512_max_ps( lcl_vo3, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask3 = _mm512_cmp_ps_mask( lcl_vo3, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo3 = _mm512_mask_blend_ps( lcl_relumask3, _mm512_setzero_ps(), lcl_vo3 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask3 ); + relumask_ptr += 2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo4 = _mm512_sub_ps( _mm512_load_act( input_ptr+48 ), lcl_vbmean4 ); + lcl_vo4 = _mm512_mul_ps( lcl_vgamma4, lcl_vo4 ); + lcl_vo4 = _mm512_fmadd_ps( lcl_vo4, lcl_vbrstd4, lcl_vbeta4 ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + lcl_vo4 = _mm512_add_ps( lcl_vo4, _mm512_load_act( input_add_ptr+48 ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU) + lcl_vo4 = _mm512_max_ps( lcl_vo4, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask4 = _mm512_cmp_ps_mask( lcl_vo4, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo4 = _mm512_mask_blend_ps( lcl_relumask4, _mm512_setzero_ps(), lcl_vo4 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask4 ); + relumask_ptr += 2; +#endif + + _mm512_stream_act( output_ptr, lcl_vo ); + _mm512_stream_act( output_ptr+16, lcl_vo2 ); + _mm512_stream_act( output_ptr+32, lcl_vo3 ); + _mm512_stream_act( output_ptr+48, lcl_vo4 ); + + input_ptr += sw*64; +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + input_add_ptr += sw*64; +#endif + output_ptr += 64; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..76e512f8e30c28e3db14a6122c42e98a427dd3f7 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c @@ -0,0 +1,265 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +/* size variables, all const */ +const int nImg = handle->desc.partN; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; +const int nFmBlock = handle->ifmblock; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* eps to avoid sqrt of zero */ +const element_stats_type sqrt_eps = 1e-7f; +const element_stats_type nhw = (element_stats_type)(handle->desc.fullN * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int v = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, const element_input_type, input_add, (element_input_type* )handle->reg_add->data, nBlocksFm, ifhp, ifwp, nFmBlock); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, nFmBlock); +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, nFmBlock); +LIBXSMM_VLA_DECL(2, const element_stats_type, beta, (element_stats_type*)handle->reg_beta->data, nFmBlock); +LIBXSMM_VLA_DECL(2, element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, nFmBlock); +LIBXSMM_VLA_DECL(2, element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, nFmBlock); +LIBXSMM_VLA_DECL(2, element_stats_type, variance, (element_stats_type*)handle->variance->data, nFmBlock); +LIBXSMM_VLA_DECL(3, element_stats_type, sum_img, (element_stats_type*)handle->scratch, nImg, nFmBlock); +LIBXSMM_VLA_DECL(3, element_stats_type, sumsq_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif + +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) +union libxsmm_bfloat16_hp input_f32; +union libxsmm_bfloat16_hp output_f32; +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) +union libxsmm_bfloat16_hp input_add_f32; +input_add_f32.i[1] = 0; +input_add_f32.i[0] = 0; +#endif +input_f32.i[1] = 0; +input_f32.i[0] = 0; +output_f32.i[1] = 0; +output_f32.i[0] = 0; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) > 0) ) { + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + /* @TODO check if we can bake this in into scratch */ + element_stats_type lcl_sum_ptr[64]; + element_stats_type lcl_sumsq_ptr[64]; + element_stats_type* sum_img_ptr; + element_stats_type* sumsq_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, img, 0, nImg, nFmBlock); + sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, img, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + lcl_sum_ptr[v] = (element_stats_type)0; + lcl_sumsq_ptr[v] = (element_stats_type)0; + } + + for ( hi=iph; hi < (ifh + iph); hi++ ) { + for ( wi=ipw; wi < (ifw + ipw); wi++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); + +#if !defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif + for (v=0; v < nFmBlock; v++) { +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) + input_f32.i[1] = input_ptr[v]; + lcl_sum_ptr[v] += input_f32.f; + lcl_sumsq_ptr[v] += (input_f32.f * input_f32.f); +#else + lcl_sum_ptr[v] += input_ptr[v]; + lcl_sumsq_ptr[v] += (input_ptr[v] * input_ptr[v]); +#endif + } + } + } + + LIBXSMM_PRAGMA_SIMD + for (v=0; v < nFmBlock; v++) { + sum_img_ptr[v] = lcl_sum_ptr[v]; + sumsq_img_ptr[v] = lcl_sumsq_ptr[v]; + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the sum and sum^2, we use the final */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + /* @TODO check if we can bake this in into scratch */ + element_stats_type lcl_sum_ptr[64]; + element_stats_type lcl_sumsq_ptr[64]; + element_stats_type* bmean_ptr = &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, nFmBlock); + element_stats_type* brstd_ptr = &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, nFmBlock); + element_stats_type* tvar_ptr = &LIBXSMM_VLA_ACCESS(2, variance, fm, 0, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + lcl_sum_ptr[v] = (element_stats_type)0; + lcl_sumsq_ptr[v] = (element_stats_type)0; + } + + for ( img=0; img < nImg; img++ ) { + element_stats_type* sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, img, 0, nImg, nFmBlock); + element_stats_type* sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, img, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + lcl_sum_ptr[v] += sum_img_ptr[v]; + lcl_sumsq_ptr[v] += sumsq_img_ptr[v]; + } + } + + if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) > 0) ) { + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + const element_stats_type tbmean = (recp_nhw * lcl_sum_ptr[v]); + const element_stats_type tbmeansq = tbmean * tbmean; + const element_stats_type tsqbmean = recp_nhw * lcl_sumsq_ptr[v]; + const element_stats_type tvar = tsqbmean - tbmeansq; + const element_stats_type tbrstd = (element_stats_type)(1.0/sqrt((double)tvar + sqrt_eps)); + bmean_ptr[v] = tbmean; + brstd_ptr[v] = tbrstd; + tvar_ptr[v] = tvar; + } + } else { + element_stats_type* sum_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, 0, 0, nImg, nFmBlock); + element_stats_type* sumsq_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, 0, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + sum_ptr[v] = lcl_sum_ptr[v]; + sumsq_ptr[v] = lcl_sumsq_ptr[v]; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0) || + ((handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) > 0) ) { + /* now we apply the actual forward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + for ( hi=iph, ho=oph; hi < (ifh+iph); hi+=sh, ho++ ) { + for ( wi=ipw, wo=opw; wi < (ifw+ipw); wi+=sw, wo++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) + const element_input_type* input_add_ptr = &LIBXSMM_VLA_ACCESS(5, input_add, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); +#endif + const element_stats_type* gamma_ptr = &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, nFmBlock); + const element_stats_type* beta_ptr = &LIBXSMM_VLA_ACCESS(2, beta, fm, 0, nFmBlock); + const element_stats_type* bmean_ptr = &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, nFmBlock); + const element_stats_type* brstd_ptr = &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, nFmBlock); + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif + float o; + +#if !defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif + for (v = 0; v < nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) + input_f32.i[1] = input_ptr[v]; + o = gamma_ptr[v]*(input_f32.f - bmean_ptr[v])*brstd_ptr[v] + beta_ptr[v]; +#else + /* BN + scale (gamma, beta) */ + o = gamma_ptr[v]*(input_ptr[v] - bmean_ptr[v])*brstd_ptr[v] + beta_ptr[v]; +#endif + /* Eltwise */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE) +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) + input_add_f32.i[1] = input_add_ptr[v]; + o += input_add_f32.f; +#else + o += input_add_ptr[v]; +#endif +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU) + o = ( o > 0.0f ) ? o : 0.0f; +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK) + o = ( o > 0.0f ) ? o : 0.0f; + relumask_ptr[v] = (unsigned char)(o > 0.0f ? 1 : 0); +#endif +#if defined(LIBXSMM_DNN_FUSEDBN_FWD_BF16) + output_f32.f = o; + output_ptr[v] = output_f32.i[1]; +#else + output_ptr[v] = o; +#endif + } + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..cb10fbd81c687e272a91b3042525e479597ad6fd --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c @@ -0,0 +1,222 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +const element_stats_type nhw = (element_stats_type)(handle->desc.N * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 16); +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 16); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, element_input_type, dinput_add, (element_input_type* )handle->grad_add->data, nBlocksFm, ifhp, ifwp, 16); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) +LIBXSMM_VLA_DECL(5, const element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 16); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 16); + +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, dgamma, (element_stats_type*)handle->grad_gamma->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, dbeta, (element_stats_type*)handle->grad_beta->data, 16); +LIBXSMM_VLA_DECL(2, const element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 16); +LIBXSMM_VLA_DECL(2, const element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 16); +LIBXSMM_VLA_DECL(3, element_stats_type, dgamma_img, (element_stats_type*)handle->scratch, nImg, 16); +LIBXSMM_VLA_DECL(3, element_stats_type, dbeta_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)16), nImg, 16); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, const unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 2); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + __m512 lcl_vbmean, lcl_vbrstd; + element_stats_type* del_gamma_img_ptr; + element_stats_type* del_beta_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, 16); + del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, 16); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 16) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 16) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + element_input_type* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(5, dinput_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + const element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 16); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + const unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 2); +#endif + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); + element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 16); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdeloutput = _mm512_load_act( del_output_ptr ); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + const __mmask16 lcl_relumask = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr ), _mm512_setzero_ps(), _CMP_NEQ_OQ ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); + output_ptr += 16; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + const __mmask16 lcl_relumask = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr, lcl_vdeloutput ); + del_input_add_ptr += sw*16; +#endif + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ), lcl_vdeloutput ), lcl_vbrstd ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, lcl_vdeloutput ); + + input_ptr += sw*16; + del_output_ptr += 16; + } + } + + _mm512_storeu_ps( del_gamma_img_ptr, lcl_vdgamma ); + _mm512_storeu_ps( del_beta_img_ptr, lcl_vdbeta ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, 0, 0, nImg, 16); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, 0, 0, nImg, 16); + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + + for ( img=0; img < nImg; img++ ) { + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_loadu_ps( del_gamma_img_ptr ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, _mm512_loadu_ps( del_beta_img_ptr ) ); + del_gamma_img_ptr += 16; + del_beta_img_ptr += 16; + } + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 16), lcl_vdgamma ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 16), lcl_vdbeta ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we apply the actual backward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbmean, lcl_vbrstd, lcl_vdgamma, lcl_vdbeta; + __m512 lcl_vnhw = _mm512_set1_ps( nhw ); + __m512 lcl_vrec_nhw = _mm512_set1_ps( recp_nhw ); + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 16) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 16) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 16) ); + lcl_vdgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 16) ); + lcl_vdbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 16) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + element_input_type* del_input_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); + const element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 16); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdelinput; + + lcl_vdelinput = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vdgamma ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vbrstd ); + lcl_vdelinput = _mm512_add_ps( lcl_vdbeta, lcl_vdelinput ); + lcl_vdelinput = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr ) ), lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vbrstd, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vgamma, lcl_vdelinput ); + _mm512_stream_act( del_input_ptr, lcl_vdelinput ); + + del_input_ptr += sw*16; + input_ptr += sw*16; + del_output_ptr += 16; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..11d7dd00fc55e969458e0bfb424232cc1b1dc5be --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c @@ -0,0 +1,280 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +const element_stats_type nhw = (element_stats_type)(handle->desc.N * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 32); +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 32); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, element_input_type, dinput_add, (element_input_type* )handle->grad_add->data, nBlocksFm, ifhp, ifwp, 32); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) +LIBXSMM_VLA_DECL(5, const element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 32); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 32); + +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, dgamma, (element_stats_type*)handle->grad_gamma->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, dbeta, (element_stats_type*)handle->grad_beta->data, 32); +LIBXSMM_VLA_DECL(2, const element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 32); +LIBXSMM_VLA_DECL(2, const element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 32); +LIBXSMM_VLA_DECL(3, element_stats_type, dgamma_img, (element_stats_type*)handle->scratch, nImg, 32); +LIBXSMM_VLA_DECL(3, element_stats_type, dbeta_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)32), nImg, 32); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, const unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 4); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + __m512 lcl_vdgamma2 = _mm512_setzero_ps(); + __m512 lcl_vdbeta2 = _mm512_setzero_ps(); + __m512 lcl_vbmean, lcl_vbrstd; + __m512 lcl_vbmean2, lcl_vbrstd2; + element_stats_type* del_gamma_img_ptr; + element_stats_type* del_beta_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, 32); + del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, 32); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 32) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 32) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 32) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 32) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + element_input_type* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(5, dinput_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + const element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 32); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + const unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 4); +#endif + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); + element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 32); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdeloutput, lcl_vdeloutput2; +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + __mmask16 lcl_relumask, lcl_relumask2; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + __mmask16 lcl_relumask, lcl_relumask2; +#endif + + lcl_vdeloutput = _mm512_load_act( del_output_ptr ); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + lcl_relumask = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr ), _mm512_setzero_ps(), _CMP_NEQ_OQ ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr, lcl_vdeloutput ); +#endif + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ), lcl_vdeloutput ), lcl_vbrstd ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, lcl_vdeloutput ); + + lcl_vdeloutput2 = _mm512_load_act( del_output_ptr+16 ); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + lcl_relumask2 = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr+16 ), _mm512_setzero_ps(), _CMP_NEQ_OQ ); + lcl_vdeloutput2 = _mm512_mask_blend_ps( lcl_relumask2, _mm512_setzero_ps(), lcl_vdeloutput2 ); + _mm512_store_act( del_output_ptr+16, lcl_vdeloutput2 ); + output_ptr += 32; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask2 = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput2 = _mm512_mask_blend_ps( lcl_relumask2, _mm512_setzero_ps(), lcl_vdeloutput2 ); + _mm512_store_act( del_output_ptr+16, lcl_vdeloutput2 ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr+16, lcl_vdeloutput2 ); + del_input_add_ptr += sw*32; +#endif + lcl_vdgamma2 = _mm512_add_ps( lcl_vdgamma2, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ), lcl_vdeloutput2 ), lcl_vbrstd2 ) ); + lcl_vdbeta2 = _mm512_add_ps( lcl_vdbeta2, lcl_vdeloutput2 ); + + input_ptr += sw*32; + del_output_ptr += 32; + } + } + + _mm512_storeu_ps( del_gamma_img_ptr, lcl_vdgamma ); + _mm512_storeu_ps( del_beta_img_ptr, lcl_vdbeta ); + _mm512_storeu_ps( del_gamma_img_ptr+16, lcl_vdgamma2 ); + _mm512_storeu_ps( del_beta_img_ptr+16, lcl_vdbeta2 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, 0, 0, nImg, 32); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, 0, 0, nImg, 32); + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + __m512 lcl_vdgamma2 = _mm512_setzero_ps(); + __m512 lcl_vdbeta2 = _mm512_setzero_ps(); + + for ( img=0; img < nImg; img++ ) { + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_loadu_ps( del_gamma_img_ptr ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, _mm512_loadu_ps( del_beta_img_ptr ) ); + lcl_vdgamma2 = _mm512_add_ps( lcl_vdgamma2, _mm512_loadu_ps( del_gamma_img_ptr+16 ) ); + lcl_vdbeta2 = _mm512_add_ps( lcl_vdbeta2, _mm512_loadu_ps( del_beta_img_ptr+16 ) ); + del_gamma_img_ptr += 32; + del_beta_img_ptr += 32; + } + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 32), lcl_vdgamma ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 32), lcl_vdbeta ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 16, 32), lcl_vdgamma2 ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 16, 32), lcl_vdbeta2 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we apply the actual backward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbmean, lcl_vbrstd, lcl_vdgamma, lcl_vdbeta; + __m512 lcl_vgamma2, lcl_vbmean2, lcl_vbrstd2, lcl_vdgamma2, lcl_vdbeta2; + __m512 lcl_vnhw = _mm512_set1_ps( nhw ); + __m512 lcl_vrec_nhw = _mm512_set1_ps( recp_nhw ); + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 32) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 32) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 32) ); + lcl_vdgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 32) ); + lcl_vdbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 32) ); + + lcl_vgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 16, 32) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 32) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 32) ); + lcl_vdgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 16, 32) ); + lcl_vdbeta2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 16, 32) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + element_input_type* del_input_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); + const element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 32); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdelinput; + __m512 lcl_vdelinput2; + + lcl_vdelinput = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vdgamma ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vbrstd ); + lcl_vdelinput = _mm512_add_ps( lcl_vdbeta, lcl_vdelinput ); + lcl_vdelinput = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr ) ), lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vbrstd, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vgamma, lcl_vdelinput ); + + lcl_vdelinput2 = _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vdelinput2, lcl_vdgamma2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vdelinput2, lcl_vbrstd2 ); + lcl_vdelinput2 = _mm512_add_ps( lcl_vdbeta2, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr+16 ) ), lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vbrstd2, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vgamma2, lcl_vdelinput2 ); + + _mm512_stream_act( del_input_ptr, lcl_vdelinput ); + _mm512_stream_act( del_input_ptr+16, lcl_vdelinput2 ); + + del_input_ptr += sw*32; + input_ptr += sw*32; + del_output_ptr += 32; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..b3c582319358bccc1542d140352707a039e7ed86 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c @@ -0,0 +1,360 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +const element_stats_type nhw = (element_stats_type)(handle->desc.N * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm * 4; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 64); +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 64); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, element_input_type, dinput_add, (element_input_type* )handle->grad_add->data, nBlocksFm, ifhp, ifwp, 64); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) +LIBXSMM_VLA_DECL(5, const element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 64); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 64); + +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, dgamma, (element_stats_type*)handle->grad_gamma->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, dbeta, (element_stats_type*)handle->grad_beta->data, 64); +LIBXSMM_VLA_DECL(2, const element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 64); +LIBXSMM_VLA_DECL(2, const element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 64); +LIBXSMM_VLA_DECL(3, element_stats_type, dgamma_img, (element_stats_type*)handle->scratch, nImg, 64); +LIBXSMM_VLA_DECL(3, element_stats_type, dbeta_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)64), nImg, 64); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, const unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 8); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + __m512 lcl_vdgamma2 = _mm512_setzero_ps(); + __m512 lcl_vdbeta2 = _mm512_setzero_ps(); + __m512 lcl_vdgamma3 = _mm512_setzero_ps(); + __m512 lcl_vdbeta3 = _mm512_setzero_ps(); + __m512 lcl_vdgamma4 = _mm512_setzero_ps(); + __m512 lcl_vdbeta4 = _mm512_setzero_ps(); + __m512 lcl_vbmean, lcl_vbrstd; + __m512 lcl_vbmean2, lcl_vbrstd2; + __m512 lcl_vbmean3, lcl_vbrstd3; + __m512 lcl_vbmean4, lcl_vbrstd4; + element_stats_type* del_gamma_img_ptr; + element_stats_type* del_beta_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, 64); + del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, 64); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 64) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 64) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 64) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 64) ); + lcl_vbmean3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 32, 64) ); + lcl_vbrstd3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 32, 64) ); + lcl_vbmean4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 48, 64) ); + lcl_vbrstd4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 48, 64) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + element_input_type* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(5, dinput_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + const element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 64); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + const unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 8); +#endif + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); + element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 64); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdeloutput, lcl_vdeloutput2, lcl_vdeloutput3, lcl_vdeloutput4; +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + __mmask16 lcl_relumask, lcl_relumask2, lcl_relumask3, lcl_relumask4; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + __mmask16 lcl_relumask, lcl_relumask2, lcl_relumask3, lcl_relumask4; +#endif + + lcl_vdeloutput = _mm512_load_act( del_output_ptr ); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + lcl_relumask = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr ), _mm512_setzero_ps(), _CMP_NEQ_OQ ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vdeloutput ); + _mm512_store_act( del_output_ptr, lcl_vdeloutput ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr, lcl_vdeloutput ); +#endif + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ), lcl_vdeloutput ), lcl_vbrstd ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, lcl_vdeloutput ); + + lcl_vdeloutput2 = _mm512_load_act( del_output_ptr+16 ); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + lcl_relumask2 = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr+16 ), _mm512_setzero_ps(), _CMP_NEQ_OQ ); + lcl_vdeloutput2 = _mm512_mask_blend_ps( lcl_relumask2, _mm512_setzero_ps(), lcl_vdeloutput2 ); + _mm512_store_act( del_output_ptr+16, lcl_vdeloutput2 ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask2 = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput2 = _mm512_mask_blend_ps( lcl_relumask2, _mm512_setzero_ps(), lcl_vdeloutput2 ); + _mm512_store_act( del_output_ptr+16, lcl_vdeloutput2 ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr+16, lcl_vdeloutput2 ); +#endif + lcl_vdgamma2 = _mm512_add_ps( lcl_vdgamma2, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ), lcl_vdeloutput2 ), lcl_vbrstd2 ) ); + lcl_vdbeta2 = _mm512_add_ps( lcl_vdbeta2, lcl_vdeloutput2 ); + + lcl_vdeloutput3 = _mm512_load_act( del_output_ptr+32 ); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + lcl_relumask3 = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr+32 ), _mm512_setzero_ps(), _CMP_NEQ_OQ ); + lcl_vdeloutput3 = _mm512_mask_blend_ps( lcl_relumask3, _mm512_setzero_ps(), lcl_vdeloutput3 ); + _mm512_store_act( del_output_ptr+32, lcl_vdeloutput3 ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask3 = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput3 = _mm512_mask_blend_ps( lcl_relumask3, _mm512_setzero_ps(), lcl_vdeloutput3 ); + _mm512_store_act( del_output_ptr+32, lcl_vdeloutput3 ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr+32, lcl_vdeloutput3 ); +#endif + lcl_vdgamma3 = _mm512_add_ps( lcl_vdgamma3, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr+32 ), lcl_vbmean3 ), lcl_vdeloutput3 ), lcl_vbrstd3 ) ); + lcl_vdbeta3 = _mm512_add_ps( lcl_vdbeta3, lcl_vdeloutput3 ); + + lcl_vdeloutput4 = _mm512_load_act( del_output_ptr+48 ); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + lcl_relumask4 = _mm512_cmp_ps_mask( _mm512_load_act( output_ptr+48 ), _mm512_setzero_ps(), _CMP_NEQ_OQ ); + lcl_vdeloutput4 = _mm512_mask_blend_ps( lcl_relumask4, _mm512_setzero_ps(), lcl_vdeloutput4 ); + _mm512_store_act( del_output_ptr+48, lcl_vdeloutput4 ); + output_ptr += 64; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + lcl_relumask4 = LIBXSMM_INTRINSICS_MM512_LOAD_MASK16( relumask_ptr ); + lcl_vdeloutput4 = _mm512_mask_blend_ps( lcl_relumask4, _mm512_setzero_ps(), lcl_vdeloutput4 ); + _mm512_store_act( del_output_ptr+48, lcl_vdeloutput4 ); + relumask_ptr += 2; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + _mm512_stream_act( del_input_add_ptr+48, lcl_vdeloutput4 ); + del_input_add_ptr += sw*64; +#endif + lcl_vdgamma4 = _mm512_add_ps( lcl_vdgamma4, _mm512_mul_ps( _mm512_mul_ps( _mm512_sub_ps( _mm512_load_act( input_ptr+48 ), lcl_vbmean4 ), lcl_vdeloutput4 ), lcl_vbrstd4 ) ); + lcl_vdbeta4 = _mm512_add_ps( lcl_vdbeta4, lcl_vdeloutput4 ); + + input_ptr += sw*64; + del_output_ptr += 64; + } + } + + _mm512_storeu_ps( del_gamma_img_ptr, lcl_vdgamma ); + _mm512_storeu_ps( del_beta_img_ptr, lcl_vdbeta ); + _mm512_storeu_ps( del_gamma_img_ptr+16, lcl_vdgamma2 ); + _mm512_storeu_ps( del_beta_img_ptr+16, lcl_vdbeta2 ); + _mm512_storeu_ps( del_gamma_img_ptr+32, lcl_vdgamma3 ); + _mm512_storeu_ps( del_beta_img_ptr+32, lcl_vdbeta3 ); + _mm512_storeu_ps( del_gamma_img_ptr+48, lcl_vdgamma4 ); + _mm512_storeu_ps( del_beta_img_ptr+48, lcl_vdbeta4 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the del_gamm and del_beta */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + __m512 lcl_vdgamma = _mm512_setzero_ps(); + __m512 lcl_vdbeta = _mm512_setzero_ps(); + + for ( img=0; img < nImg; img++ ) { + lcl_vdgamma = _mm512_add_ps( lcl_vdgamma, _mm512_loadu_ps( del_gamma_img_ptr ) ); + lcl_vdbeta = _mm512_add_ps( lcl_vdbeta, _mm512_loadu_ps( del_beta_img_ptr ) ); + del_gamma_img_ptr += 64; + del_beta_img_ptr += 64; + } + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, (fm/4), ((fm%4)*16), 64), lcl_vdgamma ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, (fm/4), ((fm%4)*16), 64), lcl_vdbeta ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we apply the actual backward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbmean, lcl_vbrstd, lcl_vdgamma, lcl_vdbeta; + __m512 lcl_vgamma2, lcl_vbmean2, lcl_vbrstd2, lcl_vdgamma2, lcl_vdbeta2; + __m512 lcl_vgamma3, lcl_vbmean3, lcl_vbrstd3, lcl_vdgamma3, lcl_vdbeta3; + __m512 lcl_vgamma4, lcl_vbmean4, lcl_vbrstd4, lcl_vdgamma4, lcl_vdbeta4; + __m512 lcl_vnhw = _mm512_set1_ps( nhw ); + __m512 lcl_vrec_nhw = _mm512_set1_ps( recp_nhw ); + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 64) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 64) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 64) ); + lcl_vdgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, 64) ); + lcl_vdbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, 64) ); + + lcl_vgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 16, 64) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 64) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 64) ); + lcl_vdgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 16, 64) ); + lcl_vdbeta2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 16, 64) ); + + lcl_vgamma3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 32, 64) ); + lcl_vbmean3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 32, 64) ); + lcl_vbrstd3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 32, 64) ); + lcl_vdgamma3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 32, 64) ); + lcl_vdbeta3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 32, 64) ); + + lcl_vgamma4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 48, 64) ); + lcl_vbmean4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 48, 64) ); + lcl_vbrstd4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 48, 64) ); + lcl_vdgamma4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 48, 64) ); + lcl_vdbeta4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 48, 64) ); + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + element_input_type* del_input_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); + const element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 64); + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + __m512 lcl_vdelinput; + __m512 lcl_vdelinput2; + __m512 lcl_vdelinput3; + __m512 lcl_vdelinput4; + + lcl_vdelinput = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vdgamma ); + lcl_vdelinput = _mm512_mul_ps( lcl_vdelinput, lcl_vbrstd ); + lcl_vdelinput = _mm512_add_ps( lcl_vdbeta, lcl_vdelinput ); + lcl_vdelinput = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr ) ), lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vbrstd, lcl_vdelinput ); + lcl_vdelinput = _mm512_mul_ps( lcl_vgamma, lcl_vdelinput ); + + lcl_vdelinput2 = _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vdelinput2, lcl_vdgamma2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vdelinput2, lcl_vbrstd2 ); + lcl_vdelinput2 = _mm512_add_ps( lcl_vdbeta2, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr+16 ) ), lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vbrstd2, lcl_vdelinput2 ); + lcl_vdelinput2 = _mm512_mul_ps( lcl_vgamma2, lcl_vdelinput2 ); + + lcl_vdelinput3 = _mm512_sub_ps( _mm512_load_act( input_ptr+32 ), lcl_vbmean3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vdelinput3, lcl_vdgamma3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vdelinput3, lcl_vbrstd3 ); + lcl_vdelinput3 = _mm512_add_ps( lcl_vdbeta3, lcl_vdelinput3 ); + lcl_vdelinput3 = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr+32 ) ), lcl_vdelinput3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vbrstd3, lcl_vdelinput3 ); + lcl_vdelinput3 = _mm512_mul_ps( lcl_vgamma3, lcl_vdelinput3 ); + + lcl_vdelinput4 = _mm512_sub_ps( _mm512_load_act( input_ptr+48 ), lcl_vbmean4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vdelinput4, lcl_vdgamma4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vdelinput4, lcl_vbrstd4 ); + lcl_vdelinput4 = _mm512_add_ps( lcl_vdbeta4, lcl_vdelinput4 ); + lcl_vdelinput4 = _mm512_sub_ps( _mm512_mul_ps( lcl_vnhw, _mm512_load_act( del_output_ptr+48 ) ), lcl_vdelinput4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vdelinput4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vbrstd4, lcl_vdelinput4 ); + lcl_vdelinput4 = _mm512_mul_ps( lcl_vgamma4, lcl_vdelinput4 ); + + _mm512_stream_act( del_input_ptr, lcl_vdelinput ); + _mm512_stream_act( del_input_ptr+16, lcl_vdelinput2 ); + _mm512_stream_act( del_input_ptr+32, lcl_vdelinput3 ); + _mm512_stream_act( del_input_ptr+48, lcl_vdelinput4 ); + + del_input_ptr += sw*64; + input_ptr += sw*64; + del_output_ptr += 64; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..7fee60ee2a594d4d6dd7192ed3b8c82b56fce167 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c @@ -0,0 +1,264 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int nG = handle->desc.G; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; +const int nFmBlock = handle->ifmblock; +/* derive channels per group */ +const int nFmG = (nBlocksFm * nFmBlock) / nG; + +/* size of sample */ +const element_stats_type ghw = (element_stats_type)(nFmG * ifh * ifw); +const element_stats_type recp_ghw = 1.0f/ghw; +const element_stats_type eps = 1e-7f; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +/* @TODO let's fix parallelization to include channel groups while avoiding conflict misses */ +const int work = nImg; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* loop variables */ +int img = 0; +int fm = 0; +/*int imgfm = 0;*/ +int hi = 0; +int wi = 0; +int v = 0; +int ho = 0; +int wo = 0; +int g = 0; + +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, nFmBlock); +LIBXSMM_VLA_DECL(5, element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, element_input_type, dinput_add, (element_input_type* )handle->grad_add->data, nBlocksFm, ifhp, ifwp, nFmBlock); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) +LIBXSMM_VLA_DECL(5, const element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, nFmBlock); + +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, nFmBlock); +LIBXSMM_VLA_DECL(2, element_stats_type, dgamma, (element_stats_type*)handle->grad_gamma->data, nFmBlock); +LIBXSMM_VLA_DECL(2, element_stats_type, dbeta, (element_stats_type*)handle->grad_beta->data, nFmBlock); +LIBXSMM_VLA_DECL(2, const element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, nG); +LIBXSMM_VLA_DECL(2, const element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, nG); +LIBXSMM_VLA_DECL(2, const element_stats_type, variance, (element_stats_type*)handle->variance->data, nG); +LIBXSMM_VLA_DECL(3, element_stats_type, dgamma_img, (element_stats_type*)handle->scratch, nImg, nFmBlock); +LIBXSMM_VLA_DECL(3, element_stats_type, dbeta_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock); +LIBXSMM_VLA_DECL(2, element_stats_type, d1_val_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * 2 * (size_t)nBlocksFm * (size_t)nFmBlock), nG); +LIBXSMM_VLA_DECL(2, element_stats_type, d2_val_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * 2 * (size_t)nBlocksFm * (size_t)nFmBlock) + ((size_t)nImg*(size_t)nG), nG); +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif + +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_BF16) +union libxsmm_bfloat16_hp input_f32; +union libxsmm_bfloat16_hp del_input_f32; +union libxsmm_bfloat16_hp del_output_f32; +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) +union libxsmm_bfloat16_hp output_f32; +output_f32.i[1] = 0; +output_f32.i[0] = 0; +#endif +input_f32.i[1] = 0; +input_f32.i[0] = 0; +del_output_f32.i[1] = 0; +del_output_f32.i[0] = 0; +del_input_f32.i[1] = 0; +del_input_f32.i[0] = 0; +#endif + +assert( nFmBlock <= 64 ); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for ( img = thr_begin; img < thr_end; ++img ) { + element_stats_type* d1_val_img_ptr = &LIBXSMM_VLA_ACCESS(2, d1_val_img, img, 0, nG); + element_stats_type* d2_val_img_ptr = &LIBXSMM_VLA_ACCESS(2, d2_val_img, img, 0, nG); + + for ( g = 0; g < nG; ++g ) { + d1_val_img_ptr[g] = 0.0f; + d2_val_img_ptr[g] = 0.0f; + } + + for ( fm = 0; fm < nBlocksFm; ++fm ) { + /* @TODO check if we can bake this in into scratch */ + element_stats_type lcl_gamma_ptr[64]; + element_stats_type lcl_beta_ptr[64]; + element_stats_type* del_gamma_img_ptr; + element_stats_type* del_beta_img_ptr; + + del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, nFmBlock); + del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + lcl_gamma_ptr[v] = 0.0f; + lcl_beta_ptr[v] = 0.0f; + } + + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + element_input_type* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(5, dinput_add, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + const element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + const unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); + element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); + const element_stats_type* bmean_ptr = &LIBXSMM_VLA_ACCESS(2, bmean, img, 0, nG); + const element_stats_type* brstd_ptr = &LIBXSMM_VLA_ACCESS(2, brstd, img, 0, nG); + const element_stats_type* gamma_ptr = &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, nFmBlock); + + for ( v=0; v < nFmBlock; v++ ) { + g = ((fm*nFmBlock)+v)/nFmG; +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_BF16) + del_output_f32.i[1] = del_output_ptr[v]; + del_output_f32.i[0] = 0; +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + output_f32.i[1] = output_ptr[v]; + del_output_f32.f = LIBXSMM_FEQ(output_f32.f, 0) ? 0 : del_output_f32.f; + del_output_ptr[v] = del_output_f32.i[1]; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + del_output_ptr[v] = (element_output_type)(relumask_ptr[v] == 1 ? del_output_ptr[v] : 0); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + del_input_add_ptr[v] = del_output_ptr[v]; +#endif + input_f32.i[1] = input_ptr[v]; + lcl_gamma_ptr[v] += (input_f32.f - bmean_ptr[g]) * del_output_f32.f * brstd_ptr[g]; + lcl_beta_ptr[v] += del_output_f32.f; + d1_val_img_ptr[g] += (input_f32.f - bmean_ptr[g]) * del_output_f32.f * gamma_ptr[v]; + d2_val_img_ptr[g] += del_output_f32.f * gamma_ptr[v]; +#else +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU) + del_output_ptr[v] = LIBXSMM_FEQ(output_ptr[v], 0) ? 0 : del_output_ptr[v]; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK) + del_output_ptr[v] = (element_output_type)(relumask_ptr[v] == 1 ? del_output_ptr[v] : 0); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE) + del_input_add_ptr[v] = del_output_ptr[v]; +#endif + lcl_gamma_ptr[v] += (input_ptr[v] - bmean_ptr[g]) * del_output_ptr[v] * brstd_ptr[g]; + lcl_beta_ptr[v] += del_output_ptr[v]; + d1_val_img_ptr[g] += (input_ptr[v] - bmean_ptr[g]) * del_output_ptr[v] * gamma_ptr[v]; + d2_val_img_ptr[g] += del_output_ptr[v] * gamma_ptr[v]; +#endif + } + } + } + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + del_gamma_img_ptr[v] = lcl_gamma_ptr[v]; + del_beta_img_ptr[v] = lcl_beta_ptr[v]; + } + } + + for ( fm = 0; fm < nBlocksFm; ++fm ) { + for ( hi=iph, ho=oph; hi < (ifh + iph); hi+=sh, ho++ ) { + for ( wi=ipw, wo=opw; wi < (ifw + ipw); wi+=sw, wo++ ) { + element_input_type* del_input_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); + const element_output_type* del_output_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); + const element_stats_type* bmean_ptr = &LIBXSMM_VLA_ACCESS(2, bmean, img, 0, nG); + const element_stats_type* brstd_ptr = &LIBXSMM_VLA_ACCESS(2, brstd, img, 0, nG); + const element_stats_type* variance_ptr = &LIBXSMM_VLA_ACCESS(2, variance, img, 0, nG); + const element_stats_type* gamma_ptr = &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, nFmBlock); + +#if 0 +#if !defined(LIBXSMM_DNN_FUSEDGN_BWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif +#endif + for ( v=0; v < nFmBlock; v++ ) { + element_stats_type t0_val; + g = ((fm*nFmBlock)+v)/nFmG; + t0_val = brstd_ptr[g] * recp_ghw; +#if defined(LIBXSMM_DNN_FUSEDGN_BWD_BF16) + del_output_f32.i[1] = del_output_ptr[v]; + input_f32.i[1] = input_ptr[v]; + del_input_f32.f = t0_val * ((gamma_ptr[v] * ghw * del_output_f32.f) - d2_val_img_ptr[g] - ((input_f32.f - bmean_ptr[g]) * d1_val_img_ptr[g] * (1.0f/(variance_ptr[g] + eps)))); + del_input_ptr[v] = del_input_f32.i[1]; +#else + del_input_ptr[v] = t0_val * ((gamma_ptr[v] * ghw * del_output_ptr[v]) - d2_val_img_ptr[g] - ((input_ptr[v] - bmean_ptr[g]) * d1_val_img_ptr[g] * (1.0f/(variance_ptr[g] + eps)))); +#endif + } + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + +/* now we need to reduce the del_gamm and del_beta */ +for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + element_stats_type* del_gamma_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma, fm, 0, nFmBlock); + element_stats_type* del_beta_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta, fm, 0, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + del_gamma_ptr[v] = (element_stats_type)0; + del_beta_ptr[v] = (element_stats_type)0; + } + + for ( img=0; img < nImg; img++ ) { + element_stats_type* del_gamma_img_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img, fm, img, 0, nImg, nFmBlock); + element_stats_type* del_beta_img_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img, fm, img, 0, nImg, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + del_gamma_ptr[v] += del_gamma_img_ptr[v]; + del_beta_ptr[v] += del_beta_img_ptr[v]; + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..9d1a104e750683b3399c8054118a8bdc7698ed36 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c @@ -0,0 +1,232 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* eps to avoid sqrt of zero */ +const element_stats_type sqrt_eps = 1e-7f; +const element_stats_type nhw = (element_stats_type)(handle->desc.N * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 16); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, const element_input_type, input_add, (element_input_type* )handle->reg_add->data, nBlocksFm, ifhp, ifwp, 16); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 16); +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 16); +LIBXSMM_VLA_DECL(2, const element_stats_type, beta, (element_stats_type*)handle->reg_beta->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 16); +LIBXSMM_VLA_DECL(2, element_stats_type, variance, (element_stats_type*)handle->variance->data, 16); +LIBXSMM_VLA_DECL(3, element_stats_type, sum_img, (element_stats_type*)handle->scratch, nImg, 16); +LIBXSMM_VLA_DECL(3, element_stats_type, sumsq_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * 16), nImg, 16); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 2); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr; + element_stats_type* sumsq_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, img, 0, nImg, 16); + sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, img, 0, nImg, 16); + + for ( hi=iph; hi < (ifh + iph); hi++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); + for ( wi=ipw; wi < (ifw + ipw); wi++ ) { + __m512 lcl_vinput = _mm512_load_act( input_ptr ); + lcl_vsum = _mm512_add_ps( lcl_vsum, lcl_vinput ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_mul_ps( lcl_vinput, lcl_vinput ) ); + + input_ptr += 16; + } + } + + _mm512_storeu_ps( sum_img_ptr, lcl_vsum ); + _mm512_storeu_ps( sumsq_img_ptr, lcl_vsumsq ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the sum and sum^2, we use the final */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, 0, 0, nImg, 16); + element_stats_type* sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, 0, 0, nImg, 16); + + for ( img=0; img < nImg; img++ ) { + lcl_vsum = _mm512_add_ps( lcl_vsum, _mm512_loadu_ps( sum_img_ptr ) ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_loadu_ps( sumsq_img_ptr ) ); + sum_img_ptr += 16; + sumsq_img_ptr += 16; + } + + __m512 lcl_vsqrt_eps = _mm512_set1_ps(sqrt_eps); + __m512 lcl_vrec_nhw = _mm512_set1_ps(recp_nhw); + __m512 lcl_vone = _mm512_set1_ps(1.0); + __m512 lcl_vbmean, lcl_vbmeansq, lcl_vsqbmean, lcl_vbrstd, lcl_vvar; + lcl_vbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsum ); /* E(X) */ + lcl_vbmeansq = _mm512_mul_ps( lcl_vbmean, lcl_vbmean ); /* E(X)^2 */ + lcl_vsqbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsumsq ); /* E(X^2) */ + lcl_vvar = _mm512_sub_ps( lcl_vsqbmean, lcl_vbmeansq ); /* variance */ +#if 0 + { + __m512d lcl_voned = _mm512_set1_pd(1.0); + __m512d lcl_vepsd = _mm512_set1_pd(1e-7); + __m512d lcl_vlo = _mm512_cvtps_pd( _mm256_castpd_ps( _mm512_extractf64x4_pd( _mm512_castps_pd( lcl_vvar ), 0 ) ) ); + __m512d lcl_vhi = _mm512_cvtps_pd( _mm256_castpd_ps( _mm512_extractf64x4_pd( _mm512_castps_pd( lcl_vvar ), 1 ) ) ); + lcl_vlo = _mm512_sqrt_pd( _mm512_add_pd( lcl_vlo, lcl_vepsd ) ); + lcl_vhi = _mm512_sqrt_pd( _mm512_add_pd( lcl_vhi, lcl_vepsd ) ); + lcl_vlo = _mm512_div_pd( lcl_voned, lcl_vlo ); + lcl_vhi = _mm512_div_pd( lcl_voned, lcl_vhi ); + lcl_vbrstd = _mm512_castpd_ps( _mm512_insertf64x4( _mm512_setzero_pd(), _mm256_castps_pd( _mm512_cvtpd_ps( lcl_vlo ) ), 0 ) ); + lcl_vbrstd = _mm512_castpd_ps( _mm512_insertf64x4( _mm512_castps_pd( lcl_vbrstd ), _mm256_castps_pd( _mm512_cvtpd_ps( lcl_vhi ) ), 1 ) ); + } +#else + lcl_vbrstd = _mm512_div_ps( lcl_vone, _mm512_sqrt_ps( _mm512_add_ps( lcl_vvar, lcl_vsqrt_eps ) ) ); +#endif + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 16), lcl_vbmean ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 16), lcl_vbrstd ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, variance, fm, 0, 16), lcl_vvar ); + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we apply the actual forward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbeta, lcl_vbmean, lcl_vbrstd; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 16) ); + lcl_vbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 0, 16) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 16) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 16) ); + + for ( hi=iph, ho=oph; hi < (ifh+iph); hi+=sh, ho++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + const element_input_type* input_add_ptr = &LIBXSMM_VLA_ACCESS(5, input_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 16); +#endif + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 16); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 2); +#endif + for ( wi=ipw, wo=opw; wi < (ifw+ipw); wi+=sw, wo++ ) { + __m512 lcl_vo; +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + __mmask16 lcl_relumask; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vo = _mm512_mul_ps( lcl_vgamma, lcl_vo ); + lcl_vo = _mm512_fmadd_ps( lcl_vo, lcl_vbrstd, lcl_vbeta ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + lcl_vo = _mm512_add_ps( lcl_vo, _mm512_load_act( input_add_ptr ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU) + lcl_vo = _mm512_max_ps( lcl_vo, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = _mm512_cmp_ps_mask( lcl_vo, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vo ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask ); + relumask_ptr += 2; +#endif + _mm512_stream_act( output_ptr, lcl_vo ); + + input_ptr += sw*16; +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + input_add_ptr += sw*16; +#endif + output_ptr += 16; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); +} + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..6e238f891364e94b96149f4b60a5baadc0126b4f --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c @@ -0,0 +1,275 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* eps to avoid sqrt of zero */ +const element_stats_type sqrt_eps = 1e-7f; +const element_stats_type nhw = (element_stats_type)(handle->desc.N * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 32); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, const element_input_type, input_add, (element_input_type* )handle->reg_add->data, nBlocksFm, ifhp, ifwp, 32); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 32); +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 32); +LIBXSMM_VLA_DECL(2, const element_stats_type, beta, (element_stats_type*)handle->reg_beta->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 32); +LIBXSMM_VLA_DECL(2, element_stats_type, variance, (element_stats_type*)handle->variance->data, 32); +LIBXSMM_VLA_DECL(3, element_stats_type, sum_img, (element_stats_type*)handle->scratch, nImg, 32); +LIBXSMM_VLA_DECL(3, element_stats_type, sumsq_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * 32), nImg, 32); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 4); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + __m512 lcl_vsum2 = _mm512_setzero_ps(); + __m512 lcl_vsumsq2 = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr; + element_stats_type* sumsq_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, img, 0, nImg, 32); + sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, img, 0, nImg, 32); + + for ( hi=iph; hi < (ifh + iph); hi++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); + for ( wi=ipw; wi < (ifw + ipw); wi++ ) { + __m512 lcl_vinput = _mm512_load_act( input_ptr ); + __m512 lcl_vinput2 = _mm512_load_act( input_ptr+16 ); + + lcl_vsum = _mm512_add_ps( lcl_vsum, lcl_vinput ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_mul_ps( lcl_vinput, lcl_vinput ) ); + + lcl_vsum2 = _mm512_add_ps( lcl_vsum2, lcl_vinput2 ); + lcl_vsumsq2 = _mm512_add_ps( lcl_vsumsq2, _mm512_mul_ps( lcl_vinput2, lcl_vinput2 ) ); + + input_ptr += 32; + } + } + + _mm512_storeu_ps( sum_img_ptr, lcl_vsum ); + _mm512_storeu_ps( sumsq_img_ptr, lcl_vsumsq ); + + _mm512_storeu_ps( sum_img_ptr+16, lcl_vsum2 ); + _mm512_storeu_ps( sumsq_img_ptr+16, lcl_vsumsq2 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the sum and sum^2, we use the final */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + __m512 lcl_vsum2 = _mm512_setzero_ps(); + __m512 lcl_vsumsq2 = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, 0, 0, nImg, 32); + element_stats_type* sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, 0, 0, nImg, 32); + + for ( img=0; img < nImg; img++ ) { + lcl_vsum = _mm512_add_ps( lcl_vsum, _mm512_loadu_ps( sum_img_ptr ) ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_loadu_ps( sumsq_img_ptr ) ); + + lcl_vsum2 = _mm512_add_ps( lcl_vsum2, _mm512_loadu_ps( sum_img_ptr+16 ) ); + lcl_vsumsq2 = _mm512_add_ps( lcl_vsumsq2, _mm512_loadu_ps( sumsq_img_ptr+16 ) ); + + sum_img_ptr += 32; + sumsq_img_ptr += 32; + } + + __m512 lcl_vsqrt_eps = _mm512_set1_ps(sqrt_eps); + __m512 lcl_vrec_nhw = _mm512_set1_ps(recp_nhw); + __m512 lcl_vone = _mm512_set1_ps(1.0); + __m512 lcl_vbmean, lcl_vbmeansq, lcl_vsqbmean, lcl_vbrstd, lcl_vvar; + __m512 lcl_vbmean2, lcl_vbmeansq2, lcl_vsqbmean2, lcl_vbrstd2, lcl_vvar2; + + lcl_vbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsum ); /* E(X) */ + lcl_vbmeansq = _mm512_mul_ps( lcl_vbmean, lcl_vbmean ); /* E(X)^2 */ + lcl_vsqbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsumsq ); /* E(X^2) */ + lcl_vvar = _mm512_sub_ps( lcl_vsqbmean, lcl_vbmeansq ); /* variance */ + lcl_vbrstd = _mm512_div_ps( lcl_vone, _mm512_sqrt_ps( _mm512_add_ps( lcl_vvar, lcl_vsqrt_eps ) ) ); + + lcl_vbmean2 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsum2 ); /* E(X) */ + lcl_vbmeansq2 = _mm512_mul_ps( lcl_vbmean2, lcl_vbmean2 ); /* E(X)^2 */ + lcl_vsqbmean2 = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsumsq2 ); /* E(X^2) */ + lcl_vvar2 = _mm512_sub_ps( lcl_vsqbmean2, lcl_vbmeansq2 ); /* variance */ + lcl_vbrstd2 = _mm512_div_ps( lcl_vone, _mm512_sqrt_ps( _mm512_add_ps( lcl_vvar2, lcl_vsqrt_eps ) ) ); + + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 32), lcl_vbmean ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 32), lcl_vbrstd ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, variance, fm, 0, 32), lcl_vvar ); + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 32), lcl_vbmean2 ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 32), lcl_vbrstd2 ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, variance, fm, 16, 32), lcl_vvar2 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we apply the actual forward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbeta, lcl_vbmean, lcl_vbrstd; + __m512 lcl_vgamma2, lcl_vbeta2, lcl_vbmean2, lcl_vbrstd2; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 32) ); + lcl_vbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 0, 32) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 32) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 32) ); + + lcl_vgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 16, 32) ); + lcl_vbeta2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 16, 32) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 32) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 32) ); + + for ( hi=iph, ho=oph; hi < (ifh+iph); hi+=sh, ho++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + const element_input_type* input_add_ptr = &LIBXSMM_VLA_ACCESS(5, input_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 32); +#endif + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 32); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 4); +#endif + for ( wi=ipw, wo=opw; wi < (ifw+ipw); wi+=sw, wo++ ) { + __m512 lcl_vo; + __m512 lcl_vo2; +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + __mmask16 lcl_relumask; + __mmask16 lcl_relumask2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vo = _mm512_mul_ps( lcl_vgamma, lcl_vo ); + lcl_vo = _mm512_fmadd_ps( lcl_vo, lcl_vbrstd, lcl_vbeta ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + lcl_vo = _mm512_add_ps( lcl_vo, _mm512_load_act( input_add_ptr ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU) + lcl_vo = _mm512_max_ps( lcl_vo, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = _mm512_cmp_ps_mask( lcl_vo, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vo ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask ); + relumask_ptr += 2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo2 = _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ); + lcl_vo2 = _mm512_mul_ps( lcl_vgamma2, lcl_vo2 ); + lcl_vo2 = _mm512_fmadd_ps( lcl_vo2, lcl_vbrstd2, lcl_vbeta2 ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + lcl_vo2 = _mm512_add_ps( lcl_vo2, _mm512_load_act( input_add_ptr+16 ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU) + lcl_vo2 = _mm512_max_ps( lcl_vo2, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask2 = _mm512_cmp_ps_mask( lcl_vo2, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo2 = _mm512_mask_blend_ps( lcl_relumask2, _mm512_setzero_ps(), lcl_vo2 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask2 ); + relumask_ptr += 2; +#endif + + _mm512_stream_act( output_ptr, lcl_vo ); + _mm512_stream_act( output_ptr+16, lcl_vo2 ); + + input_ptr += sw*32; +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + input_add_ptr += sw*32; +#endif + output_ptr += 32; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..332a88ad57ef9a135dcf4d515603512b86166406 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c @@ -0,0 +1,332 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel, delta gamma and beta reduction */ +const int work2 = nBlocksFm * 4; +/* compute chunk size */ +const int chunksize2 = (work2 % handle->desc.threads == 0) ? (work2 / handle->desc.threads) : ((work2 / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2; +const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2; + +/* eps to avoid sqrt of zero */ +const element_stats_type sqrt_eps = 1e-7f; +const element_stats_type nhw = (element_stats_type)(handle->desc.N * ifh * ifw); +const element_stats_type recp_nhw = 1.0f/nhw; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int hi = 0; +int wi = 0; +int ho = 0; +int wo = 0; + +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 64); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, const element_input_type, input_add, (element_input_type* )handle->reg_add->data, nBlocksFm, ifhp, ifwp, 64); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 64); +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, 64); +LIBXSMM_VLA_DECL(2, const element_stats_type, beta, (element_stats_type*)handle->reg_beta->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, 64); +LIBXSMM_VLA_DECL(2, element_stats_type, variance, (element_stats_type*)handle->variance->data, 64); +LIBXSMM_VLA_DECL(3, element_stats_type, sum_img, (element_stats_type*)handle->scratch, nImg, 64); +LIBXSMM_VLA_DECL(3, element_stats_type, sumsq_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * 64), nImg, 64); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, 8); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + __m512 lcl_vsum2 = _mm512_setzero_ps(); + __m512 lcl_vsumsq2 = _mm512_setzero_ps(); + __m512 lcl_vsum3 = _mm512_setzero_ps(); + __m512 lcl_vsumsq3 = _mm512_setzero_ps(); + __m512 lcl_vsum4 = _mm512_setzero_ps(); + __m512 lcl_vsumsq4 = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr; + element_stats_type* sumsq_img_ptr; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, fm, img, 0, nImg, 64); + sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, fm, img, 0, nImg, 64); + + for ( hi=iph; hi < (ifh + iph); hi++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); + for ( wi=ipw; wi < (ifw + ipw); wi++ ) { + __m512 lcl_vinput = _mm512_load_act( input_ptr ); + __m512 lcl_vinput2 = _mm512_load_act( input_ptr+16 ); + __m512 lcl_vinput3 = _mm512_load_act( input_ptr+32 ); + __m512 lcl_vinput4 = _mm512_load_act( input_ptr+48 ); + + lcl_vsum = _mm512_add_ps( lcl_vsum, lcl_vinput ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_mul_ps( lcl_vinput, lcl_vinput ) ); + + lcl_vsum2 = _mm512_add_ps( lcl_vsum2, lcl_vinput2 ); + lcl_vsumsq2 = _mm512_add_ps( lcl_vsumsq2, _mm512_mul_ps( lcl_vinput2, lcl_vinput2 ) ); + + lcl_vsum3 = _mm512_add_ps( lcl_vsum3, lcl_vinput3 ); + lcl_vsumsq3 = _mm512_add_ps( lcl_vsumsq3, _mm512_mul_ps( lcl_vinput3, lcl_vinput3 ) ); + + lcl_vsum4 = _mm512_add_ps( lcl_vsum4, lcl_vinput4 ); + lcl_vsumsq4 = _mm512_add_ps( lcl_vsumsq4, _mm512_mul_ps( lcl_vinput4, lcl_vinput4 ) ); + + input_ptr += 64; + } + } + + _mm512_storeu_ps( sum_img_ptr, lcl_vsum ); + _mm512_storeu_ps( sumsq_img_ptr, lcl_vsumsq ); + + _mm512_storeu_ps( sum_img_ptr+16, lcl_vsum2 ); + _mm512_storeu_ps( sumsq_img_ptr+16, lcl_vsumsq2 ); + + _mm512_storeu_ps( sum_img_ptr+32, lcl_vsum3 ); + _mm512_storeu_ps( sumsq_img_ptr+32, lcl_vsumsq3 ); + + _mm512_storeu_ps( sum_img_ptr+48, lcl_vsum4 ); + _mm512_storeu_ps( sumsq_img_ptr+48, lcl_vsumsq4 ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we need to reduce the sum and sum^2, we use the final */ + for ( fm = thr_begin2; fm < thr_end2; ++fm ) { + __m512 lcl_vsum = _mm512_setzero_ps(); + __m512 lcl_vsumsq = _mm512_setzero_ps(); + element_stats_type* sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + element_stats_type* sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, (fm/4), 0, ((fm%4)*16), nImg, 64); + + for ( img=0; img < nImg; img++ ) { + lcl_vsum = _mm512_add_ps( lcl_vsum, _mm512_loadu_ps( sum_img_ptr ) ); + lcl_vsumsq = _mm512_add_ps( lcl_vsumsq, _mm512_loadu_ps( sumsq_img_ptr ) ); + + sum_img_ptr += 64; + sumsq_img_ptr += 64; + } + + __m512 lcl_vsqrt_eps = _mm512_set1_ps(sqrt_eps); + __m512 lcl_vrec_nhw = _mm512_set1_ps(recp_nhw); + __m512 lcl_vone = _mm512_set1_ps(1.0); + __m512 lcl_vbmean, lcl_vbmeansq, lcl_vsqbmean, lcl_vbrstd, lcl_vvar; + + lcl_vbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsum ); /* E(X) */ + lcl_vbmeansq = _mm512_mul_ps( lcl_vbmean, lcl_vbmean ); /* E(X)^2 */ + lcl_vsqbmean = _mm512_mul_ps( lcl_vrec_nhw, lcl_vsumsq ); /* E(X^2) */ + lcl_vvar = _mm512_sub_ps( lcl_vsqbmean, lcl_vbmeansq ); /* variance */ + lcl_vbrstd = _mm512_div_ps( lcl_vone, _mm512_sqrt_ps( _mm512_add_ps( lcl_vvar, lcl_vsqrt_eps ) ) ); + + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, (fm/4), ((fm%4)*16), 64), lcl_vbmean ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, (fm/4), ((fm%4)*16), 64), lcl_vbrstd ); + _mm512_storeu_ps( &LIBXSMM_VLA_ACCESS(2, variance, (fm/4), ((fm%4)*16), 64), lcl_vvar ); + } + + libxsmm_barrier_wait(handle->barrier, ltid); + + /* now we apply the actual forward batch norm */ + for ( imgfm = thr_begin; imgfm < thr_end; ++imgfm ) { + __m512 lcl_vgamma, lcl_vbeta, lcl_vbmean, lcl_vbrstd; + __m512 lcl_vgamma2, lcl_vbeta2, lcl_vbmean2, lcl_vbrstd2; + __m512 lcl_vgamma3, lcl_vbeta3, lcl_vbmean3, lcl_vbrstd3; + __m512 lcl_vgamma4, lcl_vbeta4, lcl_vbmean4, lcl_vbrstd4; + + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + lcl_vgamma = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, 64) ); + lcl_vbeta = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 0, 64) ); + lcl_vbmean = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 0, 64) ); + lcl_vbrstd = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 0, 64) ); + + lcl_vgamma2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 16, 64) ); + lcl_vbeta2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 16, 64) ); + lcl_vbmean2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 16, 64) ); + lcl_vbrstd2 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 16, 64) ); + + lcl_vgamma3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 32, 64) ); + lcl_vbeta3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 32, 64) ); + lcl_vbmean3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 32, 64) ); + lcl_vbrstd3 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 32, 64) ); + + lcl_vgamma4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, gamma, fm, 48, 64) ); + lcl_vbeta4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, beta, fm, 48, 64) ); + lcl_vbmean4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, bmean, fm, 48, 64) ); + lcl_vbrstd4 = _mm512_loadu_ps( &LIBXSMM_VLA_ACCESS(2, brstd, fm, 48, 64) ); + + for ( hi=iph, ho=oph; hi < (ifh+iph); hi+=sh, ho++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + const element_input_type* input_add_ptr = &LIBXSMM_VLA_ACCESS(5, input_add, img, fm, hi, ipw, 0, nBlocksFm, ifhp, ifwp, 64); +#endif + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 64); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 8); +#endif + for ( wi=ipw, wo=opw; wi < (ifw+ipw); wi+=sw, wo++ ) { + __m512 lcl_vo; + __m512 lcl_vo2; + __m512 lcl_vo3; + __m512 lcl_vo4; +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + __mmask16 lcl_relumask; + __mmask16 lcl_relumask2; + __mmask16 lcl_relumask3; + __mmask16 lcl_relumask4; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo = _mm512_sub_ps( _mm512_load_act( input_ptr ), lcl_vbmean ); + lcl_vo = _mm512_mul_ps( lcl_vgamma, lcl_vo ); + lcl_vo = _mm512_fmadd_ps( lcl_vo, lcl_vbrstd, lcl_vbeta ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + lcl_vo = _mm512_add_ps( lcl_vo, _mm512_load_act( input_add_ptr ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU) + lcl_vo = _mm512_max_ps( lcl_vo, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask = _mm512_cmp_ps_mask( lcl_vo, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo = _mm512_mask_blend_ps( lcl_relumask, _mm512_setzero_ps(), lcl_vo ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask ); + relumask_ptr += 2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo2 = _mm512_sub_ps( _mm512_load_act( input_ptr+16 ), lcl_vbmean2 ); + lcl_vo2 = _mm512_mul_ps( lcl_vgamma2, lcl_vo2 ); + lcl_vo2 = _mm512_fmadd_ps( lcl_vo2, lcl_vbrstd2, lcl_vbeta2 ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + lcl_vo2 = _mm512_add_ps( lcl_vo2, _mm512_load_act( input_add_ptr+16 ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU) + lcl_vo2 = _mm512_max_ps( lcl_vo2, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask2 = _mm512_cmp_ps_mask( lcl_vo2, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo2 = _mm512_mask_blend_ps( lcl_relumask2, _mm512_setzero_ps(), lcl_vo2 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask2 ); + relumask_ptr += 2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo3 = _mm512_sub_ps( _mm512_load_act( input_ptr+32 ), lcl_vbmean3 ); + lcl_vo3 = _mm512_mul_ps( lcl_vgamma3, lcl_vo3 ); + lcl_vo3 = _mm512_fmadd_ps( lcl_vo3, lcl_vbrstd3, lcl_vbeta3 ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + lcl_vo3 = _mm512_add_ps( lcl_vo3, _mm512_load_act( input_add_ptr+32 ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU) + lcl_vo3 = _mm512_max_ps( lcl_vo3, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask3 = _mm512_cmp_ps_mask( lcl_vo3, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo3 = _mm512_mask_blend_ps( lcl_relumask3, _mm512_setzero_ps(), lcl_vo3 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask3 ); + relumask_ptr += 2; +#endif + + /* BN + scale (gamma, beta) */ + lcl_vo4 = _mm512_sub_ps( _mm512_load_act( input_ptr+48 ), lcl_vbmean4 ); + lcl_vo4 = _mm512_mul_ps( lcl_vgamma4, lcl_vo4 ); + lcl_vo4 = _mm512_fmadd_ps( lcl_vo4, lcl_vbrstd4, lcl_vbeta4 ); + /* eltwise add */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + lcl_vo4 = _mm512_add_ps( lcl_vo4, _mm512_load_act( input_add_ptr+48 ) ); +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU) + lcl_vo4 = _mm512_max_ps( lcl_vo4, _mm512_setzero_ps() ); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + lcl_relumask4 = _mm512_cmp_ps_mask( lcl_vo4, _mm512_setzero_ps(), _CMP_GT_OQ ); + lcl_vo4 = _mm512_mask_blend_ps( lcl_relumask4, _mm512_setzero_ps(), lcl_vo4 ); + LIBXSMM_INTRINSICS_MM512_STORE_MASK16( relumask_ptr, lcl_relumask4 ); + relumask_ptr += 2; +#endif + + _mm512_stream_act( output_ptr, lcl_vo ); + _mm512_stream_act( output_ptr+16, lcl_vo2 ); + _mm512_stream_act( output_ptr+32, lcl_vo3 ); + _mm512_stream_act( output_ptr+48, lcl_vo4 ); + + input_ptr += sw*64; +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + input_add_ptr += sw*64; +#endif + output_ptr += 64; + } + } + } + + libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..89b70194bfa397b9985075f66341383110147f9c --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c @@ -0,0 +1,229 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int nG = handle->desc.G; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = ifh/sh; +const int ofw = ifw/sw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; +const int nFmBlock = handle->ifmblock; +/* derive channels per group */ +const int nFmG = (nBlocksFm * nFmBlock) / nG; +/* size of sample */ +const element_stats_type ghw = (element_stats_type)(nFmG * ifh * ifw); +const element_stats_type recp_ghw = 1.0f/ghw; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +/* @TODO let's fix parallelization to include channel groups while avoiding conflict misses */ +const int work = nImg; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* eps to avoid sqrt of zero */ +const element_stats_type sqrt_eps = 1e-7f; + +/* loop variables */ +int img = 0; +int fm = 0; +/*int imgfm = 0;*/ +int hi = 0; +int wi = 0; +int v = 0; +int ho = 0; +int wo = 0; +int g = 0; + +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) +LIBXSMM_VLA_DECL(5, const element_input_type, input_add, (element_input_type* )handle->reg_add->data, nBlocksFm, ifhp, ifwp, nFmBlock); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, nFmBlock); +LIBXSMM_VLA_DECL(2, const element_stats_type, gamma, (element_stats_type*)handle->reg_gamma->data, nFmBlock); +LIBXSMM_VLA_DECL(2, const element_stats_type, beta, (element_stats_type*)handle->reg_beta->data, nFmBlock); +LIBXSMM_VLA_DECL(2, element_stats_type, bmean, (element_stats_type*)handle->expvalue->data, nG); +LIBXSMM_VLA_DECL(2, element_stats_type, brstd, (element_stats_type*)handle->rcpstddev->data, nG); +LIBXSMM_VLA_DECL(2, element_stats_type, variance, (element_stats_type*)handle->variance->data, nG); +LIBXSMM_VLA_DECL(3, element_stats_type, sum_img, (element_stats_type*)handle->scratch, nBlocksFm, nFmBlock); +LIBXSMM_VLA_DECL(3, element_stats_type, sumsq_img, ((element_stats_type*)handle->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nBlocksFm, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) +LIBXSMM_VLA_DECL(5, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif + +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) +union libxsmm_bfloat16_hp input_f32; +union libxsmm_bfloat16_hp output_f32; +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) +union libxsmm_bfloat16_hp input_add_f32; +input_add_f32.i[1] = 0; +input_add_f32.i[0] = 0; +#endif +input_f32.i[1] = 0; +input_f32.i[0] = 0; +output_f32.i[1] = 0; +output_f32.i[0] = 0; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for ( img = thr_begin; img < thr_end; ++img ) { + element_stats_type* bmean_ptr = &LIBXSMM_VLA_ACCESS(2, bmean, img, 0, nG); + element_stats_type* brstd_ptr = &LIBXSMM_VLA_ACCESS(2, brstd, img, 0, nG); + element_stats_type* tvar_ptr = &LIBXSMM_VLA_ACCESS(2, variance, img, 0, nG); + element_stats_type* sum_img_ptr = NULL; + element_stats_type* sumsq_img_ptr = NULL; + + /* create reduction over all pixels per channel */ + for ( fm = 0; fm < nBlocksFm; ++fm ) { + /* @TODO check if we can bake this in into scratch */ + element_stats_type lcl_sum_ptr[64]; + element_stats_type lcl_sumsq_ptr[64]; + + sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, img, fm, 0, nBlocksFm, nFmBlock); + sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, img, fm, 0, nBlocksFm, nFmBlock); + + LIBXSMM_PRAGMA_SIMD + for ( v=0; v < nFmBlock; v++ ) { + lcl_sum_ptr[v] = (element_stats_type)0; + lcl_sumsq_ptr[v] = (element_stats_type)0; + } + + for ( hi=iph; hi < (ifh + iph); hi++ ) { + for ( wi=ipw; wi < (ifw + ipw); wi++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); + +#if !defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif + for (v=0; v < nFmBlock; v++) { +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) + input_f32.i[1] = input_ptr[v]; + lcl_sum_ptr[v] += input_f32.f; + lcl_sumsq_ptr[v] += (input_f32.f * input_f32.f); +#else + lcl_sum_ptr[v] += input_ptr[v]; + lcl_sumsq_ptr[v] += (input_ptr[v] * input_ptr[v]); +#endif + } + } + } + + LIBXSMM_PRAGMA_SIMD + for (v=0; v < nFmBlock; v++) { + sum_img_ptr[v] = lcl_sum_ptr[v]; + sumsq_img_ptr[v] = lcl_sumsq_ptr[v]; + } + } + + /* new we compute mean, variance and rstd per channel group */ + sum_img_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img, img, 0, 0, nImg, nFmBlock); + sumsq_img_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img, img, 0, 0, nImg, nFmBlock); + for ( g = 0; g < nG; ++g ) { + element_stats_type lcl_fm_sum = 0.0f; + element_stats_type lcl_fm_sumsq = 0.0f; + + for ( fm = g*nFmG; fm < (g+1)*nFmG; ++fm ) { + lcl_fm_sum += sum_img_ptr[fm]; + lcl_fm_sumsq += sumsq_img_ptr[fm]; + } + + { + const element_stats_type tbmean = (recp_ghw * lcl_fm_sum); + const element_stats_type tbmeansq = tbmean * tbmean; + const element_stats_type tsqbmean = recp_ghw * lcl_fm_sumsq; + const element_stats_type tvar = tsqbmean - tbmeansq; + const element_stats_type tbrstd = (element_stats_type)(1.0/sqrt((double)tvar + sqrt_eps)); + bmean_ptr[g] = tbmean; + brstd_ptr[g] = tbrstd; + tvar_ptr[g] = tvar; + } + } + + /* let's scale the data */ + for ( fm = 0; fm < nBlocksFm; ++fm ) { + for ( hi=iph, ho=oph; hi < (ifh+iph); hi+=sh, ho++ ) { + for ( wi=ipw, wo=opw; wi < (ifw+ipw); wi+=sw, wo++ ) { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) + const element_input_type* input_add_ptr = &LIBXSMM_VLA_ACCESS(5, input_add, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); +#endif + const element_stats_type* gamma_ptr = &LIBXSMM_VLA_ACCESS(2, gamma, fm, 0, nFmBlock); + const element_stats_type* beta_ptr = &LIBXSMM_VLA_ACCESS(2, beta, fm, 0, nFmBlock); + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + unsigned char* relumask_ptr = &LIBXSMM_VLA_ACCESS(5, relumask, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#endif + float o; + +#if 0 +#if !defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif +#endif + for (v = 0; v < nFmBlock; v++ ) { + g = ((fm*nFmBlock)+v)/nFmG; +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) + input_f32.i[1] = input_ptr[v]; + o = gamma_ptr[v]*(input_f32.f - bmean_ptr[g])*brstd_ptr[g] + beta_ptr[v]; +#else + /* BN + scale (gamma, beta) */ + o = gamma_ptr[v]*(input_ptr[v] - bmean_ptr[g])*brstd_ptr[g] + beta_ptr[v]; +#endif + /* Eltwise */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE) +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) + input_add_f32.i[1] = input_add_ptr[v]; + o += input_add_f32.f; +#else + o += input_add_ptr[v]; +#endif +#endif + /* ReLU */ +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU) + o = ( o > 0.0f ) ? o : 0.0f; +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK) + o = ( o > 0.0f ) ? o : 0.0f; + relumask_ptr[v] = (unsigned char)(o > 0.0f ? 1 : 0); +#endif +#if defined(LIBXSMM_DNN_FUSEDGN_FWD_BF16) + output_f32.f = o; + output_ptr[v] = output_f32.i[1]; +#else + output_ptr[v] = o; +#endif + } + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..1818ab34f15756490eb17aa0c86521cd9c9c070b --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c @@ -0,0 +1,91 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512) +# define _mm512_load_fil(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +# define _mm512_store_fil(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16((B)),16))) +#endif + +/* loop counters */ +libxsmm_blasint i; + +/* computing first logical thread */ +const int ltid = tid - start_thread; + +/* number of tasks that could run in parallel for the filters */ +const int work = handle->desc.C * handle->desc.K; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +element_filter_type* filter = (element_filter_type*)handle->reg_filter->data; +element_filter_type* dfilter = (element_filter_type*)handle->grad_filter->data; +#if defined(LIBXSMM_DNN_OPTIMIZER_SGD_BF16) || defined(LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512) +element_master_type* master = (element_master_type*)handle->master_filter->data; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init( handle->barrier, ltid ); + +#if defined(LIBXSMM_DNN_OPTIMIZER_SGD_BF16) || defined(LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512) +#if defined(LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512) +{ + libxsmm_blasint iv = ( (thr_end-thr_begin)/16 ) * 16; /* compute iterations which are vectorizable */ + __m512 vlr = _mm512_set1_ps( handle->desc.learning_rate ); + for ( i = thr_begin; i desc.learning_rate*t1.f); + t2.f = master[i]; + filter[i] = t2.i[1]; + } +} +#undef _mm512_load_fil +#undef _mm512_store_fil +#else +for ( i = thr_begin; i < thr_end; ++i ) { + libxsmm_bfloat16_hp t1, t2; + t1.i[0] =0; + t1.i[1] = dfilter[i]; + master[i] = master[i] - (handle->desc.learning_rate*t1.f); + t2.f = master[i]; + filter[i] = t2.i[1]; +} +#endif +#else +#if defined(LIBXSMM_DNN_OPTIMIZER_SGD_F32_AVX512) +{ + libxsmm_blasint iv = ( (thr_end-thr_begin)/16 ) * 16; /* compute iterations which are vectorizable */ + __m512 vlr = _mm512_set1_ps( handle->desc.learning_rate ); + for ( i = thr_begin; i < thr_begin + iv; i+=16 ) { + _mm512_storeu_ps( filter+i, _mm512_sub_ps( _mm512_loadu_ps( filter+i ), _mm512_mul_ps( vlr, _mm512_loadu_ps( dfilter + i ) ) ) ) ; + } + for ( i = thr_begin + iv; i < thr_end; ++i ) { + filter[i] = filter[i] - (handle->desc.learning_rate*dfilter[i]); + } +} +#else +for ( i = thr_begin; i < thr_end; ++i ) { + filter[i] = filter[i] - (handle->desc.learning_rate*dfilter[i]); +} +#endif +#endif + +libxsmm_barrier_wait( handle->barrier, ltid ); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..72e92417bc28b793703d650044df3a537886c39b --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c @@ -0,0 +1,153 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) +const int sh = handle->desc.u; +const int sw = handle->desc.v; +#endif +const int ofh = handle->ofh; +const int ofw = handle->ofw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int ho = 0; +int wo = 0; +int hi = 0; +int wi = 0; +int v = 0; +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) +int kh = 0; +int kw = 0; +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S); +#else +element_input_type recp_pool_size = 1.0f/((element_input_type)handle->desc.R*(element_input_type)handle->desc.S); +#endif +#endif + +/* multi-dim arrays declaration */ +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +float* lcl_buffer_ptr = ((float*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)16*(size_t)ltid); +LIBXSMM_VLA_DECL(3, float, lcl_dinput, lcl_buffer_ptr, ifw, 16); +#else +element_output_type* lcl_buffer_ptr = ((element_input_type*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)16*(size_t)ltid); +LIBXSMM_VLA_DECL(3, element_input_type, lcl_dinput, lcl_buffer_ptr, ifw, 16); +#endif +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 16); +LIBXSMM_VLA_DECL(5, const element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 16); +#if defined(LIBXSMM_DNN_POOLING_BWD_MAX) +LIBXSMM_VLA_DECL(5, const element_mask_type, mask, (element_mask_type* )handle->mask->data, nBlocksFm, ofh, ofw, 16); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) { + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + + for ( v = 0; v < ifh*ifw*16; v += 16 ) { + _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_setzero_ps() ); + } + +#if defined(LIBXSMM_DNN_POOLING_BWD_MAX) + for ( ho = oph; ho < (ofh+oph); ho++ ) { + for ( wo = opw; wo < (ofw+opw); wo++ ) { + const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, 16); + const element_mask_type* mask_ptr = &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, 16); + + __m512 lcl_vdinput = _mm512_i32gather_ps( _mm512_loadu_si512( mask_ptr ), lcl_buffer_ptr, 4 ); + lcl_vdinput = _mm512_add_ps( lcl_vdinput, _mm512_load_act( doutput_ptr ) ); + _mm512_i32scatter_ps( lcl_buffer_ptr, _mm512_loadu_si512( mask_ptr ), lcl_vdinput, 4 ); + } + } +#endif +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) + for ( ho = oph; ho < (ofh+oph); ho++ ) { + hi = ((ho-oph) * sh) - handle->desc.pad_h; + for ( wo = opw; wo < (ofw+opw); wo++ ) { + wi = ((wo-opw) * sw) - handle->desc.pad_w; + for ( kh = 0; kh < handle->desc.R; kh++ ) { + if (hi+kh < 0 || hi+kh >= ifh) continue; + for ( kw = 0; kw < handle->desc.S; kw++ ) { + if (wi+kw < 0 || wi+kw >= ifw) { + continue; + } else { + const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, 16); + float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi+kh, wi+kw, 0, ifw, 16); + const __m512 recp_pool_size_ps = _mm512_set1_ps( recp_pool_size ); + const __m512 lcl_dinput_ps = _mm512_loadu_ps( lcl_dinput_ptr ); + _mm512_storeu_ps( lcl_dinput_ptr, _mm512_fmadd_ps( _mm512_load_act( doutput_ptr ), recp_pool_size_ps, lcl_dinput_ps ) ); + } + } + } + } + } +#endif + + /* copy the local buffer into dinput activations */ + for ( hi = iph; hi < (ifh+iph); hi++ ) { + for ( wi = ipw; wi < (ifw+ipw); wi++ ) { + element_input_type* dinput_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, 16); + float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi-iph, wi-ipw, 0, ifw, 16); + _mm512_stream_act( dinput_ptr, _mm512_loadu_ps( lcl_dinput_ptr ) ); + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..b5740474d4e2220ef712ff89a44c99ae3e8b33e2 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c @@ -0,0 +1,161 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) +const int sh = handle->desc.u; +const int sw = handle->desc.v; +#endif +const int ofh = handle->ofh; +const int ofw = handle->ofw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int ho = 0; +int wo = 0; +int hi = 0; +int wi = 0; +int v = 0; +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) +int kh = 0; +int kw = 0; +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S); +#else +element_input_type recp_pool_size = 1.0f/((element_input_type)handle->desc.R*(element_input_type)handle->desc.S); +#endif +#endif + +/* multi-dim arrays declaration */ +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +float* lcl_buffer_ptr = ((float*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)32*(size_t)ltid); +LIBXSMM_VLA_DECL(3, float, lcl_dinput, lcl_buffer_ptr, ifw, 32); +#else +element_output_type* lcl_buffer_ptr = ((element_input_type*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)32*(size_t)ltid); +LIBXSMM_VLA_DECL(3, element_input_type, lcl_dinput, lcl_buffer_ptr, ifw, 32); +#endif +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 32); +LIBXSMM_VLA_DECL(5, const element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 32); +#if defined(LIBXSMM_DNN_POOLING_BWD_MAX) +LIBXSMM_VLA_DECL(5, const element_mask_type, mask, (element_mask_type* )handle->mask->data, nBlocksFm, ofh, ofw, 32); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) { + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + + for( v = 0; v < ifh*ifw*32; v += 16 ) { + _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_setzero_ps() ); + } + +#if defined(LIBXSMM_DNN_POOLING_BWD_MAX) + for( ho = oph; ho < (ofh+oph); ho++ ) { + for( wo = opw; wo < (ofw+opw); wo++ ) { + __m512 lcl_vdinput, lcl_vdinput2; + const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, 32); + const element_mask_type* mask_ptr = &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, 32); + + lcl_vdinput = _mm512_i32gather_ps( _mm512_loadu_si512( mask_ptr ), lcl_buffer_ptr, 4 ); + lcl_vdinput = _mm512_add_ps( lcl_vdinput, _mm512_load_act( doutput_ptr ) ); + _mm512_i32scatter_ps( lcl_buffer_ptr, _mm512_loadu_si512( mask_ptr ), lcl_vdinput, 4 ); + + lcl_vdinput2 = _mm512_i32gather_ps( _mm512_loadu_si512( mask_ptr+16 ), lcl_buffer_ptr, 4 ); + lcl_vdinput2 = _mm512_add_ps( lcl_vdinput2, _mm512_load_act( doutput_ptr+16 ) ); + _mm512_i32scatter_ps( lcl_buffer_ptr, _mm512_loadu_si512( mask_ptr+16 ), lcl_vdinput2, 4 ); + } + } +#endif +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) + for( ho = oph; ho < (ofh+oph); ho++ ) { + hi = ((ho-oph) * sh) - handle->desc.pad_h; + for( wo = opw; wo < (ofw+opw); wo++ ) { + wi = ((wo-opw) * sw) - handle->desc.pad_w; + for( kh = 0; kh < handle->desc.R; kh++ ) { + if (hi+kh < 0 || hi+kh >= ifh) continue; + for( kw = 0; kw < handle->desc.S; kw++ ) { + if (wi+kw < 0 || wi+kw >= ifw) { + continue; + } else { + const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, 32); + float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi+kh, wi+kw, 0, ifw, 32); + const __m512 recp_pool_size_ps = _mm512_set1_ps( recp_pool_size ); + const __m512 lcl_dinput_ps = _mm512_loadu_ps( lcl_dinput_ptr ); + const __m512 lcl_dinput_ps2 = _mm512_loadu_ps( lcl_dinput_ptr+16 ); + _mm512_storeu_ps( lcl_dinput_ptr, _mm512_fmadd_ps( _mm512_load_act( doutput_ptr ), recp_pool_size_ps, lcl_dinput_ps ) ); + _mm512_storeu_ps( lcl_dinput_ptr+16, _mm512_fmadd_ps( _mm512_load_act( doutput_ptr+16 ), recp_pool_size_ps, lcl_dinput_ps2 ) ); + } + } + } + } + } +#endif + + /* copy the local buffer into dinput activations */ + for( hi = iph; hi < (ifh+iph); hi++ ) { + for( wi = ipw; wi < (ifw+ipw); wi++ ) { + element_input_type* dinput_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, 32); + float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi-iph, wi-ipw, 0, ifw, 32); + _mm512_stream_act( dinput_ptr, _mm512_loadu_ps( lcl_dinput_ptr ) ); + _mm512_stream_act( dinput_ptr+16, _mm512_loadu_ps( lcl_dinput_ptr+16 ) ); + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..70b652021213b1d74c13e1e68fc6a6414cf80481 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c @@ -0,0 +1,170 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) +const int sh = handle->desc.u; +const int sw = handle->desc.v; +#endif +const int ofh = handle->ofh; +const int ofw = handle->ofw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm * 4; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int img = 0; +int fm1 = 0; +int fm2 = 0; +int imgfm = 0; +int ho = 0; +int wo = 0; +int hi = 0; +int wi = 0; +int v = 0; +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) +int kh = 0; +int kw = 0; +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S); +#else +element_input_type recp_pool_size = 1.0f/((element_input_type)handle->desc.R*(element_input_type)handle->desc.S); +#endif +#endif + +/* multi-dim arrays declaration */ +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +float* lcl_buffer_ptr = ((float*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)64*(size_t)ltid); +LIBXSMM_VLA_DECL(3, float, lcl_dinput, lcl_buffer_ptr, ifw, 16); +#else +element_output_type* lcl_buffer_ptr = ((element_input_type*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)64*(size_t)ltid); +LIBXSMM_VLA_DECL(3, element_input_type, lcl_dinput, lcl_buffer_ptr, ifw, 16); +#endif +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 64); +LIBXSMM_VLA_DECL(5, const element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 64); +#if defined(LIBXSMM_DNN_POOLING_BWD_MAX) +LIBXSMM_VLA_DECL(5, const element_mask_type, mask, (element_mask_type* )handle->mask->data, nBlocksFm, ofh, ofw, 64); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) { + img = imgfm / (nBlocksFm*4); + fm1 = imgfm % (nBlocksFm*4); + fm2 = imgfm % (nBlocksFm*4); + fm1 = fm1/4; + fm2 = (fm2%4)*16; + + for( v = 0; v < ifh*ifw*16; v += 16 ) { + _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_setzero_ps() ); + } + +#if defined(LIBXSMM_DNN_POOLING_BWD_MAX) + for( ho = oph; ho < (ofh+oph); ho++ ) { + for( wo = opw; wo < (ofw+opw); wo++ ) { + __m512 lcl_vdinput/*, lcl_vdinput2, lcl_vdinput3, lcl_vdinput4*/; + const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm1, ho, wo, fm2, nBlocksFm, ofhp, ofwp, 64); + const element_mask_type* mask_ptr = &LIBXSMM_VLA_ACCESS(5, mask, img, fm1, ho-oph, wo-opw, fm2, nBlocksFm, ofh, ofw, 64); +#if 1 + lcl_vdinput = _mm512_i32gather_ps( _mm512_loadu_si512( mask_ptr ), lcl_buffer_ptr, 4 ); + lcl_vdinput = _mm512_add_ps( lcl_vdinput, _mm512_load_act( doutput_ptr ) ); + _mm512_i32scatter_ps( lcl_buffer_ptr, _mm512_loadu_si512( mask_ptr ), lcl_vdinput, 4 ); +#else + for ( v = 0; v < 16; ++v ) { +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) + union libxsmm_bfloat16_hp del_output_f32; + del_output_f32.i[1] = doutput_ptr[v]; + del_output_f32.i[0] = 0; + lcl_buffer_ptr[mask_ptr[v]] += del_output_f32.f; +#else + lcl_buffer_ptr[mask_ptr[v]] += doutput_ptr[v]; +#endif + } +#endif + } + } +#endif +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) + for( ho = oph; ho < (ofh+oph); ho++ ) { + hi = ((ho-oph) * sh) - handle->desc.pad_h; + for( wo = opw; wo < (ofw+opw); wo++ ) { + wi = ((wo-opw) * sw) - handle->desc.pad_w; + for( kh = 0; kh < handle->desc.R; kh++ ) { + if (hi+kh < 0 || hi+kh >= ifh) continue; + for( kw = 0; kw < handle->desc.S; kw++ ) { + if (wi+kw < 0 || wi+kw >= ifw) { + continue; + } else { + const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm1, ho, wo, fm2, nBlocksFm, ofhp, ofwp, 64); + float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi+kh, wi+kw, 0, ifw, 16); + const __m512 recp_pool_size_ps = _mm512_set1_ps( recp_pool_size ); + const __m512 lcl_dinput_ps = _mm512_loadu_ps( lcl_dinput_ptr ); + _mm512_storeu_ps( lcl_dinput_ptr, _mm512_fmadd_ps( _mm512_load_act( doutput_ptr ), recp_pool_size_ps, lcl_dinput_ps ) ); + } + } + } + } + } +#endif + + /* copy the local buffer into dinput activations */ + for( hi = iph; hi < (ifh+iph); hi++ ) { + for( wi = ipw; wi < (ifw+ipw); wi++ ) { + element_input_type* dinput_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm1, hi, wi, fm2, nBlocksFm, ifhp, ifwp, 64); + float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi-iph, wi-ipw, 0, ifw, 16); + _mm512_stream_act( dinput_ptr, _mm512_loadu_ps( lcl_dinput_ptr ) ); + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..805db71fcc3b88ef42f90150e3046b8523ee00ef --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c @@ -0,0 +1,184 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) +const int sh = handle->desc.u; +const int sw = handle->desc.v; +#endif +const int ofh = handle->ofh; +const int ofw = handle->ofw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; +const int nFmBlock = handle->ifmblock; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int ho = 0; +int wo = 0; +int hi = 0; +int wi = 0; +int v = 0; +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) +int kh = 0; +int kw = 0; +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S); +#else +element_input_type recp_pool_size = 1.0f/((element_input_type)handle->desc.R*(element_input_type)handle->desc.S); +#endif +#endif + +/* multi-dim arrays declaration */ +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +float *const lcl_buffer_ptr = (float*)handle->scratch + (size_t)ifh*ifw*nFmBlock*ltid; +LIBXSMM_VLA_DECL(3, float, lcl_dinput, lcl_buffer_ptr, ifw, nFmBlock); +#else +element_output_type *const lcl_buffer_ptr = (element_input_type*)handle->scratch + (size_t)ifh*ifw*nFmBlock*ltid; +LIBXSMM_VLA_DECL(3, element_input_type, lcl_dinput, lcl_buffer_ptr, ifw, nFmBlock); +#endif +LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, nFmBlock); +LIBXSMM_VLA_DECL(5, const element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, nFmBlock); +#if defined(LIBXSMM_DNN_POOLING_BWD_MAX) +LIBXSMM_VLA_DECL(5, const element_mask_type, mask, (element_mask_type* )handle->mask->data, nBlocksFm, ofh, ofw, nFmBlock); +#endif + +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) +union libxsmm_bfloat16_hp del_input_f32; +union libxsmm_bfloat16_hp del_output_f32; +del_input_f32.i[1] = 0; +del_input_f32.i[0] = 0; +del_output_f32.i[1] = 0; +del_output_f32.i[0] = 0; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) { + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + + LIBXSMM_PRAGMA_SIMD + for ( v = 0; v < ifh*ifw*nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) + lcl_buffer_ptr[v] = (float)0; +#else + lcl_buffer_ptr[v] = (element_input_type)0; +#endif + } + +#if defined(LIBXSMM_DNN_POOLING_BWD_MAX) + for ( ho = oph; ho < (ofh+oph); ho++ ) { + for ( wo = opw; wo < (ofw+opw); wo++ ) { + const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); + const element_mask_type* mask_ptr = &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, nFmBlock); + +#if !defined(LIBXSMM_DNN_POOLING_BWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif + for ( v = 0; v < nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) + del_output_f32.i[1] = doutput_ptr[v]; + lcl_buffer_ptr[mask_ptr[v]] += del_output_f32.f; +#else + lcl_buffer_ptr[mask_ptr[v]] += doutput_ptr[v]; +#endif + } + } + } +#endif +#if defined(LIBXSMM_DNN_POOLING_BWD_AVG) + for ( ho = oph; ho < (ofh+oph); ho++ ) { + hi = ((ho-oph) * sh) - handle->desc.pad_h; + for ( wo = opw; wo < (ofw+opw); wo++ ) { + wi = ((wo-opw) * sw) - handle->desc.pad_w; + for ( kh = 0; kh < handle->desc.R; kh++ ) { + if (hi+kh < 0 || hi+kh >= ifh) continue; + for ( kw = 0; kw < handle->desc.S; kw++ ) { + if (wi+kw < 0 || wi+kw >= ifw) { + continue; + } else { + const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) + float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi+kh, wi+kw, 0, ifw, nFmBlock); +#else + element_input_type* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi+kh, wi+kw, 0, ifw, nFmBlock); +#endif + +#if !defined(LIBXSMM_DNN_POOLING_BWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif + for ( v = 0; v < nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) + del_output_f32.i[1] = doutput_ptr[v]; + lcl_dinput_ptr[v] += (del_output_f32.f * recp_pool_size); +#else + lcl_dinput_ptr[v] += (doutput_ptr[v] * recp_pool_size); +#endif + } + } + } + } + } + } +#endif + + /* copy the local buffer into dinput activations */ + for ( hi = iph; hi < (ifh+iph); hi++ ) { + for ( wi = ipw; wi < (ifw+ipw); wi++ ) { + element_input_type* dinput_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, nFmBlock); +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) + float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi-iph, wi-ipw, 0, ifw, nFmBlock); +#else + element_input_type* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi-iph, wi-ipw, 0, ifw, nFmBlock); +#endif + +#if !defined(LIBXSMM_DNN_POOLING_BWD_BF16) + LIBXSMM_PRAGMA_SIMD +#endif + for ( v = 0; v < nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_POOLING_BWD_BF16) + del_input_f32.f = lcl_dinput_ptr[v]; + dinput_ptr[v] = del_input_f32.i[1]; +#else + dinput_ptr[v] = lcl_dinput_ptr[v]; +#endif + } + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..76137cd514fd84db90c7bed34fa32b894119f5a1 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c @@ -0,0 +1,171 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = handle->ofh; +const int ofw = handle->ofw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int ho = 0; +int wo = 0; +int hi = 0; +int wi = 0; +int kh = 0; +int kw = 0; +int v = 0; +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S); +#else +element_output_type recp_pool_size = 1.0f/((element_output_type)handle->desc.R*(element_output_type)handle->desc.S); +#endif +#endif + +/* multi-dim arrays declaration */ +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +float* lcl_buffer_ptr = ((float*)handle->scratch)+((size_t)ofh*(size_t)ofw*(size_t)16*(size_t)ltid); +LIBXSMM_VLA_DECL(3, float, lcl_output, lcl_buffer_ptr, ofw, 16); +#else +element_output_type* lcl_buffer_ptr = ((element_output_type*)handle->scratch)+((size_t)ofh*(size_t)ofw*(size_t)16*(size_t)ltid); +LIBXSMM_VLA_DECL(3, element_output_type, lcl_output, lcl_buffer_ptr, ofw, 16); +#endif +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 16); +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 16); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) +LIBXSMM_VLA_DECL(5, element_mask_type, mask, (element_mask_type* )handle->mask->data, nBlocksFm, ofh, ofw, 16); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + __m512i lcl_viadd = _mm512_set_epi32( 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 ); +#endif + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + + for ( v = 0; v < ofh*ofw*16; v+=16 ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_set1_ps(-FLT_MAX) ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_setzero_ps() ); +#endif + } + + for ( ho = oph; ho < (ofh+oph); ho++ ) { + hi = ((ho-oph) * sh) - handle->desc.pad_h; + for ( wo = opw; wo < (ofw+opw); wo++ ) { + float* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, wo-opw, 0, ofw, 16); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + __m512i lcl_vmask = _mm512_loadu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, 16) ); +#endif + __m512 lcl_voutput = _mm512_loadu_ps( lcl_output_ptr ); + + wi = ((wo-opw) * sw) - handle->desc.pad_w; + for ( kh = 0; kh < handle->desc.R; kh++ ) { + if (hi+kh < 0 || hi+kh >= ifh) continue; + for ( kw = 0; kw < handle->desc.S; kw++ ) { + if (wi+kw < 0 || wi+kw >= ifw) { + continue; + } else { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi+kh+iph, wi+kw+ipw, 0, nBlocksFm, ifhp, ifwp, 16); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + __m512i lcl_vnewmask = _mm512_add_epi32( lcl_viadd, _mm512_set1_epi32((hi+kh)*ifw*16 + (wi+kw)*16) ); + __m512 lcl_vinput = _mm512_load_act( input_ptr ); + __mmask16 lcl_mlt = _mm512_cmp_ps_mask( lcl_voutput, lcl_vinput, _CMP_LT_OS ); + lcl_voutput = _mm512_mask_blend_ps( lcl_mlt, lcl_voutput, lcl_vinput ); + lcl_vmask = _mm512_mask_blend_epi32( lcl_mlt, lcl_vmask, lcl_vnewmask ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + lcl_voutput = _mm512_add_ps( lcl_voutput, _mm512_load_act( input_ptr ) ); +#endif + } + } + } +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + _mm512_storeu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, 16), lcl_vmask ); +#endif + _mm512_storeu_ps( lcl_output_ptr, lcl_voutput ); + } + } + + /* copy the local buffer into output activations */ + for ( ho = oph; ho < (ofh+oph); ho++ ) { + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 16); + float* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, 0, 0, ofw, 16); + for ( wo = opw; wo < (ofw+opw); wo++ ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + const __m512 recp_pool_size_ps = _mm512_set1_ps( recp_pool_size ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + _mm512_stream_act( output_ptr, _mm512_loadu_ps( lcl_output_ptr ) ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + _mm512_stream_act( output_ptr, _mm512_mul_ps( _mm512_loadu_ps( lcl_output_ptr ), recp_pool_size_ps ) ); +#endif + output_ptr += 16; + lcl_output_ptr += 16; + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..7f53b509ba5a7ba77bb14f37ba7dfb68ab4678a6 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c @@ -0,0 +1,181 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = handle->ofh; +const int ofw = handle->ofw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int ho = 0; +int wo = 0; +int hi = 0; +int wi = 0; +int kh = 0; +int kw = 0; +int v = 0; +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S); +#else +element_output_type recp_pool_size = 1.0f/((element_output_type)handle->desc.R*(element_output_type)handle->desc.S); +#endif +#endif + +/* multi-dim arrays declaration */ +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +float* lcl_buffer_ptr = ((float*)handle->scratch)+((size_t)ofh*(size_t)ofw*(size_t)32*(size_t)ltid); +LIBXSMM_VLA_DECL(3, float, lcl_output, lcl_buffer_ptr, ofw, 32); +#else +element_output_type* lcl_buffer_ptr = ((element_output_type*)handle->scratch)+((size_t)ofh*(size_t)ofw*(size_t)32*(size_t)ltid); +LIBXSMM_VLA_DECL(3, element_output_type, lcl_output, lcl_buffer_ptr, ofw, 32); +#endif +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 32); +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 32); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) +LIBXSMM_VLA_DECL(5, element_mask_type, mask, (element_mask_type* )handle->mask->data, nBlocksFm, ofh, ofw, 32); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + __m512i lcl_viadd = _mm512_set_epi32( 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 ); +#endif + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + + for( v = 0; v < ofh*ofw*32; v+=16 ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_set1_ps(-FLT_MAX) ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_setzero_ps() ); +#endif + } + + for( ho = oph; ho < (ofh+oph); ho++ ) { + hi = ((ho-oph) * sh) - handle->desc.pad_h; + for( wo = opw; wo < (ofw+opw); wo++ ) { + float* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, wo-opw, 0, ofw, 32); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + __m512i lcl_vmask = _mm512_loadu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, 32) ); + __m512i lcl_vmask2 = _mm512_loadu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 16, nBlocksFm, ofh, ofw, 32) ); +#endif + __m512 lcl_voutput = _mm512_loadu_ps( lcl_output_ptr ); + __m512 lcl_voutput2 = _mm512_loadu_ps( lcl_output_ptr+16 ); + + wi = ((wo-opw) * sw) - handle->desc.pad_w; + for( kh = 0; kh < handle->desc.R; kh++ ) { + if (hi+kh < 0 || hi+kh >= ifh) continue; + for( kw = 0; kw < handle->desc.S; kw++ ) { + if (wi+kw < 0 || wi+kw >= ifw) { + continue; + } else { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi+kh+iph, wi+kw+ipw, 0, nBlocksFm, ifhp, ifwp, 32); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + __m512i lcl_vnewmask = _mm512_add_epi32( lcl_viadd, _mm512_set1_epi32((hi+kh)*ifw*32 + (wi+kw)*32) ); + __m512i lcl_vnewmask2 = _mm512_add_epi32( lcl_viadd, _mm512_set1_epi32((hi+kh)*ifw*32 + (wi+kw)*32 + 16) ); + __m512 lcl_vinput = _mm512_load_act( input_ptr ); + __m512 lcl_vinput2 = _mm512_load_act( input_ptr+16 ); + __mmask16 lcl_mlt = _mm512_cmp_ps_mask( lcl_voutput, lcl_vinput, _CMP_LT_OS ); + __mmask16 lcl_mlt2 = _mm512_cmp_ps_mask( lcl_voutput2, lcl_vinput2, _CMP_LT_OS ); + lcl_voutput = _mm512_mask_blend_ps( lcl_mlt, lcl_voutput, lcl_vinput ); + lcl_voutput2 = _mm512_mask_blend_ps( lcl_mlt2, lcl_voutput2, lcl_vinput2 ); + lcl_vmask = _mm512_mask_blend_epi32( lcl_mlt, lcl_vmask, lcl_vnewmask ); + lcl_vmask2 = _mm512_mask_blend_epi32( lcl_mlt2, lcl_vmask2, lcl_vnewmask2 ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + lcl_voutput = _mm512_add_ps( lcl_voutput, _mm512_load_act( input_ptr ) ); + lcl_voutput2 = _mm512_add_ps( lcl_voutput2, _mm512_load_act( input_ptr+16 ) ); +#endif + } + } + } +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + _mm512_storeu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, 32), lcl_vmask ); + _mm512_storeu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 16, nBlocksFm, ofh, ofw, 32), lcl_vmask2 ); +#endif + _mm512_storeu_ps( lcl_output_ptr, lcl_voutput ); + _mm512_storeu_ps( lcl_output_ptr+16, lcl_voutput2 ); + } + } + + /* copy the local buffer into output activations */ + for( ho = oph; ho < (ofh+oph); ho++ ) { + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 32); + float* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, 0, 0, ofw, 32); + for( wo = opw; wo < (ofw+opw); wo++ ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + const __m512 recp_pool_size_ps = _mm512_set1_ps( recp_pool_size ); + _mm512_stream_act( output_ptr, _mm512_mul_ps( _mm512_loadu_ps( lcl_output_ptr ), recp_pool_size_ps ) ); + _mm512_stream_act( output_ptr+16, _mm512_mul_ps( _mm512_loadu_ps( lcl_output_ptr+16 ), recp_pool_size_ps ) ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + _mm512_stream_act( output_ptr, _mm512_loadu_ps( lcl_output_ptr ) ); + _mm512_stream_act( output_ptr+16, _mm512_loadu_ps( lcl_output_ptr+16 ) ); +#endif + output_ptr += 32; + lcl_output_ptr += 32; + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..b7f91174af32a6714016f54e0787291adc3d8765 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c @@ -0,0 +1,205 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +# define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) +#if 1 +# define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) +#else +# define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +# define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) +#endif +#else +# define _mm512_load_act(A) _mm512_loadu_ps(A) +# define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) +# define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) +#endif + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = handle->ofh; +const int ofw = handle->ofw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int ho = 0; +int wo = 0; +int hi = 0; +int wi = 0; +int kh = 0; +int kw = 0; +int v = 0; +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S); +#else +element_output_type recp_pool_size = 1.0f/((element_output_type)handle->desc.R*(element_output_type)handle->desc.S); +#endif +#endif + +/* multi-dim arrays declaration */ +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +float* lcl_buffer_ptr = ((float*)handle->scratch)+((size_t)ofh*(size_t)ofw*(size_t)64*(size_t)ltid); +LIBXSMM_VLA_DECL(3, float, lcl_output, lcl_buffer_ptr, ofw, 64); +#else +element_output_type* lcl_buffer_ptr = ((element_output_type*)handle->scratch)+((size_t)ofh*(size_t)ofw*(size_t)64*(size_t)ltid); +LIBXSMM_VLA_DECL(3, element_output_type, lcl_output, lcl_buffer_ptr, ofw, 64); +#endif +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, 64); +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, 64); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) +LIBXSMM_VLA_DECL(5, element_mask_type, mask, (element_mask_type* )handle->mask->data, nBlocksFm, ofh, ofw, 64); +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + __m512i lcl_viadd = _mm512_set_epi32( 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 ); +#endif + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + + for( v = 0; v < ofh*ofw*64; v+=16 ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_set1_ps(-FLT_MAX) ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_setzero_ps() ); +#endif + } + + for( ho = oph; ho < (ofh+oph); ho++ ) { + hi = ((ho-oph) * sh) - handle->desc.pad_h; + for( wo = opw; wo < (ofw+opw); wo++ ) { + float* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, wo-opw, 0, ofw, 64); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + __m512i lcl_vmask = _mm512_loadu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, 64) ); + __m512i lcl_vmask2 = _mm512_loadu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 16, nBlocksFm, ofh, ofw, 64) ); + __m512i lcl_vmask3 = _mm512_loadu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 32, nBlocksFm, ofh, ofw, 64) ); + __m512i lcl_vmask4 = _mm512_loadu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 48, nBlocksFm, ofh, ofw, 64) ); +#endif + __m512 lcl_voutput = _mm512_loadu_ps( lcl_output_ptr ); + __m512 lcl_voutput2 = _mm512_loadu_ps( lcl_output_ptr+16 ); + __m512 lcl_voutput3 = _mm512_loadu_ps( lcl_output_ptr+32 ); + __m512 lcl_voutput4 = _mm512_loadu_ps( lcl_output_ptr+48 ); + + wi = ((wo-opw) * sw) - handle->desc.pad_w; + for( kh = 0; kh < handle->desc.R; kh++ ) { + if (hi+kh < 0 || hi+kh >= ifh) continue; + for( kw = 0; kw < handle->desc.S; kw++ ) { + if (wi+kw < 0 || wi+kw >= ifw) { + continue; + } else { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi+kh+iph, wi+kw+ipw, 0, nBlocksFm, ifhp, ifwp, 64); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + __m512i lcl_vnewmask = _mm512_add_epi32( lcl_viadd, _mm512_set1_epi32((hi+kh)*ifw*16 + (wi+kw)*16) ); + __m512i lcl_vnewmask2 = _mm512_add_epi32( lcl_viadd, _mm512_set1_epi32((hi+kh)*ifw*16 + (wi+kw)*16) ); + __m512i lcl_vnewmask3 = _mm512_add_epi32( lcl_viadd, _mm512_set1_epi32((hi+kh)*ifw*16 + (wi+kw)*16) ); + __m512i lcl_vnewmask4 = _mm512_add_epi32( lcl_viadd, _mm512_set1_epi32((hi+kh)*ifw*16 + (wi+kw)*16) ); + __m512 lcl_vinput = _mm512_load_act( input_ptr ); + __m512 lcl_vinput2 = _mm512_load_act( input_ptr+16 ); + __m512 lcl_vinput3 = _mm512_load_act( input_ptr+32 ); + __m512 lcl_vinput4 = _mm512_load_act( input_ptr+48 ); + __mmask16 lcl_mlt = _mm512_cmp_ps_mask( lcl_voutput, lcl_vinput, _CMP_LT_OS ); + __mmask16 lcl_mlt2 = _mm512_cmp_ps_mask( lcl_voutput2, lcl_vinput2, _CMP_LT_OS ); + __mmask16 lcl_mlt3 = _mm512_cmp_ps_mask( lcl_voutput3, lcl_vinput3, _CMP_LT_OS ); + __mmask16 lcl_mlt4 = _mm512_cmp_ps_mask( lcl_voutput4, lcl_vinput4, _CMP_LT_OS ); + lcl_voutput = _mm512_mask_blend_ps( lcl_mlt, lcl_voutput, lcl_vinput ); + lcl_voutput2 = _mm512_mask_blend_ps( lcl_mlt2, lcl_voutput2, lcl_vinput2 ); + lcl_voutput3 = _mm512_mask_blend_ps( lcl_mlt3, lcl_voutput3, lcl_vinput3 ); + lcl_voutput4 = _mm512_mask_blend_ps( lcl_mlt4, lcl_voutput4, lcl_vinput4 ); + lcl_vmask = _mm512_mask_blend_epi32( lcl_mlt, lcl_vmask, lcl_vnewmask ); + lcl_vmask2 = _mm512_mask_blend_epi32( lcl_mlt2, lcl_vmask2, lcl_vnewmask2 ); + lcl_vmask3 = _mm512_mask_blend_epi32( lcl_mlt3, lcl_vmask3, lcl_vnewmask3 ); + lcl_vmask4 = _mm512_mask_blend_epi32( lcl_mlt4, lcl_vmask4, lcl_vnewmask4 ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + lcl_voutput = _mm512_add_ps( lcl_voutput, _mm512_load_act( input_ptr ) ); + lcl_voutput2 = _mm512_add_ps( lcl_voutput2, _mm512_load_act( input_ptr+16 ) ); + lcl_voutput3 = _mm512_add_ps( lcl_voutput3, _mm512_load_act( input_ptr+32 ) ); + lcl_voutput4 = _mm512_add_ps( lcl_voutput4, _mm512_load_act( input_ptr+48 ) ); +#endif + } + } + } +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + _mm512_storeu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, 64), lcl_vmask ); + _mm512_storeu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 16, nBlocksFm, ofh, ofw, 64), lcl_vmask2 ); + _mm512_storeu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 32, nBlocksFm, ofh, ofw, 64), lcl_vmask3 ); + _mm512_storeu_si512( &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 48, nBlocksFm, ofh, ofw, 64), lcl_vmask4 ); +#endif + _mm512_storeu_ps( lcl_output_ptr, lcl_voutput ); + _mm512_storeu_ps( lcl_output_ptr+16, lcl_voutput2 ); + _mm512_storeu_ps( lcl_output_ptr+32, lcl_voutput3 ); + _mm512_storeu_ps( lcl_output_ptr+48, lcl_voutput4 ); + } + } + + /* copy the local buffer into output activations */ + for( ho = oph; ho < (ofh+oph); ho++ ) { + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, opw, 0, nBlocksFm, ofhp, ofwp, 64); + float* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, 0, 0, ofw, 64); + for( wo = opw; wo < (ofw+opw); wo++ ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + const __m512 recp_pool_size_ps = _mm512_set1_ps( recp_pool_size ); + _mm512_stream_act( output_ptr, _mm512_mul_ps( _mm512_loadu_ps( lcl_output_ptr ), recp_pool_size_ps ) ); + _mm512_stream_act( output_ptr+16, _mm512_mul_ps( _mm512_loadu_ps( lcl_output_ptr+16 ), recp_pool_size_ps ) ); + _mm512_stream_act( output_ptr+32, _mm512_mul_ps( _mm512_loadu_ps( lcl_output_ptr+32 ), recp_pool_size_ps ) ); + _mm512_stream_act( output_ptr+48, _mm512_mul_ps( _mm512_loadu_ps( lcl_output_ptr+48 ), recp_pool_size_ps ) ); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + _mm512_stream_act( output_ptr, _mm512_loadu_ps( lcl_output_ptr ) ); + _mm512_stream_act( output_ptr+16, _mm512_loadu_ps( lcl_output_ptr+16 ) ); + _mm512_stream_act( output_ptr+32, _mm512_loadu_ps( lcl_output_ptr+32 ) ); + _mm512_stream_act( output_ptr+48, _mm512_loadu_ps( lcl_output_ptr+48 ) ); +#endif + output_ptr += 64; + lcl_output_ptr += 64; + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + +# undef _mm512_load_act +# undef _mm512_stream_act +# undef _mm512_store_act + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..0a90220337ba4853b2795dc2fb70531ea6979f13 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c @@ -0,0 +1,194 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) +******************************************************************************/ + +/* size variables, all const */ +const int nImg = handle->desc.N; +const int ifh = handle->desc.H; +const int ifw = handle->desc.W; +const int sh = handle->desc.u; +const int sw = handle->desc.v; +const int ofh = handle->ofh; +const int ofw = handle->ofw; +const int iph = handle->desc.pad_h_in; +const int ipw = handle->desc.pad_w_in; +const int oph = handle->desc.pad_h_out; +const int opw = handle->desc.pad_w_out; +const int ofhp = ofh + 2*oph; +const int ofwp = ofw + 2*opw; +const int ifhp = ifh + 2*iph; +const int ifwp = ifw + 2*ipw; +/* here we assume that input and output blocking is similar */ +const int nBlocksFm = handle->blocksifm; +const int nFmBlock = handle->ifmblock; + +/* computing first logical thread */ +const int ltid = tid - start_thread; +/* number of tasks that could be run in parallel */ +const int work = nImg * nBlocksFm; +/* compute chunk size */ +const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* loop variables */ +int img = 0; +int fm = 0; +int imgfm = 0; +int ho = 0; +int wo = 0; +int hi = 0; +int wi = 0; +int kh = 0; +int kw = 0; +int v = 0; +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S); +#else +element_output_type recp_pool_size = 1.0f/((element_output_type)handle->desc.R*(element_output_type)handle->desc.S); +#endif +#endif + +/* multi-dim arrays declaration */ +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +float *const lcl_buffer_ptr = (float*)handle->scratch + (size_t)ofh*ofw*nFmBlock*ltid; +LIBXSMM_VLA_DECL(3, float, lcl_output, lcl_buffer_ptr, ofw, nFmBlock); +#else +element_output_type *const lcl_buffer_ptr = (element_output_type*)handle->scratch + (size_t)ofh*ofw*nFmBlock*ltid; +LIBXSMM_VLA_DECL(3, element_output_type, lcl_output, lcl_buffer_ptr, ofw, nFmBlock); +#endif +LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksFm, ifhp, ifwp, nFmBlock); +LIBXSMM_VLA_DECL(5, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksFm, ofhp, ofwp, nFmBlock); +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) +LIBXSMM_VLA_DECL(5, element_mask_type, mask, (element_mask_type* )handle->mask->data, nBlocksFm, ofh, ofw, nFmBlock); +#endif + +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) +union libxsmm_bfloat16_hp input_f32; +union libxsmm_bfloat16_hp output_f32; +input_f32.i[1] = 0; +input_f32.i[0] = 0; +output_f32.i[1] = 0; +output_f32.i[0] = 0; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, ltid); + +for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) { + img = imgfm / nBlocksFm; + fm = imgfm % nBlocksFm; + + LIBXSMM_PRAGMA_SIMD + for ( v = 0; v < ofh*ofw*nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + lcl_buffer_ptr[v] = -FLT_MAX; +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) + lcl_buffer_ptr[v] = (float)0.0; +#else + lcl_buffer_ptr[v] = (element_output_type)0.0; +#endif +#endif + } + + for ( ho = oph; ho < (ofh+oph); ho++ ) { + hi = ((ho-oph) * sh) - handle->desc.pad_h; + for ( wo = opw; wo < (ofw+opw); wo++ ) { + wi = ((wo-opw) * sw) - handle->desc.pad_w; + for ( kh = 0; kh < handle->desc.R; kh++ ) { + if (hi+kh < 0 || hi+kh >= ifh) continue; + for ( kw = 0; kw < handle->desc.S; kw++ ) { + if (wi+kw < 0 || wi+kw >= ifw) { + continue; + } else { + const element_input_type* input_ptr = &LIBXSMM_VLA_ACCESS(5, input, img, fm, hi+kh+iph, wi+kw+ipw, 0, nBlocksFm, ifhp, ifwp, nFmBlock); +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) + float* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, wo-opw, 0, ofw, nFmBlock); +#else + element_output_type* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, wo-opw, 0, ofw, nFmBlock); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + const int idx = (hi+kh)*ifw*nFmBlock + (wi+kw)*nFmBlock; + element_mask_type* mask_ptr = &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, nFmBlock); +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) + for ( v = 0; v < nFmBlock; v++ ) { + input_f32.i[1] = input_ptr[v]; +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + if ( input_f32.f > lcl_output_ptr[v] ) { + lcl_output_ptr[v] = input_f32.f; + mask_ptr[v] = idx + v; + } +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + lcl_output_ptr[v] += input_f32.f; +#endif + } +#else + LIBXSMM_PRAGMA_SIMD + for ( v = 0; v < nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + if ( input_ptr[v] > lcl_output_ptr[v] ) { + lcl_output_ptr[v] = input_ptr[v]; + mask_ptr[v] = idx + v; + } +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + lcl_output_ptr[v] += input_ptr[v]; +#endif + } +#endif + } + } + } + } + } + + /* copy the local buffer into output activations */ + for ( ho = oph; ho < (ofh+oph); ho++ ) { + for ( wo = opw; wo < (ofw+opw); wo++ ) { + element_output_type* output_ptr = &LIBXSMM_VLA_ACCESS(5, output, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, nFmBlock); +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) + float* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, wo-opw, 0, ofw, nFmBlock); +#else + element_output_type* lcl_output_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_output, ho-oph, wo-opw, 0, ofw, nFmBlock); +#endif + +#if defined(LIBXSMM_DNN_POOLING_FWD_BF16) + for ( v = 0; v < nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + output_f32.f = lcl_output_ptr[v]; +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + output_f32.f = lcl_output_ptr[v] * recp_pool_size; +#endif + output_ptr[v] = output_f32.i[1]; + } +#else + LIBXSMM_PRAGMA_SIMD + for ( v = 0; v < nFmBlock; v++ ) { +#if defined(LIBXSMM_DNN_POOLING_FWD_MAX) + output_ptr[v] = lcl_output_ptr[v]; +#endif +#if defined(LIBXSMM_DNN_POOLING_FWD_AVG) + output_ptr[v] = lcl_output_ptr[v] * recp_pool_size; +#endif + } +#endif + } + } +} + +libxsmm_barrier_wait(handle->barrier, ltid); + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_ck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_ck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..94e18fa664a30ea8551f3c07110ed2a34f920171 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_ck_generic.tpl.c @@ -0,0 +1,637 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +libxsmm_blasint K3 = K * 3; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +unsigned long long blocks; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *ht = (element_output_type*)(handle->ht ? handle->ht->data : NULL); +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ct = (element_output_type*)handle->cit->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_input_type *dxt = (element_input_type* )handle->dxt->data; +element_input_type *dhpD = (element_input_type* )handle->dhp->data; +element_filter_type *dw = (element_filter_type*)handle->dw->data; +element_filter_type *dr = (element_filter_type*)handle->dr->data; +element_output_type *db = (element_output_type*)handle->db->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *diD = (element_output_type*)handle->scratch_di; +element_output_type *dcD = (element_output_type*)handle->scratch_dci; +element_output_type *dfD = (element_output_type*)handle->scratch_df; +element_output_type *doD = (element_output_type*)handle->scratch_do; +element_output_type *doutD = (element_output_type*)handle->scratch_deltat; +element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +element_output_type *scratch_oT = (element_output_type*)handle->scratch_dpB; +element_filter_type *w_scratch = (element_filter_type*)handle->scratch_w; +element_filter_type *r_scratch = (element_filter_type*)handle->scratch_r; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[K]); +element_filter_type *wfD = &(w[2*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K]); +element_filter_type *rfD = &(r[2*K]); +element_filter_type *dwiD = &(dw[0]); +element_filter_type *dwcD = &(dw[K]); +element_filter_type *dwfD = &(dw[2*K]); +element_filter_type *driD = &(dr[0]); +element_filter_type *drcD = &(dr[K]); +element_filter_type *drfD = &(dr[2*K]); +element_filter_type *dwiD_scratch = &(w_scratch[0]); +element_filter_type *dwcD_scratch = &(w_scratch[C*K]); +element_filter_type *dwfD_scratch = &(w_scratch[2*C*K]); +element_filter_type *driD_scratch = &(r_scratch[0]); +element_filter_type *drcD_scratch = &(r_scratch[K*K]); +element_filter_type *drfD_scratch = &(r_scratch[2*K*K]); +element_output_type *dbi = &(db[0]); +element_output_type *dbc = &(db[K]); +element_output_type *dbf = &(db[2*K]); +element_filter_type *scratch_wiT = &(scratch_wT[0]); +element_filter_type *scratch_wcT = &(scratch_wT[C*K]); +element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); +element_filter_type *scratch_riT = &(scratch_rT[0]); +element_filter_type *scratch_rcT = &(scratch_rT[K*K]); +element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); +element_output_type *t1D = (element_output_type*)handle->scratch_t1; +element_output_type *t2D = (element_output_type*)handle->scratch_t2; +/* multidimensional arrays */ +LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K); +LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(2, element_filter_type, wi, wiD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, wc, wcD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, wf, wfD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, ri, riD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, rc, rcD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, rf, rfD, K3); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, c, ct, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_input_type, dx, dxt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, dhp, dhpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, dwi, dwiD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwc, dwcD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwf, dwfD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dri, driD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, drc, drcD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, drf, drfD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_filter_type, dwi_ck, dwiD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, dwc_ck, dwcD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, dwf_ck, dwfD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, dri_ck, driD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, drc_ck, drcD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, drf_ck, drfD, K3); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dc, dcD, K); +LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dout, doutD, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +LIBXSMM_VLA_DECL(4, element_filter_type, wiT, scratch_wiT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, wcT, scratch_wcT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, wfT, scratch_wfT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, riT, scratch_riT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rcT, scratch_rcT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rfT, scratch_rfT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +LIBXSMM_VLA_DECL(2, element_output_type, oT, scratch_oT, N); +element_output_type *dout_ptr = NULL; +/* define batch-reduce gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, NULL, NULL ); +#if 0 +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL ); +#endif +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb1 = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &K, &N, &bk, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc1 = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &K, &N, &bk, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kerneld = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL ); + +/* Auxiliary arrays for batch-reduce gemm calls */ +const element_filter_type *A_array[1024]; +const element_output_type *B_array[1024]; + +#if 0 +LIBXSMM_VLA_DECL(4, element_output_type, diB, (element_output_type*)handle->scratch_diB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dcB, (element_output_type*)handle->scratch_dciB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, kBlocks, bn, bk); +#endif + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +/* number of tasks that could be run in parallel for K blocks*/ +/* compute chunk size */ +const libxsmm_blasint chunksize_k = (K % (libxsmm_blasint)handle->desc.threads == 0) ? (K / (libxsmm_blasint)handle->desc.threads) : ((K / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_k = (ltid * chunksize_k < K) ? (ltid * chunksize_k) : K; +const libxsmm_blasint thr_end_k = ((ltid + 1) * chunksize_k < K) ? ((ltid + 1) * chunksize_k) : K; + +/* int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; */ + +libxsmm_blasint ikic, inic, inik, icin, ikin; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (K > 1024 && K <= 2048) { + BF = 8; + while (kBlocks % BF != 0) { + BF--; + } +} + +if (K > 2048) { + BF = 16; + while (kBlocks % BF != 0) { + BF--; + } +} +KB_BLOCKS = kBlocks/BF; + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxt, start_thread, tid, handle->desc.threads); +} + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K*3, w_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K*3, r_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*3, db, start_thread, tid, handle->desc.threads); +} + +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(4, wiT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wi, ic*bc+jc, ik*bk+jk, K3); + LIBXSMM_VLA_ACCESS(4, wcT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wc, ic*bc+jc, ik*bk+jk, K3); + LIBXSMM_VLA_ACCESS(4, wfT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wf, ic*bc+jc, ik*bk+jk, K3); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(4, riT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, ri, ic*bk+jc, ik*bk+jk, K3); + LIBXSMM_VLA_ACCESS(4, rcT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rc, ic*bk+jc, ik*bk+jk, K3); + LIBXSMM_VLA_ACCESS(4, rfT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rf, ic*bk+jc, ik*bk+jk, K3); + } + } +} +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +for (j = t-1; j >= 0; --j) { + /* let's run the cell in blocks for good locality */ + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ik = (inik / (N/bn))*bk; + + /* compute dhp */ + if (j == t-1) { + libxsmm_internal_matrix_copy_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } else { + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } + /* df = dout . (1 - c) . (1 - (f . f)) */ + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_square_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, df, in, ik, K) ); + /* dc = dout . (hp - f) . c . (1 - c) */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + if (0 == j) { + libxsmm_internal_matrix_sub_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } else { + LIBXSMM_ASSERT(NULL != ht); /* coverity[var_deref_op] */ + libxsmm_internal_matrix_sub_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K) ); + } + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* transpose xt for current timestep */ + for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + in = (icin / (C/bc))*bn; + ic = (icin % (C/bc))*bc; + + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, xT, ec, en, N) = LIBXSMM_VLA_ACCESS(3, x, j, en, ec, N, C); + } + } + } + + /* transpose ht for current timestep */ + if (j == 0) { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(2, hp, en, ek, K); + } + } + } + } else { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, h, j-1, en, ek, N, K); + } + } + } + } + + /* transpose ot for current timestep */ + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, oT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, o, j, en, ek, N, K); + } + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + /* do = {R_f}^T * df */ + for (KB = 0; KB < BF; KB++) { + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + + if (KB == 0) libxsmm_internal_matrix_zero_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K) ); + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rfT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, df, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kerneld(A_array, B_array, &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &blocks); + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + /* di = do . hp . i . (1 - i) */ + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ik = (inik / (N/bn))*bk; + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + if (0 == j) { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } else { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, di, in, ik, K) ); + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* dx = W^T * dicf */ + for (KB = 0; KB < BF; KB++) { + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + in = (inic % (N/bn))*bn; + icb = inic / (N/bn); + ic = icb*bc; + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wiT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wcT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wfT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + } + } + } + + for (KB = 0; KB < BF; KB++) { + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + dout_ptr = (j > 0) ? (element_output_type*) &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) : (element_output_type*) &LIBXSMM_VLA_ACCESS(2, dhp, in, ik, K); + + if (0 == KB) { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), dout_ptr ); + } + + /* dhp += R^T * dic */ + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, riT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, di, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kerneld(A_array, B_array, dout_ptr, &blocks); + + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rcT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kerneld(A_array, B_array, dout_ptr, &blocks); + } + } + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + if ((C == K) && (bc == bk) /*&& (bcbk_multiples_of_16 == 1)*/) { +#if 0 + if (K % 2048 != 0) { +#endif + /* Interleave computation of dr = dicf * o^T/h^T and dw = dicf * x^T to take advantage of temporal locality */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + blocks = nBlocks; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } +#if 0 + } else { + /* Interleave computation of dr = dicf * o^T/h^T and dw = dicf * x^T to take advantage of temporal locality */ + /* Use blocked format for di, dc, df */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + blocks = nBlocks; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, diB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, diB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dcB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dcB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dfB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dfB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } +#endif + } else { + /* dr = dicf * o^T/h^T */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + } + + /* dw = dicf * x^T */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } + + /* gradient bias */ + for (ik = thr_begin_k; ik < thr_end_k; ik++) { + for (in = 0; in < N; in++) { + dbi[ik] += LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + dbc[ik] += LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + dbf[ik] += LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* Store result weight matrices in CK format */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bc; ++jc) { + for (jk = 0; jk < bk; ++jk) { + LIBXSMM_VLA_ACCESS(2, dwi_ck, ic+jc, ik+jk, K3) = LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(2, dwc_ck, ic+jc, ik+jk, K3) = LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(2, dwf_ck, ic+jc, ik+jk, K3) = LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk); + } + } + } + + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bk; ++jc) { + for (jk = 0; jk < bk; ++jk) { + LIBXSMM_VLA_ACCESS(2, dri_ck, ic+jc, ik+jk, K3) = LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(2, drc_ck, ic+jc, ik+jk, K3) = LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(2, drf_ck, ic+jc, ik+jk, K3) = LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk); + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_kcck.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_kcck.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..834be810372a2e4528522c6bb93360b827d1c99b --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_kcck.tpl.c @@ -0,0 +1,626 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +unsigned long long blocks; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *ht = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ct = (element_output_type*)handle->cit->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_input_type *dxt = (element_input_type* )handle->dxt->data; +element_input_type *dhpD = (element_input_type* )handle->dhp->data; +element_filter_type *dw = (element_filter_type*)handle->dw->data; +element_filter_type *dr = (element_filter_type*)handle->dr->data; +element_output_type *db = (element_output_type*)handle->db->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *diD = (element_output_type*)handle->scratch_di; +element_output_type *dcD = (element_output_type*)handle->scratch_dci; +element_output_type *dfD = (element_output_type*)handle->scratch_df; +element_output_type *doD = (element_output_type*)handle->scratch_do; +element_output_type *doutD = (element_output_type*)handle->scratch_deltat; +element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +element_output_type *scratch_oT = (element_output_type*)handle->scratch_dpB; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[C*K]); +element_filter_type *wfD = &(w[2*C*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K*K]); +element_filter_type *rfD = &(r[2*K*K]); +element_filter_type *dwiD = &(dw[0]); +element_filter_type *dwcD = &(dw[C*K]); +element_filter_type *dwfD = &(dw[2*C*K]); +element_filter_type *driD = &(dr[0]); +element_filter_type *drcD = &(dr[K*K]); +element_filter_type *drfD = &(dr[2*K*K]); +element_output_type *dbi = &(db[0]); +element_output_type *dbc = &(db[K]); +element_output_type *dbf = &(db[2*K]); +element_filter_type *scratch_wiT = &(scratch_wT[0]); +element_filter_type *scratch_wcT = &(scratch_wT[C*K]); +element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); +element_filter_type *scratch_riT = &(scratch_rT[0]); +element_filter_type *scratch_rcT = &(scratch_rT[K*K]); +element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); +element_output_type *t1D = (element_output_type*)handle->scratch_t1; +element_output_type *t2D = (element_output_type*)handle->scratch_t2; +/* multidimensional arrays */ +LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K); +LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, c, ct, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_input_type, dx, dxt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, dhp, dhpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, dwi, dwiD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwc, dwcD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwf, dwfD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dri, driD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, drc, drcD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, drf, drfD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dc, dcD, K); +LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dout, doutD, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +LIBXSMM_VLA_DECL(4, element_filter_type, wiT, scratch_wiT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, wcT, scratch_wcT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, wfT, scratch_wfT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, riT, scratch_riT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rcT, scratch_rcT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rfT, scratch_rfT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +LIBXSMM_VLA_DECL(2, element_output_type, oT, scratch_oT, N); +element_output_type *dout_ptr = NULL; +/* define batch-reduce gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, NULL, NULL ); +#if 0 +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL ); +#endif +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb1 = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &K, &N, &bk, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc1 = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &K, &N, &bk, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kerneld = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL ); + +/* Auxiliary arrays for batch-reduce gemm calls */ +const element_filter_type *A_array[1024]; +const element_output_type *B_array[1024]; + +#if 0 +LIBXSMM_VLA_DECL(4, element_output_type, diB, (element_output_type*)handle->scratch_diB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dcB, (element_output_type*)handle->scratch_dciB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, kBlocks, bn, bk); +#endif + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +/* number of tasks that could be run in parallel for K blocks*/ +/* compute chunk size */ +const libxsmm_blasint chunksize_k = (K % (libxsmm_blasint)handle->desc.threads == 0) ? (K / (libxsmm_blasint)handle->desc.threads) : ((K / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_k = (ltid * chunksize_k < K) ? (ltid * chunksize_k) : K; +const libxsmm_blasint thr_end_k = ((ltid + 1) * chunksize_k < K) ? ((ltid + 1) * chunksize_k) : K; + +libxsmm_blasint ikic, inic, inik, icin, ikin; +#if defined(LIBXSMM_RNN_CELL_AVX512) +int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (K >= 1024 && K%2==0) { + BF = 2; +} +if (K >= 2048 && K%4==0) { + BF = 4; +} +if (K >= 4096 && K%8==0) { + BF = 8; +} +KB_BLOCKS = kBlocks/BF; + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxt, start_thread, tid, handle->desc.threads); +} + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K*3, dw, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K*3, dr, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*3, db, start_thread, tid, handle->desc.threads); +} + +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(4, wiT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(4, wi, ik, ic, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(4, wcT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(4, wc, ik, ic, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(4, wfT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(4, wf, ik, ic, jc, jk, cBlocks, bc, bk); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(4, riT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(4, ri, ik, ic, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(4, rcT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(4, rc, ik, ic, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(4, rfT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(4, rf, ik, ic, jc, jk, kBlocks, bk, bk); + } + } +} +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +for (j = t-1; j >= 0; --j) { + /* let's run the cell in blocks for good locality */ + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ik = (inik / (N/bn))*bk; +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bcbk_multiples_of_16) { +#include "libxsmm_internal_gru_bwdupd_fused_eltwise_1.tpl.c" + } else { + /* compute dhp */ + if (j == t-1) { + libxsmm_internal_matrix_copy_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } else { + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } + /* df = dout . (1 - c) . (1 - (f . f)) */ + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_square_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, df, in, ik, K) ); + /* dc = dout . (hp - f) . c . (1 - c) */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + if (0 == j) { + libxsmm_internal_matrix_sub_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } else { + libxsmm_internal_matrix_sub_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K) ); + } +#else + /* compute dhp */ + if (j == t-1) { + libxsmm_internal_matrix_copy_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } else { + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } + /* df = dout . (1 - c) . (1 - (f . f)) */ + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_square_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, df, in, ik, K) ); + /* dc = dout . (hp - f) . c . (1 - c) */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + if (0 == j) { + libxsmm_internal_matrix_sub_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } else { + libxsmm_internal_matrix_sub_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K) ); +#endif + } + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* transpose xt for current timestep */ + for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + in = (icin / (C/bc))*bn; + ic = (icin % (C/bc))*bc; + + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, xT, ec, en, N) = LIBXSMM_VLA_ACCESS(3, x, j, en, ec, N, C); + } + } + } + + /* transpose ht for current timestep */ + if (j == 0) { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(2, hp, en, ek, K); + } + } + } + } else { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, h, j-1, en, ek, N, K); + } + } + } + } + + /* transpose ot for current timestep */ + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, oT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, o, j, en, ek, N, K); + } + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + /* do = {R_f}^T * df */ + for (KB = 0; KB < BF; KB++) { + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + + if (KB == 0) libxsmm_internal_matrix_zero_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K) ); + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rfT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, df, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kerneld(A_array, B_array, &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &blocks); + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + /* di = do . hp . i . (1 - i) */ + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ik = (inik / (N/bn))*bk; +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bcbk_multiples_of_16) { +#include "libxsmm_internal_gru_bwdupd_fused_eltwise_2.tpl.c" + } else { + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + if (0 == j) { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } else { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, di, in, ik, K) ); + } +#else + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + if (0 == j) { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } else { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + } + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, di, in, ik, K) ); +#endif + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* dx = W^T * dicf */ + for (KB = 0; KB < BF; KB++) { + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + in = (inic % (N/bn))*bn; + icb = inic / (N/bn); + ic = icb*bc; + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wiT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wcT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wfT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + } + } + } + + for (KB = 0; KB < BF; KB++) { + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + dout_ptr = (j > 0) ? (element_output_type*) &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) : (element_output_type*) &LIBXSMM_VLA_ACCESS(2, dhp, in, ik, K); + + if (0 == KB) { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), dout_ptr ); + } + + /* dhp += R^T * dic */ + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, riT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, di, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kerneld(A_array, B_array, dout_ptr, &blocks); + + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rcT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kerneld(A_array, B_array, dout_ptr, &blocks); + } + } + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + if ((C == K) && (bc == bk) /*&& (bcbk_multiples_of_16 == 1)*/) { +#if 0 + if (K % 2048 != 0) { +#endif + /* Interleave computation of dr = dicf * o^T/h^T and dw = dicf * x^T to take advantage of temporal locality */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + blocks = nBlocks; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } +#if 0 + } else { + /* Interleave computation of dr = dicf * o^T/h^T and dw = dicf * x^T to take advantage of temporal locality */ + /* Use blocked format for di, dc, df */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + blocks = nBlocks; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, diB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, diB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dcB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dcB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dfB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dfB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } +#endif + } else { + /* dr = dicf * o^T/h^T */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, oT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + } + + /* dw = dicf * x^T */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } + + /* gradient bias */ + for (ik = thr_begin_k; ik < thr_end_k; ik++) { + for (in = 0; in < N; in++) { + dbi[ik] += LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + dbc[ik] += LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + dbf[ik] += LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..dfe775ad76eeaba02b36356fc918700c3b50478b --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c @@ -0,0 +1,285 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, ic, icb, inik, BF, CB, CB_BLOCKS, KB_BLOCKS, ikic, jk, jc; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +const libxsmm_blasint K3 = K * 3; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +unsigned long long blocks; + +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_filter_type *w_scratch = (element_filter_type*)handle->scratch_w; +element_filter_type *r_scratch = (element_filter_type*)handle->scratch_r; +element_output_type *b = (element_output_type*)handle->b->data; +element_output_type *ht = (element_output_type*)handle->ht->data; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ct = (element_output_type*)handle->cit->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[K]); +element_filter_type *wfD = &(w[2*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K]); +element_filter_type *rfD = &(r[2*K]); +element_filter_type *wiD_scratch = &(w_scratch[0]); +element_filter_type *wcD_scratch = &(w_scratch[C*K]); +element_filter_type *wfD_scratch = &(w_scratch[2*C*K]); +element_filter_type *riD_scratch = &(r_scratch[0]); +element_filter_type *rcD_scratch = &(r_scratch[K*K]); +element_filter_type *rfD_scratch = &(r_scratch[2*K*K]); +element_output_type *bi = &(b[0]); +element_output_type *bd = &(b[K]); +element_output_type *bf = &(b[2*K]); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_filter_type, wi_ck, wiD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, wc_ck, wcD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, wf_ck, wfD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, ri_ck, riD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, rc_ck, rcD, K3); +LIBXSMM_VLA_DECL(2, element_filter_type, rf_ck, rfD, K3); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, c, ct, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +/* define batch-reduce gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL ); +/* define gemm kernels */ +/* Auxiliary arrays for batch-reduce gemms */ +const element_filter_type *A_array[1024]; +const element_input_type *B_array[1024]; + +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; +#if 0 +const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; +#endif +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} +if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} + +if (C == 2048 && K == 1024) { + BF = 2; +} + +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; + +/* Upfront reformatting of W and R */ +/* reformat W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(4, wi, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wi_ck, ic*bc+jc, ik*bk+jk, 3*K); + LIBXSMM_VLA_ACCESS(4, wc, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wc_ck, ic*bc+jc, ik*bk+jk, 3*K); + LIBXSMM_VLA_ACCESS(4, wf, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wf_ck, ic*bc+jc, ik*bk+jk, 3*K); + } + } +} + +/* reformat R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(4, ri, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, ri_ck, ic*bk+jc, ik*bk+jk, 3*K); + LIBXSMM_VLA_ACCESS(4, rc, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rc_ck, ic*bk+jc, ik*bk+jk, 3*K); + LIBXSMM_VLA_ACCESS(4, rf, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rf_ck, ic*bk+jc, ik*bk+jk, 3*K); + } + } +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* All data is in column-major format */ +for (j = 0; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); + /* i += R.hp */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); + /* initialize c with bd */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &bd[ik] ); + /* c += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &blocks); + /* c += R.hp */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &blocks); + + if (CB == BF-1) { + /* i = sigmoid(i) */ + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + /* o = hp . i */ + if (0 == j) { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + } else { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + } + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + /* We need a barrier here to ensure all elements of o are computed before f can be computed */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize f with bf */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik] ); + /* f += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); + /* f += R.o */ + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, o, j, in, ic + CB*KB_BLOCKS*bk, N, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); + + if (CB == BF-1) { + /* f = tanh(f) */ + libxsmm_internal_matrix_tanh_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + /* c = sigmoid(c) */ + libxsmm_internal_matrix_sigmoid_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K) ); + /* h = (1 - c) . f */ + libxsmm_internal_matrix_complement_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + /* h += c . hp */ + if (0 == j) { + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } else { + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..92d429bd823e8d30f9544906eb8d96302be9056a --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c @@ -0,0 +1,222 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, ic, icb, inik, BF, CB, CB_BLOCKS, KB_BLOCKS; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +unsigned long long blocks; + +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *b = (element_output_type*)handle->b->data; +element_output_type *ht = (element_output_type*)handle->ht->data; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ct = (element_output_type*)handle->cit->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[C*K]); +element_filter_type *wfD = &(w[2*C*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K*K]); +element_filter_type *rfD = &(r[2*K*K]); +element_output_type *bi = &(b[0]); +element_output_type *bd = &(b[K]); +element_output_type *bf = &(b[2*K]); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, c, ct, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +/* define batch-reduce gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL ); +/* define gemm kernels */ +/* Auxiliary arrays for batch-reduce gemms */ +const element_filter_type *A_array[1024]; +const element_input_type *B_array[1024]; + +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +#if 0 +const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; +#endif +BF = 1; +if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} +if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} + +if (C == 2048 && K == 1024) { + BF = 2; +} + +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* All data is in column-major format */ +for (j = 0; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); + /* i += R.hp */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); + /* initialize c with bd */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &bd[ik] ); + /* c += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &blocks); + /* c += R.hp */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &blocks); + + if (CB == BF-1) { + /* i = sigmoid(i) */ + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + /* o = hp . i */ + if (0 == j) { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + } else { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + } + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + /* We need a barrier here to ensure all elements of o are computed before f can be computed */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize f with bf */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik] ); + /* f += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); + /* f += R.o */ + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, o, j, in, ic + CB*KB_BLOCKS*bk, N, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); + + if (CB == BF-1) { + /* f = tanh(f) */ + libxsmm_internal_matrix_tanh_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + /* c = sigmoid(c) */ + libxsmm_internal_matrix_sigmoid_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K) ); + /* h = (1 - c) . f */ + libxsmm_internal_matrix_complement_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + /* h += c . hp */ + if (0 == j) { + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } else { + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..d4574e937fd2ff4012025658dafb2c6413f1fc26 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic.tpl.c @@ -0,0 +1,360 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +libxsmm_blasint K4 = K * 4; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +unsigned long long blocks; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *cst = (element_output_type*)handle->cst->data; +element_output_type *ht = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_output_type *cit = (element_output_type*)handle->cit->data; +element_output_type *cot = (element_output_type*)handle->cot->data; +element_input_type *dxt = (element_input_type*)handle->dxt->data; +element_input_type *dcsp = (element_input_type* )handle->dcsp->data; +element_input_type *dhpD = (element_input_type* )handle->dhp->data; +element_filter_type *dw = (element_filter_type*)handle->dw->data; +element_filter_type *dr = (element_filter_type*)handle->dr->data; +element_output_type *db = (element_output_type*)handle->db->data; +element_output_type *dcsD = (element_output_type*)handle->dcs->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *diD = (element_output_type*)handle->scratch_di; +element_output_type *dfD = (element_output_type*)handle->scratch_df; +element_output_type *doD = (element_output_type*)handle->scratch_do; +element_output_type *dciD = (element_output_type*)handle->scratch_dci; +element_output_type *doutD = (element_output_type*)handle->scratch_deltat; +element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +element_filter_type *w_scratch = (element_filter_type*)handle->scratch_w; +element_filter_type *r_scratch = (element_filter_type*)handle->scratch_r; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[K]); +element_filter_type *wfD = &(w[2*K]); +element_filter_type *woD = &(w[3*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K]); +element_filter_type *rfD = &(r[2*K]); +element_filter_type *roD = &(r[3*K]); +element_filter_type *dwiD = &(dw[0]); +element_filter_type *dwcD = &(dw[K]); +element_filter_type *dwfD = &(dw[2*K]); +element_filter_type *dwoD = &(dw[3*K]); +element_filter_type *driD = &(dr[0]); +element_filter_type *drcD = &(dr[K]); +element_filter_type *drfD = &(dr[2*K]); +element_filter_type *droD = &(dr[3*K]); +element_filter_type *dwiD_scratch = &(w_scratch[0]); +element_filter_type *dwcD_scratch = &(w_scratch[C*K]); +element_filter_type *dwfD_scratch = &(w_scratch[2*C*K]); +element_filter_type *dwoD_scratch = &(w_scratch[3*C*K]); +element_filter_type *driD_scratch = &(r_scratch[0]); +element_filter_type *drcD_scratch = &(r_scratch[K*K]); +element_filter_type *drfD_scratch = &(r_scratch[2*K*K]); +element_filter_type *droD_scratch = &(r_scratch[3*K*K]); +element_output_type *dbi = &(db[0]); +element_output_type *dbc = &(db[K]); +element_output_type *dbf = &(db[2*K]); +element_output_type *dbo = &(db[3*K]); +element_filter_type *scratch_wiT = &(scratch_wT[0]); +element_filter_type *scratch_wcT = &(scratch_wT[C*K]); +element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); +element_filter_type *scratch_woT = &(scratch_wT[3*C*K]); +element_filter_type *scratch_riT = &(scratch_rT[0]); +element_filter_type *scratch_rcT = &(scratch_rT[K*K]); +element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); +element_filter_type *scratch_roT = &(scratch_rT[3*K*K]); +element_output_type *t1D = (element_output_type*)handle->scratch_t1; +element_output_type *t2D = (element_output_type*)handle->scratch_t2; +/* multidimensional arrays */ +LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K); +LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(2, element_filter_type, wi, wiD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, wf, wfD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, wo, woD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, wc, wcD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, ri, riD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, rf, rfD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, ro, roD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, rc, rcD, K4); +LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); +LIBXSMM_VLA_DECL(3, element_input_type, dx, dxt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, dcp, dcsp, K); +LIBXSMM_VLA_DECL(2, element_input_type, dhp, dhpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, dwi, dwiD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwf, dwfD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwo, dwoD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwc, dwcD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dri, driD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, drf, drfD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dro, droD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, drc, drcD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_filter_type, dwi_ck, dwiD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dwf_ck, dwfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dwo_ck, dwoD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dwc_ck, dwcD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dri_ck, driD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, drf_ck, drfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dro_ck, droD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, drc_ck, drcD, 4*K); +LIBXSMM_VLA_DECL(2, element_output_type, dcs, dcsD, K); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K); +LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dci, dciD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dout, doutD, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +LIBXSMM_VLA_DECL(4, element_filter_type, wiT, scratch_wiT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, wcT, scratch_wcT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, wfT, scratch_wfT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, woT, scratch_woT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, riT, scratch_riT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rcT, scratch_rcT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rfT, scratch_rfT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, roT, scratch_roT, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +element_output_type *dout_ptr = NULL; +/* define batch-reduce gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb1 = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &K, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc1 = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &K, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kerneld = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL); + +/* Auxiliary arrays for batch-reduce gemm calls */ +const element_filter_type *A_array[1024]; +const element_output_type *B_array[1024]; + +LIBXSMM_VLA_DECL(4, element_output_type, diB, (element_output_type*)handle->scratch_diB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, kBlocks, bn, bk); + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) +element_output_type *cps_ptr = NULL; +int k_tasks = K/16; +int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K; +const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K;__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum; +#endif +/* number of tasks that could be run in parallel for K blocks*/ +/* compute chunk size */ +const libxsmm_blasint chunksize_k = (K % (libxsmm_blasint)handle->desc.threads == 0) ? (K / (libxsmm_blasint)handle->desc.threads) : ((K / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_k = (ltid * chunksize_k < K) ? (ltid * chunksize_k) : K; +const libxsmm_blasint thr_end_k = ((ltid + 1) * chunksize_k < K) ? ((ltid + 1) * chunksize_k) : K; +#ifdef PROFILE +__int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0, reformat_cycles = 0; +float total_time = 0.0; +#endif +int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; + +libxsmm_blasint ikic, inic, inik, icin, ikin; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (K > 1024 && K <= 2048) { + BF = 8; + while (kBlocks % BF != 0) { + BF--; + } +} + +if (K > 2048) { + BF = 16; + while (kBlocks % BF != 0) { + BF--; + } +} +KB_BLOCKS = kBlocks/BF; + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxt, start_thread, tid, handle->desc.threads); +} + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K*4, w_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K*4, r_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*4, db, start_thread, tid, handle->desc.threads); +} + +#ifdef PROFILE +if (ltid == 0) _start = _rdtsc(); +#endif +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(4, wiT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wi, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, wcT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wc, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, wfT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wf, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, woT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wo, ic*bc+jc, ik*bk+jk, 4*K); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(4, riT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, ri, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, rcT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rc, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, rfT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rf, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, roT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, ro, ic*bk+jc, ik*bk+jk, 4*K); + } + } +} +#ifdef PROFILE +if (ltid == 0) { + _end = _rdtsc(); + weight_trans_cycles += _end - _start; +} +#endif + +#include "libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core.tpl.c" + +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* Store result weight matrices in CK format */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bc; ++jc) { + for (jk = 0; jk < bk; ++jk) { + LIBXSMM_VLA_ACCESS(2, dwi_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(2, dwc_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(2, dwf_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(2, dwo_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk); + } + } + } + + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bk; ++jc) { + for (jk = 0; jk < bk; ++jk) { + LIBXSMM_VLA_ACCESS(2, dri_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(2, drc_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(2, drf_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(2, dro_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk); + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + reformat_cycles += _end - _start; + } +#endif +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9))); + printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9))); + printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9))); + printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat dwdr time is %f ms (%.2f%%)\n\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); +} +#undef PROFILE +#endif diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..fb1def3838ad615b4dc661748018376e913e45a7 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16.tpl.c @@ -0,0 +1,361 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +libxsmm_blasint K4 = K * 4; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +const int lpb = handle->lpb; +/*const int bc_lp = bc/lpb;*/ +const int bk_lp = bk/lpb; +const int bn_lp = bn/lpb; +unsigned long long blocks; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *cst = (element_output_type*)handle->cst->data; +element_output_type *ht = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_output_type *cit = (element_output_type*)handle->cit->data; +element_output_type *cot = (element_output_type*)handle->cot->data; +element_input_type *dxt = (element_input_type*)handle->dxt->data; +element_input_type *dcsp = (element_input_type* )handle->dcsp->data; +element_input_type *dhpD = (element_input_type* )handle->dhp->data; +element_filter_type *dw = (element_filter_type*)handle->dw->data; +element_filter_type *dr = (element_filter_type*)handle->dr->data; +element_output_type *db_bf16 = (element_output_type*)handle->db->data; +element_output_type *dcsD = (element_output_type*)handle->dcs->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *diD = (element_output_type*)handle->scratch_di; +element_output_type *dfD = (element_output_type*)handle->scratch_df; +element_output_type *doD = (element_output_type*)handle->scratch_do; +element_output_type *dciD = (element_output_type*)handle->scratch_dci; +float *dxD = (float*)handle->scratch_dx; +float *doutD = (float*)handle->scratch_deltat; +float *dhpD_f32 = (float*)handle->scratch_dhp; +float *db = (float*)handle->scratch_db; +element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +float *w_scratch = (float*)handle->scratch_w; +float *r_scratch = (float*)handle->scratch_r; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[K]); +element_filter_type *wfD = &(w[2*K]); +element_filter_type *woD = &(w[3*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K]); +element_filter_type *rfD = &(r[2*K]); +element_filter_type *roD = &(r[3*K]); +element_filter_type *dwiD = &(dw[0]); +element_filter_type *dwcD = &(dw[K]); +element_filter_type *dwfD = &(dw[2*K]); +element_filter_type *dwoD = &(dw[3*K]); +element_filter_type *driD = &(dr[0]); +element_filter_type *drcD = &(dr[K]); +element_filter_type *drfD = &(dr[2*K]); +element_filter_type *droD = &(dr[3*K]); +float *dwiD_scratch = &(w_scratch[0]); +float *dwcD_scratch = &(w_scratch[C*K]); +float *dwfD_scratch = &(w_scratch[2*C*K]); +float *dwoD_scratch = &(w_scratch[3*C*K]); +float *driD_scratch = &(r_scratch[0]); +float *drcD_scratch = &(r_scratch[K*K]); +float *drfD_scratch = &(r_scratch[2*K*K]); +float *droD_scratch = &(r_scratch[3*K*K]); +float *dbi = &(db[0]); +float *dbc = &(db[K]); +float *dbf = &(db[2*K]); +float *dbo = &(db[3*K]); +element_output_type *dbi_bf16 = &(db_bf16[0]); +element_output_type *dbc_bf16 = &(db_bf16[K]); +element_output_type *dbf_bf16 = &(db_bf16[2*K]); +element_output_type *dbo_bf16 = &(db_bf16[3*K]); +element_filter_type *scratch_wiT = &(scratch_wT[0]); +element_filter_type *scratch_wcT = &(scratch_wT[C*K]); +element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); +element_filter_type *scratch_woT = &(scratch_wT[3*C*K]); +element_filter_type *scratch_riT = &(scratch_rT[0]); +element_filter_type *scratch_rcT = &(scratch_rT[K*K]); +element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); +element_filter_type *scratch_roT = &(scratch_rT[3*K*K]); +/*element_output_type *t1D = (element_output_type*)handle->scratch_t1;*/ +/*element_output_type *t2D = (element_output_type*)handle->scratch_t2;*/ +/* multidimensional arrays */ +/*LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K);*/ +/*LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K);*/ +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(2, element_filter_type, wi, wiD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, wf, wfD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, wo, woD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, wc, wcD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, ri, riD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, rf, rfD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, ro, roD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, rc, rcD, K4); +LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); +LIBXSMM_VLA_DECL(3, float, dx, dxD, N, C); +LIBXSMM_VLA_DECL(3, element_input_type, dx_bf16, dxt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, dcp, dcsp, K); +LIBXSMM_VLA_DECL(2, element_input_type, dhp, dhpD, K); +LIBXSMM_VLA_DECL(2, float, dhp_f32, dhpD_f32, K); +LIBXSMM_VLA_DECL(4, float, dwi, dwiD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwf, dwfD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwo, dwoD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwc, dwcD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dri, driD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drf, drfD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, dro, droD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drc, drcD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_filter_type, dwi_ck, dwiD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dwf_ck, dwfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dwo_ck, dwoD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dwc_ck, dwcD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dri_ck, driD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, drf_ck, drfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dro_ck, droD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, drc_ck, drcD, 4*K); +LIBXSMM_VLA_DECL(2, element_output_type, dcs, dcsD, K); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K); +LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dci, dciD, K); +LIBXSMM_VLA_DECL(5, element_output_type, diB, (element_output_type*)handle->scratch_diB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, float, dout, doutD, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +LIBXSMM_VLA_DECL(5, element_filter_type, wiT, scratch_wiT, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wcT, scratch_wcT, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wfT, scratch_wfT, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, woT, scratch_woT, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, riT, scratch_riT, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rcT, scratch_rcT, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rfT, scratch_rfT, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, roT, scratch_roT, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +float *dout_ptr = NULL; +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->bwdupd_kernela; +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->bwdupd_kernelb; +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelc = handle->bwdupd_kernelc; +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kerneld = handle->bwdupd_kerneld; +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +element_output_type *cps_ptr = NULL; +int k_tasks = K/16; +int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K; +const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K; +__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum; +#ifdef PROFILE +__int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0, reformat_cycles = 0; +float total_time = 0.0; +#endif +int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; + +libxsmm_blasint ikic, inic, inik, icin, ikin; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (K > 1024 && K <= 2048) { + BF = 8; + while (kBlocks % BF != 0) { + BF--; + } +} + +if (K > 2048) { + BF = 16; + while (kBlocks % BF != 0) { + BF--; + } +} +KB_BLOCKS = kBlocks/BF; + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxD, start_thread, tid, handle->desc.threads); +} + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K*4, w_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K*4, r_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*4, db, start_thread, tid, handle->desc.threads); +} + +#ifdef PROFILE +if (ltid == 0) _start = _rdtsc(); +#endif +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(5, wiT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(2, wi, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(2, wc, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(2, wf, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, woT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(2, wo, ic*bc+jc, ik*bk+jk, 4*K); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(5, riT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, ri, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, rcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, rc, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, rfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, rf, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, roT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, ro, ic*bk+jc, ik*bk+jk, 4*K); + } + } +} +#ifdef PROFILE +if (ltid == 0) { + _end = _rdtsc(); + weight_trans_cycles += _end - _start; +} +#endif + +#include "libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16.tpl.c" + +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* Store result weight matrices in CK format and downcovert to bf16 */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bc; ++jc) { + for (jk = 0; jk < bk; jk += 16) { + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dwi_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk)))); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dwc_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk)))); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dwf_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk)))); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dwo_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk)))); + } + } + } + + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bk; ++jc) { + for (jk = 0; jk < bk; jk += 16) { + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dri_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk)))); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, drc_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk)))); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, drf_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk)))); + _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dro_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk)))); + } + } + } +#else + /* TODO: Add here non AVX512 replacement code */ +#endif + libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + reformat_cycles += _end - _start; + } +#endif +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9))); + printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9))); + printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9))); + printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat dwdr time is %f ms (%.2f%%)\n\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); +} +#undef PROFILE +#endif + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..14b465ab65a454d7145738328bee7c6dd1011bc6 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16_amx.tpl.c @@ -0,0 +1,376 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +libxsmm_blasint K4 = K * 4; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +const int lpb = handle->lpb; +/*const int bc_lp = bc/lpb;*/ +const int bk_lp = bk/lpb; +const int bn_lp = bn/lpb; +unsigned long long blocks; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *cst = (element_output_type*)handle->cst->data; +element_output_type *ht = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_output_type *cit = (element_output_type*)handle->cit->data; +element_output_type *cot = (element_output_type*)handle->cot->data; +element_input_type *dxt = (element_input_type*)handle->dxt->data; +element_input_type *dcsp = (element_input_type* )handle->dcsp->data; +element_input_type *dhpD = (element_input_type* )handle->dhp->data; +element_filter_type *dw = (element_filter_type*)handle->dw->data; +element_filter_type *dr = (element_filter_type*)handle->dr->data; +element_output_type *db_bf16 = (element_output_type*)handle->db->data; +element_output_type *dcsD = (element_output_type*)handle->dcs->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *diD = (element_output_type*)handle->scratch_di; +element_output_type *dfD = (element_output_type*)handle->scratch_df; +element_output_type *doD = (element_output_type*)handle->scratch_do; +element_output_type *dciD = (element_output_type*)handle->scratch_dci; +float *dxD = (float*)handle->scratch_dx; +float *doutD = (float*)handle->scratch_deltat; +float *dhpD_f32 = (float*)handle->scratch_dhp; +float *db = (float*)handle->scratch_db; +element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +float *w_scratch = (float*)handle->scratch_w; +float *r_scratch = (float*)handle->scratch_r; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[K]); +element_filter_type *wfD = &(w[2*K]); +element_filter_type *woD = &(w[3*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K]); +element_filter_type *rfD = &(r[2*K]); +element_filter_type *roD = &(r[3*K]); +element_filter_type *dwiD = &(dw[0]); +element_filter_type *dwcD = &(dw[K]); +element_filter_type *dwfD = &(dw[2*K]); +element_filter_type *dwoD = &(dw[3*K]); +element_filter_type *driD = &(dr[0]); +element_filter_type *drcD = &(dr[K]); +element_filter_type *drfD = &(dr[2*K]); +element_filter_type *droD = &(dr[3*K]); +float *dwiD_scratch = &(w_scratch[0]); +float *dwcD_scratch = &(w_scratch[C*K]); +float *dwfD_scratch = &(w_scratch[2*C*K]); +float *dwoD_scratch = &(w_scratch[3*C*K]); +float *driD_scratch = &(r_scratch[0]); +float *drcD_scratch = &(r_scratch[K*K]); +float *drfD_scratch = &(r_scratch[2*K*K]); +float *droD_scratch = &(r_scratch[3*K*K]); +float *dbi = &(db[0]); +float *dbc = &(db[K]); +float *dbf = &(db[2*K]); +float *dbo = &(db[3*K]); +element_output_type *dbi_bf16 = &(db_bf16[0]); +element_output_type *dbc_bf16 = &(db_bf16[K]); +element_output_type *dbf_bf16 = &(db_bf16[2*K]); +element_output_type *dbo_bf16 = &(db_bf16[3*K]); +element_filter_type *scratch_wiT = &(scratch_wT[0]); +element_filter_type *scratch_wcT = &(scratch_wT[C*K]); +element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); +element_filter_type *scratch_woT = &(scratch_wT[3*C*K]); +element_filter_type *scratch_riT = &(scratch_rT[0]); +element_filter_type *scratch_rcT = &(scratch_rT[K*K]); +element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); +element_filter_type *scratch_roT = &(scratch_rT[3*K*K]); +/*element_output_type *t1D = (element_output_type*)handle->scratch_t1;*/ +/*element_output_type *t2D = (element_output_type*)handle->scratch_t2;*/ +/* multidimensional arrays */ +/*LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K);*/ +/*LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K);*/ +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(2, element_filter_type, wi, wiD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, wf, wfD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, wo, woD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, wc, wcD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, ri, riD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, rf, rfD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, ro, roD, K4); +LIBXSMM_VLA_DECL(2, element_filter_type, rc, rcD, K4); +LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); +LIBXSMM_VLA_DECL(3, float, dx, dxD, N, C); +LIBXSMM_VLA_DECL(3, element_input_type, dx_bf16, dxt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, dcp, dcsp, K); +LIBXSMM_VLA_DECL(2, element_input_type, dhp, dhpD, K); +LIBXSMM_VLA_DECL(2, float, dhp_f32, dhpD_f32, K); +LIBXSMM_VLA_DECL(4, float, dwi, dwiD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwf, dwfD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwo, dwoD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwc, dwcD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dri, driD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drf, drfD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, dro, droD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drc, drcD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_filter_type, dwi_ck, dwiD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dwf_ck, dwfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dwo_ck, dwoD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dwc_ck, dwcD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dri_ck, driD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, drf_ck, drfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, dro_ck, droD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, drc_ck, drcD, 4*K); +LIBXSMM_VLA_DECL(2, element_output_type, dcs, dcsD, K); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K); +LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dci, dciD, K); +LIBXSMM_VLA_DECL(5, element_output_type, diB, (element_output_type*)handle->scratch_diB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, float, dout, doutD, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +LIBXSMM_VLA_DECL(5, element_filter_type, wiT, scratch_wiT, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wcT, scratch_wcT, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wfT, scratch_wfT, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, woT, scratch_woT, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, riT, scratch_riT, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rcT, scratch_rcT, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rfT, scratch_rfT, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, roT, scratch_roT, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +float *dout_ptr = NULL; +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->bwdupd_kernela; /*libxsmm_bsmmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, &kernel_flags, NULL);*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->bwdupd_kernelb; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelc = handle->bwdupd_kernelc; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kerneld = handle->bwdupd_kerneld; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL);*/ +libxsmm_bsmmfunction_reducebatch_addr tile_config_kernel = handle->bwdupd_tileconfig; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL);*/ + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) +element_output_type *cps_ptr = NULL; +int k_tasks = K/16; +int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K; +const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K; +__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum; +#endif +#ifdef PROFILE +__int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0, reformat_cycles = 0; +float total_time = 0.0; +#endif +int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; + +libxsmm_blasint ikic, inic, inik, icin, ikin; + +/* Hoist tileconfig if possible */ +if ((bk % 32 == 0) && (bc % 32 == 0) && (bn % 32 == 0)) { + tile_config_kernel(NULL, NULL, NULL, NULL); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (K > 1024 && K <= 2048) { + BF = 8; + while (kBlocks % BF != 0) { + BF--; + } +} + +if (K > 2048) { + BF = 16; + while (kBlocks % BF != 0) { + BF--; + } +} + +BF = handle->bwdupd_block; +KB_BLOCKS = kBlocks/BF; + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxD, start_thread, tid, handle->desc.threads); +} + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K*4, w_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K*4, r_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*4, db, start_thread, tid, handle->desc.threads); +} + +#ifdef PROFILE +if (ltid == 0) _start = _rdtsc(); +#endif +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(5, wiT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(2, wi, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(2, wc, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(2, wf, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, woT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(2, wo, ic*bc+jc, ik*bk+jk, 4*K); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(5, riT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, ri, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, rcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, rc, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, rfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, rf, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, roT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, ro, ic*bk+jc, ik*bk+jk, 4*K); + } + } +} +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +#ifdef PROFILE +if (ltid == 0) { + _end = _rdtsc(); + weight_trans_cycles += _end - _start; +} +#endif + +#include "libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16_amx.tpl.c" + +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* Store result weight matrices in CK format and downcovert to bf16 */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bc; ++jc) { + for (jk = 0; jk < bk; jk += 16) { + _mm512_storecvt_fp32_bf16(&LIBXSMM_VLA_ACCESS(2, dwi_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_storecvt_fp32_bf16(&LIBXSMM_VLA_ACCESS(2, dwc_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_storecvt_fp32_bf16(&LIBXSMM_VLA_ACCESS(2, dwf_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_storecvt_fp32_bf16(&LIBXSMM_VLA_ACCESS(2, dwo_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk))); + } + } + } + + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bk; ++jc) { + for (jk = 0; jk < bk; jk += 16) { + _mm512_storecvt_fp32_bf16(&LIBXSMM_VLA_ACCESS(2, dri_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_storecvt_fp32_bf16(&LIBXSMM_VLA_ACCESS(2, drc_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_storecvt_fp32_bf16(&LIBXSMM_VLA_ACCESS(2, drf_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_storecvt_fp32_bf16(&LIBXSMM_VLA_ACCESS(2, dro_ck, ic+jc, ik+jk , K4), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk))); + } + } + } +#else + /* TODO: Add here non AVX512 replacement code */ +#endif + libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + reformat_cycles += _end - _start; + } +#endif +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9))); + printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9))); + printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9))); + printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat dwdr time is %f ms (%.2f%%)\n\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); +} +#undef PROFILE +#endif + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..272a22b3925a975d5aee65f85502603169986714 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck.tpl.c @@ -0,0 +1,306 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +unsigned long long blocks; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *wt = (element_filter_type*)handle->wt->data; +element_filter_type *rt = (element_filter_type*)handle->rt->data; +element_output_type *cst = (element_output_type*)handle->cst->data; +element_output_type *ht = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_output_type *cit = (element_output_type*)handle->cit->data; +element_output_type *cot = (element_output_type*)handle->cot->data; +element_input_type *dxt = (element_input_type*)handle->dxt->data; +element_input_type *dcsp = (element_input_type* )handle->dcsp->data; +element_input_type *dhpD = (element_input_type* )handle->dhp->data; +element_filter_type *dw = (element_filter_type*)handle->dw->data; +element_filter_type *dr = (element_filter_type*)handle->dr->data; +element_output_type *db = (element_output_type*)handle->db->data; +element_output_type *dcsD = (element_output_type*)handle->dcs->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *diD = (element_output_type*)handle->scratch_di; +element_output_type *dfD = (element_output_type*)handle->scratch_df; +element_output_type *doD = (element_output_type*)handle->scratch_do; +element_output_type *dciD = (element_output_type*)handle->scratch_dci; +element_output_type *doutD = (element_output_type*)handle->scratch_deltat; +element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; +#if 0 +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +#endif +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +element_filter_type *witD = &(wt[0]); +element_filter_type *wctD = &(wt[C*K]); +element_filter_type *wftD = &(wt[2*C*K]); +element_filter_type *wotD = &(wt[3*C*K]); +element_filter_type *ritD = &(rt[0]); +element_filter_type *rctD = &(rt[K*K]); +element_filter_type *rftD = &(rt[2*K*K]); +element_filter_type *rotD = &(rt[3*K*K]); +element_filter_type *dwiD = &(dw[0]); +element_filter_type *dwcD = &(dw[C*K]); +element_filter_type *dwfD = &(dw[2*C*K]); +element_filter_type *dwoD = &(dw[3*C*K]); +element_filter_type *driD = &(dr[0]); +element_filter_type *drcD = &(dr[K*K]); +element_filter_type *drfD = &(dr[2*K*K]); +element_filter_type *droD = &(dr[3*K*K]); +element_output_type *dbi = &(db[0]); +element_output_type *dbc = &(db[K]); +element_output_type *dbf = &(db[2*K]); +element_output_type *dbo = &(db[3*K]); +#if 0 +element_filter_type *scratch_wiT = &(scratch_wT[0]); +element_filter_type *scratch_wcT = &(scratch_wT[C*K]); +element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); +element_filter_type *scratch_woT = &(scratch_wT[3*C*K]); +element_filter_type *scratch_riT = &(scratch_rT[0]); +element_filter_type *scratch_rcT = &(scratch_rT[K*K]); +element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); +element_filter_type *scratch_roT = &(scratch_rT[3*K*K]); +#endif +element_output_type *t1D = (element_output_type*)handle->scratch_t1; +element_output_type *t2D = (element_output_type*)handle->scratch_t2; +/* multidimensional arrays */ +LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K); +LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +#if 0 +LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wo, woD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, ro, roD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD, kBlocks, bk, bk); +#endif +LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); +LIBXSMM_VLA_DECL(3, element_input_type, dx, dxt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, dcp, dcsp, K); +LIBXSMM_VLA_DECL(2, element_input_type, dhp, dhpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, dwi, dwiD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwf, dwfD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwo, dwoD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dwc, dwcD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dri, driD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, drf, drfD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dro, droD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, drc, drcD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_output_type, dcs, dcsD, K); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K); +LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dci, dciD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dout, doutD, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +LIBXSMM_VLA_DECL(4, element_filter_type, wiT, witD, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, wcT, wctD, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, wfT, wftD, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, woT, wotD, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, riT, ritD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rcT, rctD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rfT, rftD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, roT, rotD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +element_output_type *dout_ptr = NULL; +/* define batch-reduce gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb1 = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &K, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc1 = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &K, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kerneld = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL); + +/* Auxiliary arrays for batch-reduce gemm calls */ +const element_filter_type *A_array[1024]; +const element_output_type *B_array[1024]; + +LIBXSMM_VLA_DECL(4, element_output_type, diB, (element_output_type*)handle->scratch_diB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, kBlocks, bn, bk); + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) +element_output_type *cps_ptr = NULL; +int k_tasks = K/16; +int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K; +const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K;__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum; +#endif +/* number of tasks that could be run in parallel for K blocks*/ +/* compute chunk size */ +const libxsmm_blasint chunksize_k = (K % (libxsmm_blasint)handle->desc.threads == 0) ? (K / (libxsmm_blasint)handle->desc.threads) : ((K / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_k = (ltid * chunksize_k < K) ? (ltid * chunksize_k) : K; +const libxsmm_blasint thr_end_k = ((ltid + 1) * chunksize_k < K) ? ((ltid + 1) * chunksize_k) : K; +#ifdef PROFILE +__int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0; +float total_time = 0.0; +#endif +int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; + +libxsmm_blasint ikic, inic, inik, icin, ikin; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (K > 1024 && K <= 2048) { + BF = 8; + while (kBlocks % BF != 0) { + BF--; + } +} + +if (K > 2048) { + BF = 16; + while (kBlocks % BF != 0) { + BF--; + } +} +KB_BLOCKS = kBlocks/BF; + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxt, start_thread, tid, handle->desc.threads); +} + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K*4, dw, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K*4, dr, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*4, db, start_thread, tid, handle->desc.threads); +} + +/* Here we assume that the weight tensors come in transposed from framework */ +#if 0 +#ifdef PROFILE +if (ltid == 0) _start = _rdtsc(); +#endif +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(4, wiT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(4, wi, ik, ic, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(4, wcT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(4, wc, ik, ic, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(4, wfT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(4, wf, ik, ic, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(4, woT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(4, wo, ik, ic, jc, jk, cBlocks, bc, bk); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(4, riT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(4, ri, ik, ic, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(4, rcT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(4, rc, ik, ic, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(4, rfT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(4, rf, ik, ic, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(4, roT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(4, ro, ik, ic, jc, jk, kBlocks, bk, bk); + } + } +} +#ifdef PROFILE +if (ltid == 0) { + _end = _rdtsc(); + weight_trans_cycles += _end - _start; +} +#endif +#endif + +#include "libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core.tpl.c" + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles)/(2.5 * 1e9)*1000.0f; + printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9))); + printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9))); + printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9))); + printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); +} +#undef PROFILE +#endif diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..1f43f01aa5d210b0eacfc8eef76c0faf6db3e0a0 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16.tpl.c @@ -0,0 +1,447 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +const int lpb = handle->lpb; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +const int bn_lp = bn/lpb; +unsigned long long blocks; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *wt = (element_filter_type*)handle->wt->data; +element_filter_type *rt = (element_filter_type*)handle->rt->data; +element_output_type *cst = (element_output_type*)handle->cst->data; +element_output_type *ht = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_output_type *cit = (element_output_type*)handle->cit->data; +element_output_type *cot = (element_output_type*)handle->cot->data; +element_input_type *dxt = (element_input_type*)handle->dxt->data; +element_input_type *dcsp = (element_input_type* )handle->dcsp->data; +element_input_type *dhpD = (element_input_type* )handle->dhp->data; +element_filter_type *dw = (element_filter_type*)handle->dw->data; +element_filter_type *dr = (element_filter_type*)handle->dr->data; +element_output_type *db_bf16 = (element_output_type*)handle->db->data; +element_output_type *dcsD = (element_output_type*)handle->dcs->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *diD = (element_output_type*)handle->scratch_di; +element_output_type *dfD = (element_output_type*)handle->scratch_df; +element_output_type *doD = (element_output_type*)handle->scratch_do; +element_output_type *dciD = (element_output_type*)handle->scratch_dci; +float *dxD = (float*)handle->scratch_dx; +float *doutD = (float*)handle->scratch_deltat; +float *dhpD_f32 = (float*)handle->scratch_dhp; +float *db = (float*)handle->scratch_db; +element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; +#if 0 +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +#endif +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +float *w_scratch = (float*)handle->scratch_w; +float *r_scratch = (float*)handle->scratch_r; +element_filter_type *witD = &(wt[0]); +element_filter_type *wctD = &(wt[C*K]); +element_filter_type *wftD = &(wt[2*C*K]); +element_filter_type *wotD = &(wt[3*C*K]); +element_filter_type *ritD = &(rt[0]); +element_filter_type *rctD = &(rt[K*K]); +element_filter_type *rftD = &(rt[2*K*K]); +element_filter_type *rotD = &(rt[3*K*K]); +element_filter_type *dwiD = &(dw[0]); +element_filter_type *dwcD = &(dw[C*K]); +element_filter_type *dwfD = &(dw[2*C*K]); +element_filter_type *dwoD = &(dw[3*C*K]); +element_filter_type *driD = &(dr[0]); +element_filter_type *drcD = &(dr[K*K]); +element_filter_type *drfD = &(dr[2*K*K]); +element_filter_type *droD = &(dr[3*K*K]); +float *dwiD_scratch = &(w_scratch[0]); +float *dwcD_scratch = &(w_scratch[C*K]); +float *dwfD_scratch = &(w_scratch[2*C*K]); +float *dwoD_scratch = &(w_scratch[3*C*K]); +float *driD_scratch = &(r_scratch[0]); +float *drcD_scratch = &(r_scratch[K*K]); +float *drfD_scratch = &(r_scratch[2*K*K]); +float *droD_scratch = &(r_scratch[3*K*K]); +float *dbi = &(db[0]); +float *dbc = &(db[K]); +float *dbf = &(db[2*K]); +float *dbo = &(db[3*K]); +element_output_type *dbi_bf16 = &(db_bf16[0]); +element_output_type *dbc_bf16 = &(db_bf16[K]); +element_output_type *dbf_bf16 = &(db_bf16[2*K]); +element_output_type *dbo_bf16 = &(db_bf16[3*K]); +#if 0 +element_filter_type *scratch_wiT = &(scratch_wT[0]); +element_filter_type *scratch_wcT = &(scratch_wT[C*K]); +element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); +element_filter_type *scratch_woT = &(scratch_wT[3*C*K]); +element_filter_type *scratch_riT = &(scratch_rT[0]); +element_filter_type *scratch_rcT = &(scratch_rT[K*K]); +element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); +element_filter_type *scratch_roT = &(scratch_rT[3*K*K]); +#endif +/*element_output_type *t1D = (element_output_type*)handle->scratch_t1;*/ +/*element_output_type *t2D = (element_output_type*)handle->scratch_t2;*/ +/* multidimensional arrays */ +/*LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K);*/ +/*LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K);*/ +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +#if 0 +LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD, kBlocks, bk_lp, bk, lpb); +#endif +LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); +LIBXSMM_VLA_DECL(3, float, dx, dxD, N, C); +LIBXSMM_VLA_DECL(3, element_input_type, dx_bf16, dxt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, dcp, dcsp, K); +LIBXSMM_VLA_DECL(2, element_input_type, dhp, dhpD, K); +LIBXSMM_VLA_DECL(2, float, dhp_f32, dhpD_f32, K); +LIBXSMM_VLA_DECL(4, float, dwi, dwiD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwf, dwfD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwo, dwoD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwc, dwcD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dri, driD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drf, drfD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, dro, droD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drc, drcD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(5, element_filter_type, dwi_bf16, dwiD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dwc_bf16, dwcD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dwf_bf16, dwfD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dwo_bf16, dwoD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dri_bf16, driD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, drc_bf16, drcD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, drf_bf16, drfD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dro_bf16, droD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, element_output_type, dcs, dcsD, K); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K); +LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dci, dciD, K); +LIBXSMM_VLA_DECL(5, element_output_type, diB, (element_output_type*)handle->scratch_diB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, float, dout, doutD, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +LIBXSMM_VLA_DECL(5, element_filter_type, wiT, witD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wcT, wctD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wfT, wftD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, woT, wotD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, riT, ritD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rcT, rctD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rfT, rftD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, roT, rotD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +float *dout_ptr = NULL; +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->bwdupd_kernela; +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->bwdupd_kernelb; +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelc = handle->bwdupd_kernelc; +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kerneld = handle->bwdupd_kerneld; + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) +element_output_type *cps_ptr = NULL; +int k_tasks = K/16; +int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K; +const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K; +__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum; +#endif +#ifdef PROFILE +__int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0, reformat_cycles = 0; +float total_time = 0.0; +#endif +int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; + +libxsmm_blasint ikic, inic, inik, icin, ikin; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (K > 1024 && K <= 2048) { + BF = 8; + while (kBlocks % BF != 0) { + BF--; + } +} + +if (K > 2048) { + BF = 16; + while (kBlocks % BF != 0) { + BF--; + } +} + +KB_BLOCKS = kBlocks/BF; + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxD, start_thread, tid, handle->desc.threads); +} + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K*4, w_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K*4, r_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*4, db, start_thread, tid, handle->desc.threads); +} + +/* Here we assume that the weight tensors come in transposed from framework */ +#if 0 +#ifdef PROFILE +if (ltid == 0) _start = _rdtsc(); +#endif +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(5, wiT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wi, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, wcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wc, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, wfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wf, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, woT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wo, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(5, riT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, ri, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, rcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, rc, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, rfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, rf, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, roT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, ro, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + } + } +} +#ifdef PROFILE +if (ltid == 0) { + _end = _rdtsc(); + weight_trans_cycles += _end - _start; +} +#endif +#endif + +#include "libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16.tpl.c" + +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* Store result weight matrices in KCCK bf16 format and downcovert to bf16 */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + /* Below is the commented reference code */ +#if 0 + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bc; jc++) { + for (jk = 0; jk < bk; jk++) { + libxsmm_bfloat16_hp tmp; + tmp.f = LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(5, dwi_bf16, ikb, icb, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(5, dwc_bf16, ikb, icb, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(5, dwf_bf16, ikb, icb, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(5, dwo_bf16, ikb, icb, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = tmp.i[1]; + } + } + } + + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bk; jc++) { + for (jk = 0; jk < bk; jk++) { + libxsmm_bfloat16_hp tmp; + tmp.f = LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(5, dri_bf16, ikb, icb, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(5, drc_bf16, ikb, icb, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(5, drf_bf16, ikb, icb, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(5, dro_bf16, ikb, icb, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = tmp.i[1]; + } + } + } +#endif + __m512 a01, b01; + __m512i c01; + const __m512i perm_index = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bc; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + a01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc+1, jk, cBlocks, bc, bk)); + b01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk)); + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(a01, b01); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(5, dwi_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + a01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc+1, jk, cBlocks, bc, bk)); + b01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk)); + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(a01, b01); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(5, dwc_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + a01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc+1, jk, cBlocks, bc, bk)); + b01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk)); + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(a01, b01); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(5, dwf_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + a01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc+1, jk, cBlocks, bc, bk)); + b01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk)); + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(a01, b01); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(5, dwo_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + a01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc+1, jk, cBlocks, bc, bk)); + b01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, cBlocks, bc, bk)); + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(a01, b01); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(5, dri_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + a01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc+1, jk, cBlocks, bc, bk)); + b01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, cBlocks, bc, bk)); + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(a01, b01); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(5, drc_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + a01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc+1, jk, cBlocks, bc, bk)); + b01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, cBlocks, bc, bk)); + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(a01, b01); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(5, drf_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + a01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc+1, jk, cBlocks, bc, bk)); + b01 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, cBlocks, bc, bk)); + c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(a01, b01); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(5, dro_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } +#else + /* TODO: Add here non AVX512 replacement code */ + LIBXSMM_UNUSED(thr_begin_kk); + LIBXSMM_UNUSED(thr_begin_ck); + LIBXSMM_UNUSED(ikic); + LIBXSMM_UNUSED(jk); + LIBXSMM_UNUSED(jc); + LIBXSMM_UNUSED(thr_end_ck); + LIBXSMM_UNUSED(thr_end_kk); +#endif + libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + reformat_cycles += _end - _start; + } +#endif +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9))); + printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9))); + printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9))); + printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat dwdr time is %f ms (%.2f%%)\n\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); +} +#undef PROFILE +#endif + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..1eada7355703dc24030f7650594dddb205b04db7 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16_amx.tpl.c @@ -0,0 +1,441 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +const int lpb = handle->lpb; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +const int bn_lp = bn/lpb; +unsigned long long blocks; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *wt = (element_filter_type*)handle->wt->data; +element_filter_type *rt = (element_filter_type*)handle->rt->data; +element_output_type *cst = (element_output_type*)handle->cst->data; +element_output_type *ht = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_output_type *cit = (element_output_type*)handle->cit->data; +element_output_type *cot = (element_output_type*)handle->cot->data; +element_input_type *dxt = (element_input_type*)handle->dxt->data; +element_input_type *dcsp = (element_input_type* )handle->dcsp->data; +element_input_type *dhpD = (element_input_type* )handle->dhp->data; +element_filter_type *dw = (element_filter_type*)handle->dw->data; +element_filter_type *dr = (element_filter_type*)handle->dr->data; +element_output_type *db_bf16 = (element_output_type*)handle->db->data; +element_output_type *dcsD = (element_output_type*)handle->dcs->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *diD = (element_output_type*)handle->scratch_di; +element_output_type *dfD = (element_output_type*)handle->scratch_df; +element_output_type *doD = (element_output_type*)handle->scratch_do; +element_output_type *dciD = (element_output_type*)handle->scratch_dci; +float *dxD = (float*)handle->scratch_dx; +float *doutD = (float*)handle->scratch_deltat; +float *dhpD_f32 = (float*)handle->scratch_dhp; +float *db = (float*)handle->scratch_db; +element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; +#if 0 +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +#endif +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +float *w_scratch = (float*)handle->scratch_w; +float *r_scratch = (float*)handle->scratch_r; +element_filter_type *witD = &(wt[0]); +element_filter_type *wctD = &(wt[C*K]); +element_filter_type *wftD = &(wt[2*C*K]); +element_filter_type *wotD = &(wt[3*C*K]); +element_filter_type *ritD = &(rt[0]); +element_filter_type *rctD = &(rt[K*K]); +element_filter_type *rftD = &(rt[2*K*K]); +element_filter_type *rotD = &(rt[3*K*K]); +element_filter_type *dwiD = &(dw[0]); +element_filter_type *dwcD = &(dw[C*K]); +element_filter_type *dwfD = &(dw[2*C*K]); +element_filter_type *dwoD = &(dw[3*C*K]); +element_filter_type *driD = &(dr[0]); +element_filter_type *drcD = &(dr[K*K]); +element_filter_type *drfD = &(dr[2*K*K]); +element_filter_type *droD = &(dr[3*K*K]); +float *dwiD_scratch = &(w_scratch[0]); +float *dwcD_scratch = &(w_scratch[C*K]); +float *dwfD_scratch = &(w_scratch[2*C*K]); +float *dwoD_scratch = &(w_scratch[3*C*K]); +float *driD_scratch = &(r_scratch[0]); +float *drcD_scratch = &(r_scratch[K*K]); +float *drfD_scratch = &(r_scratch[2*K*K]); +float *droD_scratch = &(r_scratch[3*K*K]); +float *dbi = &(db[0]); +float *dbc = &(db[K]); +float *dbf = &(db[2*K]); +float *dbo = &(db[3*K]); +element_output_type *dbi_bf16 = &(db_bf16[0]); +element_output_type *dbc_bf16 = &(db_bf16[K]); +element_output_type *dbf_bf16 = &(db_bf16[2*K]); +element_output_type *dbo_bf16 = &(db_bf16[3*K]); +#if 0 +element_filter_type *scratch_wiT = &(scratch_wT[0]); +element_filter_type *scratch_wcT = &(scratch_wT[C*K]); +element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); +element_filter_type *scratch_woT = &(scratch_wT[3*C*K]); +element_filter_type *scratch_riT = &(scratch_rT[0]); +element_filter_type *scratch_rcT = &(scratch_rT[K*K]); +element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); +element_filter_type *scratch_roT = &(scratch_rT[3*K*K]); +#endif +/*element_output_type *t1D = (element_output_type*)handle->scratch_t1;*/ +/*element_output_type *t2D = (element_output_type*)handle->scratch_t2;*/ +/* multidimensional arrays */ +/*LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K);*/ +/*LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K);*/ +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +#if 0 +LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD, kBlocks, bk_lp, bk, lpb); +#endif +LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); +LIBXSMM_VLA_DECL(3, float, dx, dxD, N, C); +LIBXSMM_VLA_DECL(3, element_input_type, dx_bf16, dxt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, dcp, dcsp, K); +LIBXSMM_VLA_DECL(2, element_input_type, dhp, dhpD, K); +LIBXSMM_VLA_DECL(2, float, dhp_f32, dhpD_f32, K); +LIBXSMM_VLA_DECL(4, float, dwi, dwiD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwf, dwfD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwo, dwoD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwc, dwcD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dri, driD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drf, drfD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, dro, droD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drc, drcD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(5, element_filter_type, dwi_bf16, dwiD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dwc_bf16, dwcD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dwf_bf16, dwfD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dwo_bf16, dwoD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dri_bf16, driD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, drc_bf16, drcD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, drf_bf16, drfD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dro_bf16, droD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, element_output_type, dcs, dcsD, K); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K); +LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K); +LIBXSMM_VLA_DECL(2, element_output_type, dci, dciD, K); +LIBXSMM_VLA_DECL(5, element_output_type, diB, (element_output_type*)handle->scratch_diB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, float, dout, doutD, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +LIBXSMM_VLA_DECL(5, element_filter_type, wiT, witD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wcT, wctD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wfT, wftD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, woT, wotD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, riT, ritD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rcT, rctD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rfT, rftD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, roT, rotD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +float *dout_ptr = NULL; +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->bwdupd_kernela; /*libxsmm_bsmmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, &kernel_flags, NULL);*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->bwdupd_kernelb; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelc = handle->bwdupd_kernelc; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kerneld = handle->bwdupd_kerneld; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL);*/ +libxsmm_bsmmfunction_reducebatch_addr tile_config_kernel = handle->bwdupd_tileconfig; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL);*/ + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) +element_output_type *cps_ptr = NULL; +int k_tasks = K/16; +int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K; +const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K; +__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum; +#endif +#ifdef PROFILE +__int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0, reformat_cycles = 0; +float total_time = 0.0; +#endif +int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; + +libxsmm_blasint ikic, inic, inik, icin, ikin; + +/* Hoist tileconfig if possible */ +if ((bk % 32 == 0) && (bc % 32 == 0) && (bn % 32 == 0)) { + tile_config_kernel(NULL, NULL, NULL, NULL); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (K > 1024 && K <= 2048) { + BF = 8; + while (kBlocks % BF != 0) { + BF--; + } +} + +if (K > 2048) { + BF = 16; + while (kBlocks % BF != 0) { + BF--; + } +} + +BF = handle->bwdupd_block; +KB_BLOCKS = kBlocks/BF; + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxD, start_thread, tid, handle->desc.threads); +} + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K*4, w_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K*4, r_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*4, db, start_thread, tid, handle->desc.threads); +} + +/* Here we assume that the weight tensors come in transposed from framework */ +#if 0 +#ifdef PROFILE +if (ltid == 0) _start = _rdtsc(); +#endif +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(5, wiT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wi, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, wcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wc, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, wfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wf, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, woT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wo, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(5, riT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, ri, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, rcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, rc, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, rfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, rf, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, roT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, ro, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + } + } +} +#ifdef PROFILE +if (ltid == 0) { + _end = _rdtsc(); + weight_trans_cycles += _end - _start; +} +#endif +#endif + +#include "libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16_amx.tpl.c" + +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* Store result weight matrices in KCCK bf16 format and downcovert to bf16 */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + /* Below is the commented reference code */ +#if 0 + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bc; jc++) { + for (jk = 0; jk < bk; jk++) { + libxsmm_bfloat16_hp tmp; + tmp.f = LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(5, dwi_bf16, ikb, icb, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(5, dwc_bf16, ikb, icb, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(5, dwf_bf16, ikb, icb, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk); + LIBXSMM_VLA_ACCESS(5, dwo_bf16, ikb, icb, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = tmp.i[1]; + } + } + } + + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bk; jc++) { + for (jk = 0; jk < bk; jk++) { + libxsmm_bfloat16_hp tmp; + tmp.f = LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(5, dri_bf16, ikb, icb, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(5, drc_bf16, ikb, icb, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(5, drf_bf16, ikb, icb, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = tmp.i[1]; + tmp.f = LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk); + LIBXSMM_VLA_ACCESS(5, dro_bf16, ikb, icb, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = tmp.i[1]; + } + } + } +#endif + __m512i c01; + const __m512i perm_index = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bc; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwi_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwc_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwf_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwo_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dri_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, drc_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, drf_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dro_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } +#else + /* TODO: Add here non AVX512 replacement code */ + LIBXSMM_UNUSED(thr_begin_kk); + LIBXSMM_UNUSED(thr_begin_ck); + LIBXSMM_UNUSED(ikic); + LIBXSMM_UNUSED(jk); + LIBXSMM_UNUSED(jc); + LIBXSMM_UNUSED(thr_end_ck); + LIBXSMM_UNUSED(thr_end_kk); +#endif + libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + reformat_cycles += _end - _start; + } +#endif +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9))); + printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9))); + printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9))); + printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat dwdr time is %f ms (%.2f%%)\n\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); +} +#undef PROFILE +#endif + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..f7b04c0f81ca3f5b5054d65b9fcffd0f13aa029d --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core.tpl.c @@ -0,0 +1,526 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ + +for (j = t-1; j >= 0; --j) { + /* let's run the cell in blocks for good locality */ +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + inb = inik % (N/bn); + ikb = inik / (N/bn); + in = (inik % (N/bn))*bn; + ik = (inik / (N/bn))*bk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) + /* Compute dcp, dci, di, df, dp */ + cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) : &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); + if (bcbk_multiples_of_16) { + if (K % 2048 != 0 || LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) { +#include "libxsmm_internal_lstm_bwdupd_fused_eltwise.tpl.c" + } else { + /* Also reformat di, dci, df and dp to be used in the UPD pass in blocked format ... */ +#include "libxsmm_internal_lstm_bwdupd_fused_eltwise_reformat.tpl.c" + } + } else { + /* compute dhp */ + if (j == t-1) { + libxsmm_internal_matrix_copy_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } else { + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(3, dh, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } + /* compute dcp */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_square_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + if (j == t-1) { + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcs, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K) ); + } else { + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K) ); + } + /* compute dci */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_square_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dci, in, ik, K) ); + /* compute di */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, di, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, di, in, ik, K), &LIBXSMM_VLA_ACCESS(2, di, in, ik, K) ); + /* compute df */ + if (j == 0) { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + } else { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + } + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, df, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, df, in, ik, K), &LIBXSMM_VLA_ACCESS(2, df, in, ik, K) ); + /* compute dp */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K) ); + /* update dcp */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K) ); + } +#else + /* compute dhp */ + if (j == t-1) { + libxsmm_internal_matrix_copy_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } else { + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(3, dh, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) ); + } + /* compute dcp */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_square_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + if (j == t-1) { + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcs, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K) ); + } else { + libxsmm_internal_matrix_add_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K) ); + } + /* compute dci */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_square_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dci, in, ik, K) ); + /* compute di */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, di, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, di, in, ik, K), &LIBXSMM_VLA_ACCESS(2, di, in, ik, K) ); + /* compute df */ + if (j == 0) { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + } else { + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + } + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, df, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, df, in, ik, K), &LIBXSMM_VLA_ACCESS(2, df, in, ik, K) ); + /* compute dp */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K) ); + libxsmm_internal_matrix_complement_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, t1, in, ik, K), &LIBXSMM_VLA_ACCESS(2, t2, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K) ); + /* update dcp */ + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K), &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K) ); +#endif + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + eltwise_cycles += _end - _start; + } +#endif + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* transpose xt for current timestep */ + for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + in = (icin / (C/bc))*bn; + ic = (icin % (C/bc))*bc; + + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, xT, ec, en, N) = LIBXSMM_VLA_ACCESS(3, x, j, en, ec, N, C); + } + } + } + + /* transpose ht for current timestep */ + if (j == 0) { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(2, hp, en, ek, K); + } + } + } + } else { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, h, j-1, en, ek, N, K); + } + } + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + act_trans_cycles += _end - _start; + } +#endif + } + + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* dx = W^T * difoc */ + for (KB = 0; KB < BF; KB++) { + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + in = (inic % (N/bn))*bn; + icb = inic / (N/bn); + ic = icb*bc; + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wiT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C) , &blocks); + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wcT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, dci, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C) , &blocks); + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wfT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C) , &blocks); + + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik += bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, woT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(2, dp, in, ik + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C) , &blocks); + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + dx_cycles += _end - _start; + } +#endif + } + +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + for (KB = 0; KB < BF; KB++) { + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + + dout_ptr = (j > 0) ? (element_output_type*) &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) : (element_output_type*) &LIBXSMM_VLA_ACCESS(2, dhp, in, ik, K); + + if (KB == 0) libxsmm_internal_matrix_zero_ld( bk, bn, K, dout_ptr); + /* dout += R^T * difoc */ + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, riT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, di, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kerneld(A_array, B_array, dout_ptr, &blocks); + + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rcT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, dci, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kerneld(A_array, B_array, dout_ptr, &blocks); + + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rfT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, df, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kerneld(A_array, B_array, dout_ptr, &blocks); + + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, roT, ikb, icb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, dp, in, ic + KB*KB_BLOCKS*bk, K); + } + /* Reduce batch gemm call */ + batchreduce_kerneld(A_array, B_array, dout_ptr, &blocks); + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + dout_cycles += _end - _start; + } +#endif + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + if ((C == K) && (bc == bk) && (bcbk_multiples_of_16 == 1)) { + if (K % 2048 != 0) { + /* Interleave computation of dr = difoc * h^T and dw = difoc * x^T to take advantage of temporal locality */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + blocks = nBlocks; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dci, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dci, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } else { + /* Interleave computation of dr = difoc * h^T and dw = difoc * x^T to take advantage of temporal locality */ + /* Use blocked format for di, dci, df and dp */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + blocks = nBlocks; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, diB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, diB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dciB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dciB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dfB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dfB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dpB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(4, dpB, inb, ikb, 0, 0, kBlocks, bn, bk); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } + } else { + /* dr = difoc * h^T */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dci, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + batchreduce_kernelb1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + } + + /* dw = difoc * x^T */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dci, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + batchreduce_kernelc1(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + dwdr_cycles += _end - _start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* gradient bias */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bcbk_multiples_of_16) { + for (ik = k_thr_begin; ik < k_thr_end; ik += 16) { + dbi_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbi[ik]); + dbf_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbf[ik]); + dbo_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbo[ik]); + dbc_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbc[ik]); + for (in = 0; in < N; in++) { + dbi_sum = _mm512_add_ps(dbi_sum, LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, di, in, ik, K))); + dbf_sum = _mm512_add_ps(dbf_sum, LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, df, in, ik, K))); + dbo_sum = _mm512_add_ps(dbo_sum, LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, dp, in, ik, K))); + dbc_sum = _mm512_add_ps(dbc_sum, LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, dci, in, ik, K))); + } + _mm512_storeu_ps(&dbi[ik], dbi_sum); + _mm512_storeu_ps(&dbf[ik], dbf_sum); + _mm512_storeu_ps(&dbo[ik], dbo_sum); + _mm512_storeu_ps(&dbc[ik], dbc_sum); + } + } else { + for (ik = thr_begin_k; ik < thr_end_k; ik++) { + for (in = 0; in < N; in++) { + dbi[ik] += LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + dbf[ik] += LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + dbo[ik] += LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + dbc[ik] += LIBXSMM_VLA_ACCESS(2, dci, in, ik, K); + } + } + } +#else + for (ik = thr_begin_k; ik < thr_end_k; ik++) { + for (in = 0; in < N; in++) { + dbi[ik] += LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + dbf[ik] += LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + dbo[ik] += LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + dbc[ik] += LIBXSMM_VLA_ACCESS(2, dci, in, ik, K); + } + } +#endif +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + gradient_cycles += _end - _start; + } +#endif + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..42789b49c8e5576ffc4783e8307dcd00ff3b302a --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16.tpl.c @@ -0,0 +1,343 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ + +#define NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(m, n, ld, _src, _dst) \ +do { \ + float *const src = _src; \ + libxsmm_bfloat16 *const dst = _dst; \ + libxsmm_blasint __i, __j; \ + __m512i packed_result; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=32 ) { \ + packed_result = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&src[(__j*ld)+__i+16]), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&src[(__j*ld)+__i])); \ + _mm512_storeu_si512(&dst[(__j*ld)+__i], packed_result); \ + } \ + } \ +} while (0) + +for (j = t-1; j >= 0; --j) { + /* let's run the cell in blocks for good locality */ +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + inb = inik % (N/bn); + ikb = inik / (N/bn); + in = (inik % (N/bn))*bn; + ik = (inik / (N/bn))*bk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) + /* Compute dcp, dci, di, df, dp */ + cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) : &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); + if (bcbk_multiples_of_16) { + /* Also reformat di, dci, df and dp to be used in the UPD pass in blocked format ... */ +#include "libxsmm_internal_lstm_bwdupd_fused_eltwise_reformat_bf16.tpl.c" + } else { + /* TODO: Add alternative path here */ + } +#else + /* TODO: Add alternative path here */ +#endif + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + eltwise_cycles += _end - _start; + } +#endif + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* transpose xt for current timestep */ + for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + in = (icin / (C/bc))*bn; + ic = (icin % (C/bc))*bc; + + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, xT, ec, en, N) = LIBXSMM_VLA_ACCESS(3, x, j, en, ec, N, C); + } + } + } + + /* transpose ht for current timestep */ + if (j == 0) { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(2, hp, en, ek, K); + } + } + } + } else { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, h, j-1, en, ek, N, K); + } + } + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + act_trans_cycles += _end - _start; + } +#endif + } + + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* dx = W^T * difoc */ + blocks = KB_BLOCKS; + for (KB = 0; KB < BF; KB++) { + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + in = (inic % (N/bn))*bn; + icb = inic / (N/bn); + ic = icb*bc; + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wiT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(2, di, in, KB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wcT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(2, dci, in, KB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wfT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(2, df, in, KB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, woT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(2, dp, in, KB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + /* If last block, make sure we downconvert dx to bf16 */ + if (KB == BF-1) { + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bc, bn, C, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &LIBXSMM_VLA_ACCESS(3, dx_bf16, j, in, ic, N, C)); + } + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + dx_cycles += _end - _start; + } +#endif + } + +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + blocks = KB_BLOCKS; + for (KB = 0; KB < BF; KB++) { + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + dout_ptr = (j > 0) ? (float*) &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) : (float*) &LIBXSMM_VLA_ACCESS(2, dhp_f32, in, ik, K); + + if (KB == 0) libxsmm_internal_matrix_zero_ld( bk, bn, K, dout_ptr); + /* dout += R^T * difoc */ + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, riT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, di, in, KB*KB_BLOCKS*bk, K), + dout_ptr, &blocks); + + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, rcT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, dci, in, KB*KB_BLOCKS*bk, K), + dout_ptr, &blocks); + + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, rfT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, df, in, KB*KB_BLOCKS*bk, K), + dout_ptr, &blocks); + + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, roT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, dp, in, KB*KB_BLOCKS*bk, K), + dout_ptr, &blocks); + + /* Make sure when last and j == 0 to downconvert dhp to BF16 */ + if ((j == 0) && (KB == BF-1)) { + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, dout_ptr, &LIBXSMM_VLA_ACCESS(2, dhp, in, ik, K)); + } + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + dout_cycles += _end - _start; + } +#endif + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + blocks = nBlocks; + if ((C == K) && (bc == bk) && (bcbk_multiples_of_16 == 1)) { + /* Interleave computation of dr = difoc * h^T and dw = difoc * x^T to take advantage of temporal locality */ + /* Use blocked format for di, dci, df and db */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } else { + /* dr = difoc * h^T */ + /* Use blocked format for di, dci, df and db */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + } + + /* dw = difoc * x^T */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + dwdr_cycles += _end - _start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* gradient bias */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bcbk_multiples_of_16) { + for (ik = k_thr_begin; ik < k_thr_end; ik += 16) { + dbi_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbi[ik]); + dbf_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbf[ik]); + dbo_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbo[ik]); + dbc_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbc[ik]); + for (in = 0; in < N; in++) { + dbi_sum = _mm512_add_ps(dbi_sum, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, di, in, ik, K)))); + dbf_sum = _mm512_add_ps(dbf_sum, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, df, in, ik, K)))); + dbo_sum = _mm512_add_ps(dbo_sum, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dp, in, ik, K)))); + dbc_sum = _mm512_add_ps(dbc_sum, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(2, dci, in, ik, K)))); + } + _mm512_storeu_ps(&dbi[ik], dbi_sum); + _mm512_storeu_ps(&dbf[ik], dbf_sum); + _mm512_storeu_ps(&dbo[ik], dbo_sum); + _mm512_storeu_ps(&dbc[ik], dbc_sum); + /* Downconvert delta bias to bf16 if done with all timesteps */ + if (j == 0) { + _mm256_storeu_si256((__m256i*)&dbi_bf16[ik], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(dbi_sum)); + _mm256_storeu_si256((__m256i*)&dbf_bf16[ik], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(dbf_sum)); + _mm256_storeu_si256((__m256i*)&dbo_bf16[ik], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(dbo_sum)); + _mm256_storeu_si256((__m256i*)&dbc_bf16[ik], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(dbc_sum)); + } + } + } else { + /* TODO: Add alternative path here */ + } +#else + /* TODO: Add alternative path here */ +#endif +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + gradient_cycles += _end - _start; + } +#endif + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +#undef NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..4972227347d2d44266b34ccb80039b584e9d8090 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core_bf16_amx.tpl.c @@ -0,0 +1,342 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#define NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(m, n, ld, _src, _dst) \ +do { \ + float *const __src = _src; \ + libxsmm_bfloat16 *__dst = _dst; \ + libxsmm_blasint __i, __j; \ + __m512i __packed_result; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=32 ) { \ + __packed_result = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i+16]), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i])); \ + _mm512_storeu_si512((libxsmm_bfloat16*)&__dst[(__j*ld)+__i], (__m512i) __packed_result); \ + } \ + } \ +} while (0) + +for (j = t-1; j >= 0; --j) { + /* let's run the cell in blocks for good locality */ +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + inb = inik % (N/bn); + ikb = inik / (N/bn); + in = (inik % (N/bn))*bn; + ik = (inik / (N/bn))*bk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) + /* Compute dcp, dci, di, df, dp */ + cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) : &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); + if (bcbk_multiples_of_16) { + /* Also reformat di, dci, df and dp to be used in the UPD pass in blocked format ... */ +#include "libxsmm_internal_lstm_bwdupd_fused_eltwise_reformat_bf16.tpl.c" + } else { + /* TODO: Add alternative path here */ + } +#else + /* TODO: Add alternative path here */ +#endif + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + eltwise_cycles += _end - _start; + } +#endif + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* transpose xt for current timestep */ + for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + in = (icin / (C/bc))*bn; + ic = (icin % (C/bc))*bc; + + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, xT, ec, en, N) = LIBXSMM_VLA_ACCESS(3, x, j, en, ec, N, C); + } + } + } + + /* transpose ht for current timestep */ + if (j == 0) { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(2, hp, en, ek, K); + } + } + } + } else { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + in = (ikin / (K/bk))*bn; + ik = (ikin % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, h, j-1, en, ek, N, K); + } + } + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + act_trans_cycles += _end - _start; + } +#endif + } + + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* dx = W^T * difoc */ + blocks = KB_BLOCKS; + for (KB = 0; KB < BF; KB++) { + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + in = (inic % (N/bn))*bn; + icb = inic / (N/bn); + ic = icb*bc; + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wiT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(2, di, in, KB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wcT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(2, dci, in, KB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wfT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(2, df, in, KB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, woT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(2, dp, in, KB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &blocks); + + /* If last block, make sure we downconvert dx to bf16 */ + if (KB == BF-1) { + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bc, bn, C, &LIBXSMM_VLA_ACCESS(3, dx, j, in, ic, N, C), &LIBXSMM_VLA_ACCESS(3, dx_bf16, j, in, ic, N, C)); + } + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + dx_cycles += _end - _start; + } +#endif + } + +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + blocks = KB_BLOCKS; + for (KB = 0; KB < BF; KB++) { + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + dout_ptr = (j > 0) ? (float*) &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K) : (float*) &LIBXSMM_VLA_ACCESS(2, dhp_f32, in, ik, K); + + if (KB == 0) libxsmm_internal_matrix_zero_ld( bk, bn, K, dout_ptr); + /* dout += R^T * difoc */ + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, riT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, di, in, KB*KB_BLOCKS*bk, K), + dout_ptr, &blocks); + + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, rcT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, dci, in, KB*KB_BLOCKS*bk, K), + dout_ptr, &blocks); + + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, rfT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, df, in, KB*KB_BLOCKS*bk, K), + dout_ptr, &blocks); + + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, roT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, dp, in, KB*KB_BLOCKS*bk, K), + dout_ptr, &blocks); + + /* Make sure when last and j == 0 to downconvert dhp to BF16 */ + if ((j == 0) && (KB == BF-1)) { + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, dout_ptr, &LIBXSMM_VLA_ACCESS(2, dhp, in, ik, K)); + } + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + dout_cycles += _end - _start; + } +#endif + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + blocks = nBlocks; + if ((C == K) && (bc == bk) && (bcbk_multiples_of_16 == 1)) { + /* Interleave computation of dr = difoc * h^T and dw = difoc * x^T to take advantage of temporal locality */ + /* Use blocked format for di, dci, df and db */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } else { + /* dr = difoc * h^T */ + /* Use blocked format for di, dci, df and db */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + } + + /* dw = difoc * x^T */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, xT, ic, 0, N), + &LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + dwdr_cycles += _end - _start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* gradient bias */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bcbk_multiples_of_16) { + for (ik = k_thr_begin; ik < k_thr_end; ik += 16) { + dbi_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbi[ik]); + dbf_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbf[ik]); + dbo_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbo[ik]); + dbc_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbc[ik]); + for (in = 0; in < N; in++) { + dbi_sum = _mm512_add_ps(dbi_sum, _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(2, di, in, ik, K))); + dbf_sum = _mm512_add_ps(dbf_sum, _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(2, df, in, ik, K))); + dbo_sum = _mm512_add_ps(dbo_sum, _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(2, dp, in, ik, K))); + dbc_sum = _mm512_add_ps(dbc_sum, _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(2, dci, in, ik, K))); + } + _mm512_store_ps(&dbi[ik], dbi_sum); + _mm512_store_ps(&dbf[ik], dbf_sum); + _mm512_store_ps(&dbo[ik], dbo_sum); + _mm512_store_ps(&dbc[ik], dbc_sum); + /* Downconvert delta bias to bf16 if done with all timesteps */ + if (j == 0) { + _mm512_storecvt_fp32_bf16(&dbi_bf16[ik], dbi_sum); + _mm512_storecvt_fp32_bf16(&dbf_bf16[ik], dbf_sum); + _mm512_storecvt_fp32_bf16(&dbo_bf16[ik], dbo_sum); + _mm512_storecvt_fp32_bf16(&dbc_bf16[ik], dbc_sum); + } + } + } else { + /* TODO: Add alternative path here */ + } +#else + /* TODO: Add alternative path here */ +#endif +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + gradient_cycles += _end - _start; + } +#endif + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +#undef NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..c731115b025eae8c8026b645cd8f8522306d0d92 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_bf16_amx.tpl.c @@ -0,0 +1,366 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, inb, icb, jk, jb, jc, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +const int lpb = handle->lpb; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +const int bn_lp = bn/lpb; +unsigned long long blocks; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *wt = (element_filter_type*)handle->wt->data; +element_filter_type *rt = (element_filter_type*)handle->rt->data; +element_output_type *cst = (element_output_type*)handle->cst->data; +element_output_type *ht = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_output_type *cit = (element_output_type*)handle->cit->data; +element_output_type *cot = (element_output_type*)handle->cot->data; +element_input_type *dxt = (element_input_type*)handle->dxt->data; +element_input_type *dcsp = (element_input_type* )handle->dcsp->data; +element_input_type *dhpD = (element_input_type* )handle->dhp->data; +element_filter_type *dw = (element_filter_type*)handle->dw->data; +element_filter_type *dr = (element_filter_type*)handle->dr->data; +element_output_type *db_bf16 = (element_output_type*)handle->db->data; +element_output_type *dcsD = (element_output_type*)handle->dcs->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *diD = (element_output_type*)handle->scratch_di; +element_output_type *dfD = (element_output_type*)handle->scratch_df; +element_output_type *doD = (element_output_type*)handle->scratch_do; +element_output_type *dciD = (element_output_type*)handle->scratch_dci; +float *dxD = (float*)handle->scratch_dx; +float *doutD = (float*)handle->scratch_deltat; +float *dhpD_f32 = (float*)handle->scratch_dhp; +float *db = (float*)handle->scratch_db; +element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; +#if 0 +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +#endif +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +float *w_scratch = (float*)handle->scratch_w; +float *r_scratch = (float*)handle->scratch_r; +element_filter_type *witD = &(wt[0]); +element_filter_type *wctD = &(wt[C*K]); +element_filter_type *wftD = &(wt[2*C*K]); +element_filter_type *wotD = &(wt[3*C*K]); +element_filter_type *ritD = &(rt[0]); +element_filter_type *rctD = &(rt[K*K]); +element_filter_type *rftD = &(rt[2*K*K]); +element_filter_type *rotD = &(rt[3*K*K]); +element_filter_type *dwiD = &(dw[0]); +element_filter_type *dwcD = &(dw[C*K]); +element_filter_type *dwfD = &(dw[2*C*K]); +element_filter_type *dwoD = &(dw[3*C*K]); +element_filter_type *driD = &(dr[0]); +element_filter_type *drcD = &(dr[K*K]); +element_filter_type *drfD = &(dr[2*K*K]); +element_filter_type *droD = &(dr[3*K*K]); +float *dwiD_scratch = &(w_scratch[0]); +float *dwcD_scratch = &(w_scratch[C*K]); +float *dwfD_scratch = &(w_scratch[2*C*K]); +float *dwoD_scratch = &(w_scratch[3*C*K]); +float *driD_scratch = &(r_scratch[0]); +float *drcD_scratch = &(r_scratch[K*K]); +float *drfD_scratch = &(r_scratch[2*K*K]); +float *droD_scratch = &(r_scratch[3*K*K]); +float *dbi = &(db[0]); +float *dbc = &(db[K]); +float *dbf = &(db[2*K]); +float *dbo = &(db[3*K]); +element_output_type *dbi_bf16 = &(db_bf16[0]); +element_output_type *dbc_bf16 = &(db_bf16[K]); +element_output_type *dbf_bf16 = &(db_bf16[2*K]); +element_output_type *dbo_bf16 = &(db_bf16[3*K]); +#if 0 +element_filter_type *scratch_wiT = &(scratch_wT[0]); +element_filter_type *scratch_wcT = &(scratch_wT[C*K]); +element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); +element_filter_type *scratch_woT = &(scratch_wT[3*C*K]); +element_filter_type *scratch_riT = &(scratch_rT[0]); +element_filter_type *scratch_rcT = &(scratch_rT[K*K]); +element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); +element_filter_type *scratch_roT = &(scratch_rT[3*K*K]); +#endif +/*element_output_type *t1D = (element_output_type*)handle->scratch_t1;*/ +/*element_output_type *t2D = (element_output_type*)handle->scratch_t2;*/ +/* multidimensional arrays */ +/*LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K);*/ +/*LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K);*/ +LIBXSMM_VLA_DECL(5, element_input_type, x, xt, nBlocks, cBlocks, bn, bc); +LIBXSMM_VLA_DECL(4, element_input_type, cp, csp, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_input_type, hp, hpD, kBlocks, bn, bk); +#if 0 +LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD, kBlocks, bk_lp, bk, lpb); +#endif +LIBXSMM_VLA_DECL(5, element_output_type, cs, cst, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, h, ht, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, i, it, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, f, ft, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, o, ot, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, ci, cit, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, co, cot, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, float, dx, dxD, nBlocks, cBlocks, bn, bc); +LIBXSMM_VLA_DECL(5, element_input_type, dx_bf16, dxt, nBlocks, cBlocks, bn, bc); +LIBXSMM_VLA_DECL(4, element_input_type, dcp, dcsp, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_input_type, dhp, dhpD, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, float, dhp_f32, dhpD_f32, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, float, dwi, dwiD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwf, dwfD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwo, dwoD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dwc, dwcD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, float, dri, driD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drf, drfD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, dro, droD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, float, drc, drcD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(5, element_filter_type, dwi_bf16, dwiD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dwc_bf16, dwcD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dwf_bf16, dwfD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dwo_bf16, dwoD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dri_bf16, driD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, drc_bf16, drcD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, drf_bf16, drfD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, dro_bf16, droD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(4, element_output_type, dcs, dcsD, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, dh, dht, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, di, diD, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, df, dfD, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dp, doD, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_output_type, dci, dciD, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, diB, (element_output_type*)handle->scratch_diB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, nBlocks, bn_lp, bk, lpb); +LIBXSMM_VLA_DECL(4, float, dout, doutD, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_input_type, xT, scratch_xT, nBlocks, bc, bn); +LIBXSMM_VLA_DECL(5, element_filter_type, wiT, witD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wcT, wctD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wfT, wftD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, woT, wotD, kBlocks, bk_lp, bc, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, riT, ritD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rcT, rctD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rfT, rftD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, roT, rotD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(4, element_output_type, hT, scratch_hT, nBlocks, bk, bn); +float *dout_ptr = NULL; +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->bwdupd_kernela; /*libxsmm_bsmmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, &kernel_flags, NULL);*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->bwdupd_kernelb; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelc = handle->bwdupd_kernelc; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kerneld = handle->bwdupd_kerneld; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL);*/ +libxsmm_bsmmfunction_reducebatch_addr tile_config_kernel = handle->bwdupd_tileconfig; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL);*/ + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) +element_output_type *cps_ptr = NULL; +int k_tasks = K/16; +int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K; +const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K; +__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum; +#endif +#ifdef PROFILE +__int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0, reformat_cycles = 0; +float total_time = 0.0; +#endif +int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; + +libxsmm_blasint ikic, inic, inik, icin, ikin; +__m512i c01; +const __m512i perm_index = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + +/* Hoist tileconfig if possible */ +if ((bk % 32 == 0) && (bc % 32 == 0) && (bn % 32 == 0)) { + tile_config_kernel(NULL, NULL, NULL, NULL); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (K > 1024 && K <= 2048) { + BF = 8; + while (kBlocks % BF != 0) { + BF--; + } +} + +if (K > 2048) { + BF = 16; + while (kBlocks % BF != 0) { + BF--; + } +} + +BF = handle->bwdupd_block; +KB_BLOCKS = kBlocks/BF; + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxD, start_thread, tid, handle->desc.threads); +} + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K*4, w_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K*4, r_scratch, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*4, db, start_thread, tid, handle->desc.threads); +} + +/* Here we assume that the weight tensors come in transposed from framework */ +#if 0 +#ifdef PROFILE +if (ltid == 0) _start = _rdtsc(); +#endif +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(5, wiT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wi, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, wcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wc, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, wfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wf, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, woT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bc, lpb) = LIBXSMM_VLA_ACCESS(5, wo, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(5, riT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, ri, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, rcT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, rc, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, rfT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, rf, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + LIBXSMM_VLA_ACCESS(5, roT, ic, ik, jk/lpb, jc, jk%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(5, ro, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb); + } + } +} +#ifdef PROFILE +if (ltid == 0) { + _end = _rdtsc(); + weight_trans_cycles += _end - _start; +} +#endif +#endif + +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +#include "libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_core_bf16_amx.tpl.c" + +handle->tilerelease_kernel(NULL, NULL, NULL); + +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#ifdef PROFILE + if (ltid == 0) _start = _rdtsc(); +#endif + /* Store result weight matrices in KCCK bf16 format and downcovert to bf16 */ +#if defined(LIBXSMM_RNN_CELL_AVX512) +#else + /* TODO: Add here non AVX512 replacement code */ + LIBXSMM_UNUSED(thr_begin_kk); + LIBXSMM_UNUSED(thr_begin_ck); + LIBXSMM_UNUSED(ikic); + LIBXSMM_UNUSED(jk); + LIBXSMM_UNUSED(jc); + LIBXSMM_UNUSED(thr_end_ck); + LIBXSMM_UNUSED(thr_end_kk); +#endif + libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE + if (ltid == 0) { + _end = _rdtsc(); + reformat_cycles += _end - _start; + } +#endif +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9))); + printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9))); + printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9))); + printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat dwdr time is %f ms (%.2f%%)\n\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); +} +#undef PROFILE +#endif + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_core_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_core_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..94b535e18c948b6fc262b0486640d5d9e9ad446b --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_core_bf16_amx.tpl.c @@ -0,0 +1,405 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#define NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(m, n, ld, _src, _dst) \ +do { \ + float *const __src = _src; \ + libxsmm_bfloat16 *__dst = _dst; \ + libxsmm_blasint __i, __j; \ + __m512i __packed_result; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=32 ) { \ + __packed_result = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i+16]), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i])); \ + _mm512_storeu_si512((libxsmm_bfloat16*)&__dst[(__j*ld)+__i], (__m512i) __packed_result); \ + } \ + } \ +} while (0) + +for (j = t-1; j >= 0; --j) { + /* let's run the cell in blocks for good locality */ + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + inb = inik % (N/bn); + ikb = inik / (N/bn); + in = (inik % (N/bn))*bn; + ik = (inik / (N/bn))*bk; + /* Compute dcp, dci, di, df, dp */ + cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(4, cp, inb, ikb, 0, 0, kBlocks, bn, bk) : &LIBXSMM_VLA_ACCESS(5, cs, j-1, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + /* Also reformat di, dci, df and dp to be used in the UPD pass in blocked format ... */ +#include "libxsmm_internal_lstm_bwdupd_fused_eltwise_ncnc_reformat_bf16.tpl.c" + } + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* transpose xt for current timestep */ + for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + inb = icin / (C/bc); + icb = icin % (C/bc); + if (bc == 32 && bk == 32) { + trans_act((short int*)&LIBXSMM_VLA_ACCESS(5, x, j, inb, icb, 0, 0, nBlocks, cBlocks, bn, bc), (short int*)&LIBXSMM_VLA_ACCESS(4, xT, icb, inb, 0, 0, nBlocks, bc, bn)); + } else { + in = inb*bn; + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + LIBXSMM_VLA_ACCESS(4, xT, icb, inb, jc, jb, nBlocks, bc, bn) = LIBXSMM_VLA_ACCESS(5, x, j, inb, icb, jb, jc, nBlocks, cBlocks, bn, bc); + } + } + } + } + + /* transpose ht for current timestep */ + if (j == 0) { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + inb = ikin / (K/bk); + ikb = ikin % (K/bk); + if (bc == 32 && bk == 32) { + trans_act((short int*)&LIBXSMM_VLA_ACCESS(4, hp, inb, ikb, 0, 0, kBlocks, bn, bk), (short int*)&LIBXSMM_VLA_ACCESS(4, hT, ikb, inb, 0, 0, nBlocks, bk, bn)); + } else { + in = inb*bn; + ik = ikb*bk; + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + LIBXSMM_VLA_ACCESS(4, hT, ikb, inb, jk, jb, nBlocks, bk, bn) = LIBXSMM_VLA_ACCESS(4, hp, inb, ikb, jb, jk, kBlocks, bn, bk); + } + } + } + } + } else { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + inb = ikin / (K/bk); + ikb = ikin % (K/bk); + if (bc == 32 && bk == 32) { + trans_act((short int*)&LIBXSMM_VLA_ACCESS(5, h, j-1, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), (short int*)&LIBXSMM_VLA_ACCESS(4, hT, ikb, inb, 0, 0, nBlocks, bk, bn)); + } else { + ik = ikb*bk; + in = inb*bn; + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + LIBXSMM_VLA_ACCESS(4, hT, ikb, inb, jk, jb, nBlocks, bk, bn) = LIBXSMM_VLA_ACCESS(5, h, j-1, inb, ikb, jb, jk, nBlocks, kBlocks, bn, bk); + } + } + } + } + } + } + + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* dx = W^T * difoc */ + blocks = KB_BLOCKS; + for (KB = 0; KB < BF; KB++) { + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + inb = inic % (N/bn); + in = inb*bn; + icb = inic / (N/bn); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wiT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, di, inb, KB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, dx, j, inb, icb, 0, 0, nBlocks, cBlocks, bn, bc), &blocks); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wcT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, dci, inb, KB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, dx, j, inb, icb, 0, 0, nBlocks, cBlocks, bn, bc), &blocks); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wfT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, df, inb, KB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, dx, j, inb, icb, 0, 0, nBlocks, cBlocks, bn, bc), &blocks); + + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, woT, icb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bc, lpb), + &LIBXSMM_VLA_ACCESS(4, dp, inb, KB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, dx, j, inb, icb, 0, 0, nBlocks, cBlocks, bn, bc), &blocks); + + /* If last block, make sure we downconvert dx to bf16 */ + if (KB == BF-1) { + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bc, bn, bc, &LIBXSMM_VLA_ACCESS(5, dx, j, inb, icb, 0, 0, nBlocks, cBlocks, bn, bc), &LIBXSMM_VLA_ACCESS(5, dx_bf16, j, inb, icb, 0, 0, nBlocks, cBlocks, bn, bc)); + } + } + } + } + + blocks = KB_BLOCKS; + for (KB = 0; KB < BF; KB++) { + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + inb = inik % (N/bn); + in = inb*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + dout_ptr = (j > 0) ? (float*) &LIBXSMM_VLA_ACCESS(4, dout, inb, ikb, 0, 0, kBlocks, bn, bk) : (float*) &LIBXSMM_VLA_ACCESS(4, dhp_f32, inb, ikb, 0, 0, kBlocks, bn, bk); + + if (KB == 0) libxsmm_internal_matrix_zero_ld( bk, bn, bk, dout_ptr); + /* dout += R^T * difoc */ + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, riT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, di, inb, KB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + dout_ptr, &blocks); + + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, rcT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, dci, inb, KB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + dout_ptr, &blocks); + + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, rfT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, df, inb, KB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + dout_ptr, &blocks); + + batchreduce_kerneld(&LIBXSMM_VLA_ACCESS(5, roT, ikb, KB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, dp, inb, KB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + dout_ptr, &blocks); + + /* Make sure when last and j == 0 to downconvert dhp to BF16 */ + if ((j == 0) && (KB == BF-1)) { + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, bk, dout_ptr, &LIBXSMM_VLA_ACCESS(4, dhp, inb, ikb, 0, 0, kBlocks, bn, bk)); + } + } + } + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + blocks = nBlocks; + if ((C == K) && (bc == bk) && (bcbk_multiples_of_16 == 1)) { + /* Interleave computation of dr = difoc * h^T and dw = difoc * x^T to take advantage of temporal locality */ + /* Use blocked format for di, dci, df and db */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ikb = ikic % (K/bk); + ik = ikb*bk; + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hT, icb, 0, 0, 0, nBlocks, bk, bn), + &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dri_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, xT, icb, 0, 0, 0, nBlocks, bc, bn), + &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwi_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hT, icb, 0, 0, 0, nBlocks, bk, bn), + &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, drc_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, xT, icb, 0, 0, 0, nBlocks, bc, bn), + &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwc_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hT, icb, 0, 0, 0, nBlocks, bk, bn), + &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, drf_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, xT, icb, 0, 0, 0, nBlocks, bc, bn), + &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwf_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hT, icb, 0, 0, 0, nBlocks, bk, bn), + &LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dro_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, xT, icb, 0, 0, 0, nBlocks, bc, bn), + &LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwo_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + } + } else { + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ikb = ikic % (K/bk); + ik = ikb*bk; + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hT, icb, 0, 0, 0, nBlocks, bk, bn), + &LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dri_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hT, icb, 0, 0, 0, nBlocks, bk, bn), + &LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, drc_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hT, icb, 0, 0, 0, nBlocks, bk, bn), + &LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, drf_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hT, icb, 0, 0, 0, nBlocks, bk, bn), + &LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc+1, jk, kBlocks, bk, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dro_bf16, ikb, icb, jc/lpb, jk, 0, kBlocks, bk_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + } + + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ikb = ikic % (K/bk); + ik = ikb*bk; + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, diB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, xT, icb, 0, 0, 0, nBlocks, bc, bn), + &LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bc; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwi_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dciB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, xT, icb, 0, 0, 0, nBlocks, bc, bn), + &LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bc; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwc_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dfB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, xT, icb, 0, 0, 0, nBlocks, bc, bn), + &LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bc; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwf_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + + batchreduce_kernelc(&LIBXSMM_VLA_ACCESS(5, dpB, ikb, 0, 0, 0, 0, nBlocks, bn_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, xT, icb, 0, 0, 0, nBlocks, bc, bn), + &LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + if (j == 0) { + for (jc = 0; jc < bk; jc+=2) { + for (jk = 0; jk < bk; jk+=16) { + c01 = (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc+1, jk, cBlocks, bc, bk)), LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk))); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(5, dwo_bf16, ikb, icb, jc/lpb, jk, 0, cBlocks, bc_lp, bk, lpb), _mm512_permutexvar_epi16(perm_index, c01)); + } + } + } + } + } + + /* gradient bias */ + if (bcbk_multiples_of_16) { + for (ik = k_thr_begin; ik < k_thr_end; ik += 16) { + dbi_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbi[ik]); + dbf_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbf[ik]); + dbo_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbo[ik]); + dbc_sum = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&dbc[ik]); + for (in = 0; in < N; in++) { + dbi_sum = _mm512_add_ps(dbi_sum, _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(4, di, in/bn, ik/bk, in%bn, ik%bk, kBlocks, bn, bk))); + dbf_sum = _mm512_add_ps(dbf_sum, _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(4, df, in/bn, ik/bk, in%bn, ik%bk, kBlocks, bn, bk))); + dbo_sum = _mm512_add_ps(dbo_sum, _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(4, dp, in/bn, ik/bk, in%bn, ik%bk, kBlocks, bn, bk))); + dbc_sum = _mm512_add_ps(dbc_sum, _mm512_loadcvt_bf16_fp32(&LIBXSMM_VLA_ACCESS(4, dci, in/bn, ik/bk, in%bn, ik%bk, kBlocks, bn, bk))); + } + _mm512_store_ps(&dbi[ik], dbi_sum); + _mm512_store_ps(&dbf[ik], dbf_sum); + _mm512_store_ps(&dbo[ik], dbo_sum); + _mm512_store_ps(&dbc[ik], dbc_sum); + /* Downconvert delta bias to bf16 if done with all timesteps */ + if (j == 0) { + _mm512_storecvt_fp32_bf16(&dbi_bf16[ik], dbi_sum); + _mm512_storecvt_fp32_bf16(&dbf_bf16[ik], dbf_sum); + _mm512_storecvt_fp32_bf16(&dbo_bf16[ik], dbo_sum); + _mm512_storecvt_fp32_bf16(&dbc_bf16[ik], dbc_sum); + } + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +#undef NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..50ad74c6ed6148190d6743d03a1ee9e2221c96ef --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c @@ -0,0 +1,214 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, ic, icb, inik, BF, CB, CB_BLOCKS, KB_BLOCKS, ikic, jk, jc; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +unsigned long long blocks; + +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_filter_type *w_scratch = (element_filter_type*)handle->scratch_w; +element_filter_type *r_scratch = (element_filter_type*)handle->scratch_r; +element_output_type *b = (element_output_type*)handle->b->data; +element_output_type *cst = (element_output_type*)handle->cst->data; +element_output_type *ht = (element_output_type*)handle->ht->data; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_output_type *cit = (element_output_type*)handle->cit->data; +element_output_type *cot = (element_output_type*)handle->cot->data; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[K]); +element_filter_type *wfD = &(w[2*K]); +element_filter_type *woD = &(w[3*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K]); +element_filter_type *rfD = &(r[2*K]); +element_filter_type *roD = &(r[3*K]); +element_filter_type *wiD_scratch = &(w_scratch[0]); +element_filter_type *wcD_scratch = &(w_scratch[C*K]); +element_filter_type *wfD_scratch = &(w_scratch[2*C*K]); +element_filter_type *woD_scratch = &(w_scratch[3*C*K]); +element_filter_type *riD_scratch = &(r_scratch[0]); +element_filter_type *rcD_scratch = &(r_scratch[K*K]); +element_filter_type *rfD_scratch = &(r_scratch[2*K*K]); +element_filter_type *roD_scratch = &(r_scratch[3*K*K]); +element_output_type *bi = &(b[0]); +element_output_type *bd = &(b[K]); +element_output_type *bf = &(b[2*K]); +element_output_type *bo = &(b[3*K]); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wo, woD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD_scratch, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, ro, roD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD_scratch, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(2, element_filter_type, wi_ck, wiD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, wf_ck, wfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, wo_ck, woD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, wc_ck, wcD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, ri_ck, riD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, rf_ck, rfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, ro_ck, roD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, rc_ck, rcD, 4*K); +LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); +/* define batch-reduce gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL ); +/* Auxiliary arrays for batch-reduce gemms */ +const element_filter_type *A_array[1024]; +const element_input_type *B_array[1024]; +element_output_type *cps_ptr = NULL; + +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; + +#ifdef PROFILE +__int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0, reformat_start, reformat_end, reformat_cycles = 0; +float total_time = 0.0; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} +if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} + +if (C == 2048 && K == 1024) { + BF = 2; +} + +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; + +/* Upfront reformatting of W and R */ +/* reformat W */ +#ifdef PROFILE +if (ltid == 0) reformat_start = _rdtsc(); +#endif +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(4, wi, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wi_ck, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, wc, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wc_ck, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, wf, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wf_ck, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, wo, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wo_ck, ic*bc+jc, ik*bk+jk, 4*K); + } + } +} + +/* reformat R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(4, ri, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, ri_ck, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, rc, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rc_ck, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, rf, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rf_ck, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(4, ro, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, ro_ck, ic*bk+jc, ik*bk+jk, 4*K); + } + } +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE +if (ltid == 0) { + reformat_end = _rdtsc(); + reformat_cycles = reformat_end - reformat_start; +} +#endif + +if (use_fused_implementation) { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused.tpl.c" +} else { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused.tpl.c" +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat weights time is %f ms (%.2f%%)\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("GEMM W*x time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9))); + printf("GEMM R*h time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9))); +} +#undef PROFILE +#endif diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..ab013d3b504dc0dac755ce7bfaffcbd22f8bd075 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c @@ -0,0 +1,283 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +#define MATRIX_CVT_BF16_FP32_LD(m, n, ld, _src, _dst) \ +do { \ + libxsmm_bfloat16 *src = _src; \ + float *dst = _dst; \ + libxsmm_blasint __i,__j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_storeu_ps((float*)&dst[(__j*ld)+__i], LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&src[(__j*ld)+__i]))); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD(m, n, ld, _srcdst, _colv) \ +do { \ + libxsmm_bfloat16 *colv = _colv; \ + float *srcdst = _srcdst; \ + libxsmm_blasint __i,__j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_storeu_ps((float*)&srcdst[(__j*ld)+__i], LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&colv[__i]))); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD(m, n, ld, _srcdst, _colv, const_bias) \ +do { \ + libxsmm_bfloat16 *colv = _colv; \ + float *srcdst = _srcdst; \ + libxsmm_blasint __i,__j; \ + __m512 vbias = _mm512_set1_ps(const_bias); \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_storeu_ps((float*)&srcdst[(__j*ld)+__i], _mm512_add_ps(vbias, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&colv[__i])))); \ + } \ + } \ +} while (0) + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, ic, /*icb,*/ inik, BF, CB, CB_BLOCKS, KB_BLOCKS, ikic, jk, jc; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const int lpb = handle->lpb; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +unsigned long long blocks, blocksa, blocksb; + +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_output_type *b = (element_output_type*)handle->b->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_filter_type *w_scratch = (element_filter_type*)handle->scratch_w; +element_filter_type *r_scratch = (element_filter_type*)handle->scratch_r; +/* These buffers are scratch for fp32 output of gemms (intermmediate results) */ +float *cst = (float*)handle->cst_scratch; +float *ht = (float*)handle->ht_scratch; +float *it = (float*)handle->it_scratch; +float *ft = (float*)handle->ft_scratch; +float *ot = (float*)handle->ot_scratch; +float *cit = (float*)handle->cit_scratch; +float *cot = (float*)handle->cot_scratch; +/* This has to be also upconverted since it is used in the elementwise functions */ +float *csp_f32 = (float*)handle->csp_scratch; +/* These are the output bf16 data */ +element_output_type *cst_bf16 = (element_output_type*)handle->cst->data; +element_output_type *ht_bf16 = (element_output_type*)handle->ht->data; +element_output_type *it_bf16 = (element_output_type*)handle->it->data; +element_output_type *ft_bf16 = (element_output_type*)handle->ft->data; +element_output_type *ot_bf16 = (element_output_type*)handle->ot->data; +element_output_type *cit_bf16 = (element_output_type*)handle->cit->data; +element_output_type *cot_bf16 = (element_output_type*)handle->cot->data; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[K]); +element_filter_type *wfD = &(w[2*K]); +element_filter_type *woD = &(w[3*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K]); +element_filter_type *rfD = &(r[2*K]); +element_filter_type *roD = &(r[3*K]); +element_filter_type *wiD_scratch = &(w_scratch[0]); +element_filter_type *wcD_scratch = &(w_scratch[C*K]); +element_filter_type *wfD_scratch = &(w_scratch[2*C*K]); +element_filter_type *woD_scratch = &(w_scratch[3*C*K]); +element_filter_type *riD_scratch = &(r_scratch[0]); +element_filter_type *rcD_scratch = &(r_scratch[K*K]); +element_filter_type *rfD_scratch = &(r_scratch[2*K*K]); +element_filter_type *roD_scratch = &(r_scratch[3*K*K]); +element_output_type *bi = &(b[0]); +element_output_type *bd = &(b[K]); +element_output_type *bf = &(b[2*K]); +element_output_type *bo = &(b[3*K]); +LIBXSMM_VLA_DECL(2, float, cp, csp_f32, K); +LIBXSMM_VLA_DECL(2, element_input_type, cp_bf16, csp, K); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD_scratch, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD_scratch, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD_scratch, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD_scratch, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD_scratch, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD_scratch, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD_scratch, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD_scratch, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, element_filter_type, wi_ck, wiD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, wf_ck, wfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, wo_ck, woD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, wc_ck, wcD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, ri_ck, riD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, rf_ck, rfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, ro_ck, roD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, rc_ck, rcD, 4*K); +LIBXSMM_VLA_DECL(3, float, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, float, h, ht, N, K); +LIBXSMM_VLA_DECL(3, float, i, it, N, K); +LIBXSMM_VLA_DECL(3, float, f, ft, N, K); +LIBXSMM_VLA_DECL(3, float, o, ot, N, K); +LIBXSMM_VLA_DECL(3, float, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, float, co, cot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, cs_out, cst_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h_out, ht_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i_out, it_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f_out, ft_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o_out, ot_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci_out, cit_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co_out, cot_bf16, N, K); +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->fwd_kernela; +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->fwd_kernelb; + +float *cps_ptr = NULL; + +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; + +#ifdef PROFILE +__int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0, reformat_start, reformat_end, reformat_cycles = 0; +float total_time = 0.0; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} +if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} + +if (C == 2048 && K == 1024) { + BF = 2; +} + +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; + +/* Upfront reformatting of W and R */ +/* reformat W */ +#ifdef PROFILE +if (ltid == 0) reformat_start = _rdtsc(); +#endif +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc;++jc) { + LIBXSMM_VLA_ACCESS(5, wi, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, wi_ck, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wc, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, wc_ck, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wf, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, wf_ck, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wo, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, wo_ck, ic*bc+jc, ik*bk+jk, 4*K); + } + } +} + +/* reformat R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(5, ri, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, ri_ck, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, rc, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, rc_ck, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, rf, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, rf_ck, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, ro, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, ro_ck, ic*bk+jc, ik*bk+jk, 4*K); + } + } +} + +/* Upconvert the cp input to fp32 that is used for elementwise stuff */ +for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + libxsmm_internal_matrix_cvt_bf16_fp32_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, cp_bf16, in, ik, K), &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K)); +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE +if (ltid == 0) { + reformat_end = _rdtsc(); + reformat_cycles = reformat_end - reformat_start; +} +#endif + +if (use_fused_implementation) { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16.tpl.c" +} else { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16.tpl.c" +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat weights time is %f ms (%.2f%%)\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("GEMM W*x time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9))); + printf("GEMM R*h time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9))); +} +#undef PROFILE +#endif + +#undef MATRIX_CVT_BF16_FP32_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..b9cc50c31b7d6dcd8b50ad2b2b1fd28d5820418c --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16_amx.tpl.c @@ -0,0 +1,291 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +#define MATRIX_CVT_BF16_FP32_LD(m, n, ld, _src, _dst) \ +do { \ + libxsmm_bfloat16 *__src = _src; \ + float *__dst = _dst; \ + libxsmm_blasint __i, __j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_store_ps((float*)&__dst[(__j*ld)+__i], _mm512_loadcvt_bf16_fp32(&__src[(__j*ld)+__i])); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD(m, n, ld, _srcdst, _colv) \ +do { \ + libxsmm_bfloat16 *__colv = _colv; \ + float *__srcdst = _srcdst; \ + libxsmm_blasint __i, __j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_store_ps((float*)&__srcdst[(__j*ld)+__i], _mm512_loadcvt_bf16_fp32(&__colv[__i])); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD(m, n, ld, _srcdst, _colv, const_bias) \ +do { \ + libxsmm_bfloat16 *__colv = _colv; \ + float *__srcdst = _srcdst; \ + libxsmm_blasint __i, __j; \ + __m512 __vbias = _mm512_set1_ps(const_bias); \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_store_ps((float*)&__srcdst[(__j*ld)+__i], _mm512_add_ps(__vbias, _mm512_loadcvt_bf16_fp32(&__colv[__i]))); \ + } \ + } \ +} while (0) + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, ic, inik, BF, CB, CB_BLOCKS, KB_BLOCKS, ikic, jk, jc; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const int lpb = handle->lpb; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +unsigned long long blocks, blocksa, blocksb; + +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_filter_type *w_scratch = (element_filter_type*)handle->scratch_w; +element_filter_type *r_scratch = (element_filter_type*)handle->scratch_r; +element_output_type *b = (element_output_type*)handle->b->data; +/* These buffers are scratch for fp32 output of gemms (intermmediate results) */ +float *cst = (float*)handle->cst_scratch; +float *ht = (float*)handle->ht_scratch; +float *it = (float*)handle->it_scratch; +float *ft = (float*)handle->ft_scratch; +float *ot = (float*)handle->ot_scratch; +float *cit = (float*)handle->cit_scratch; +float *cot = (float*)handle->cot_scratch; +/* This has to be also upconverted since it is used in the elementwise functions */ +float *csp_f32 = (float*)handle->csp_scratch; +/* These are the output bf16 data */ +element_output_type *cst_bf16 = (element_output_type*)handle->cst->data; +element_output_type *ht_bf16 = (element_output_type*)handle->ht->data; +element_output_type *it_bf16 = (element_output_type*)handle->it->data; +element_output_type *ft_bf16 = (element_output_type*)handle->ft->data; +element_output_type *ot_bf16 = (element_output_type*)handle->ot->data; +element_output_type *cit_bf16 = (element_output_type*)handle->cit->data; +element_output_type *cot_bf16 = (element_output_type*)handle->cot->data; + +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[K]); +element_filter_type *wfD = &(w[2*K]); +element_filter_type *woD = &(w[3*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K]); +element_filter_type *rfD = &(r[2*K]); +element_filter_type *roD = &(r[3*K]); +element_filter_type *wiD_scratch = &(w_scratch[0]); +element_filter_type *wcD_scratch = &(w_scratch[C*K]); +element_filter_type *wfD_scratch = &(w_scratch[2*C*K]); +element_filter_type *woD_scratch = &(w_scratch[3*C*K]); +element_filter_type *riD_scratch = &(r_scratch[0]); +element_filter_type *rcD_scratch = &(r_scratch[K*K]); +element_filter_type *rfD_scratch = &(r_scratch[2*K*K]); +element_filter_type *roD_scratch = &(r_scratch[3*K*K]); +element_output_type *bi = &(b[0]); +element_output_type *bd = &(b[K]); +element_output_type *bf = &(b[2*K]); +element_output_type *bo = &(b[3*K]); +LIBXSMM_VLA_DECL(2, float, cp, csp_f32, K); +LIBXSMM_VLA_DECL(2, element_input_type, cp_bf16, csp, K); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD_scratch, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD_scratch, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD_scratch, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD_scratch, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD_scratch, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD_scratch, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD_scratch, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD_scratch, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(2, element_filter_type, wi_ck, wiD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, wf_ck, wfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, wo_ck, woD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, wc_ck, wcD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, ri_ck, riD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, rf_ck, rfD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, ro_ck, roD, 4*K); +LIBXSMM_VLA_DECL(2, element_filter_type, rc_ck, rcD, 4*K); +LIBXSMM_VLA_DECL(3, float, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, float, h, ht, N, K); +LIBXSMM_VLA_DECL(3, float, i, it, N, K); +LIBXSMM_VLA_DECL(3, float, f, ft, N, K); +LIBXSMM_VLA_DECL(3, float, o, ot, N, K); +LIBXSMM_VLA_DECL(3, float, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, float, co, cot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, cs_out, cst_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h_out, ht_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i_out, it_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f_out, ft_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o_out, ot_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci_out, cit_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co_out, cot_bf16, N, K); + +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->fwd_kernela; /*= libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, &kernel_flags, NULL );*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->fwd_kernelb; /* libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL );*/ +const libxsmm_bsmmfunction_reducebatch_addr tile_config_kernel = handle->fwd_tileconfig; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL );*/ + +float *cps_ptr = NULL; + +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; + +#ifdef PROFILE +__int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0, reformat_start, reformat_end, reformat_cycles = 0; +float total_time = 0.0; +#endif + +/* Hoist tileconfig if possible */ +if ((bk % 32 == 0) && (bc % 32 == 0) && (bn % 32 == 0)) { + tile_config_kernel(NULL, NULL, NULL, NULL); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} +if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} + +if (C == 2048 && K == 1024) { + BF = 2; +} + +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; + +/* Upfront reformatting of W and R */ +/* reformat W */ +#ifdef PROFILE +if (ltid == 0) reformat_start = _rdtsc(); +#endif +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk)); + ik = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc;++jc) { + LIBXSMM_VLA_ACCESS(5, wi, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, wi_ck, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wc, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, wc_ck, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wf, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, wf_ck, ic*bc+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, wo, ik, ic, jc/lpb, jk, jc%lpb, cBlocks, bc_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, wo_ck, ic*bc+jc, ik*bk+jk, 4*K); + } + } +} + +/* reformat R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(5, ri, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, ri_ck, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, rc, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, rc_ck, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, rf, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, rf_ck, ic*bk+jc, ik*bk+jk, 4*K); + LIBXSMM_VLA_ACCESS(5, ro, ik, ic, jc/lpb, jk, jc%lpb, kBlocks, bk_lp, bk, lpb) = LIBXSMM_VLA_ACCESS(2, ro_ck, ic*bk+jc, ik*bk+jk, 4*K); + } + } +} + +/* Upconvert the cp input to fp32 that is used for elementwise stuff */ +for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + MATRIX_CVT_BF16_FP32_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, cp_bf16, in, ik, K), &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K)); +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE +if (ltid == 0) { + reformat_end = _rdtsc(); + reformat_cycles = reformat_end - reformat_start; +} +#endif + +if (use_fused_implementation) { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16_amx.tpl.c" +} else { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16_amx.tpl.c" +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat weights time is %f ms (%.2f%%)\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("GEMM W*x time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9))); + printf("GEMM R*h time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9))); +} +#undef PROFILE +#endif + +#undef MATRIX_CVT_BF16_FP32_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..a581f284e579eecbfec1cfa2d507a20b62863acf --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c @@ -0,0 +1,138 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, ic, icb, inik, BF, CB, CB_BLOCKS, KB_BLOCKS; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +unsigned long long blocks; + +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *b = (element_output_type*)handle->b->data; +element_output_type *cst = (element_output_type*)handle->cst->data; +element_output_type *ht = (element_output_type*)handle->ht->data; +element_output_type *it = (element_output_type*)handle->it->data; +element_output_type *ft = (element_output_type*)handle->ft->data; +element_output_type *ot = (element_output_type*)handle->ot->data; +element_output_type *cit = (element_output_type*)handle->cit->data; +element_output_type *cot = (element_output_type*)handle->cot->data; +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[C*K]); +element_filter_type *wfD = &(w[2*C*K]); +element_filter_type *woD = &(w[3*C*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K*K]); +element_filter_type *rfD = &(r[2*K*K]); +element_filter_type *roD = &(r[3*K*K]); +element_output_type *bi = &(b[0]); +element_output_type *bd = &(b[K]); +element_output_type *bf = &(b[2*K]); +element_output_type *bo = &(b[3*K]); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wo, woD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, ro, roD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); +/* define batch-reduce gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, NULL, NULL ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL ); +/* Auxiliary arrays for batch-reduce gemms */ +const element_filter_type *A_array[1024]; +const element_input_type *B_array[1024]; +element_output_type *cps_ptr = NULL; + +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; +#ifdef PROFILE +__int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0; +float total_time = 0.0; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} +if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} + +if (C == 2048 && K == 1024) { + BF = 2; +} + +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; + +if (use_fused_implementation) { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused.tpl.c" +} else { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused.tpl.c" +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles)/(2.5 * 1e9)*1000.0f; + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("GEMM W*x time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9))); + printf("GEMM R*h time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9))); +} +#undef PROFILE +#endif diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..39526f8fb8bc754098f41a327019de99ea8b8c9f --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c @@ -0,0 +1,223 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +#define MATRIX_CVT_BF16_FP32_LD(m, n, ld, _src, _dst) \ +do { \ + libxsmm_bfloat16 *src = _src; \ + float *dst = _dst; \ + libxsmm_blasint __i,__j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_storeu_ps((float*)&dst[(__j*ld)+__i], LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&src[(__j*ld)+__i]))); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD(m, n, ld, _srcdst, _colv) \ +do { \ + libxsmm_bfloat16 *colv = _colv; \ + float *srcdst = _srcdst; \ + libxsmm_blasint __i,__j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_storeu_ps((float*)&srcdst[(__j*ld)+__i], LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&colv[__i]))); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD(m, n, ld, _srcdst, _colv, const_bias) \ +do { \ + libxsmm_bfloat16 *colv = _colv; \ + float *srcdst = _srcdst; \ + libxsmm_blasint __i,__j; \ + __m512 vbias = _mm512_set1_ps(const_bias); \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_storeu_ps((float*)&srcdst[(__j*ld)+__i], _mm512_add_ps(vbias, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&colv[__i])))); \ + } \ + } \ +} while (0) + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, /*ic, icb,*/ inik, BF, CB, CB_BLOCKS, KB_BLOCKS; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +int lpb = 2; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +unsigned long long blocks, blocksa, blocksb; + +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *b = (element_output_type*)handle->b->data; + +/* These buffers are scratch for fp32 output of gemms (intermmediate results) */ +float *cst = (float*)handle->cst_scratch; +float *ht = (float*)handle->ht_scratch; +float *it = (float*)handle->it_scratch; +float *ft = (float*)handle->ft_scratch; +float *ot = (float*)handle->ot_scratch; +float *cit = (float*)handle->cit_scratch; +float *cot = (float*)handle->cot_scratch; +/* This has to be also upconverted since it is used in the elementwise functions */ +float *csp_f32 = (float*)handle->csp_scratch; +/* These are the output bf16 data */ +element_output_type *cst_bf16 = (element_output_type*)handle->cst->data; +element_output_type *ht_bf16 = (element_output_type*)handle->ht->data; +element_output_type *it_bf16 = (element_output_type*)handle->it->data; +element_output_type *ft_bf16 = (element_output_type*)handle->ft->data; +element_output_type *ot_bf16 = (element_output_type*)handle->ot->data; +element_output_type *cit_bf16 = (element_output_type*)handle->cit->data; +element_output_type *cot_bf16 = (element_output_type*)handle->cot->data; + +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[C*K]); +element_filter_type *wfD = &(w[2*C*K]); +element_filter_type *woD = &(w[3*C*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K*K]); +element_filter_type *rfD = &(r[2*K*K]); +element_filter_type *roD = &(r[3*K*K]); +element_output_type *bi = &(b[0]); +element_output_type *bd = &(b[K]); +element_output_type *bf = &(b[2*K]); +element_output_type *bo = &(b[3*K]); +LIBXSMM_VLA_DECL(2, float, cp, csp_f32, K); +LIBXSMM_VLA_DECL(2, element_input_type, cp_bf16, csp, K); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(3, float, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, float, h, ht, N, K); +LIBXSMM_VLA_DECL(3, float, i, it, N, K); +LIBXSMM_VLA_DECL(3, float, f, ft, N, K); +LIBXSMM_VLA_DECL(3, float, o, ot, N, K); +LIBXSMM_VLA_DECL(3, float, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, float, co, cot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, cs_out, cst_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h_out, ht_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i_out, it_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f_out, ft_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o_out, ot_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci_out, cit_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co_out, cot_bf16, N, K); +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->fwd_kernela; +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->fwd_kernelb; + +float *cps_ptr = NULL; + +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; + +#ifdef PROFILE +__int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0, reformat_start, reformat_end, reformat_cycles = 0; +float total_time = 0.0; +#endif + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} +if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} + +if (C == 2048 && K == 1024) { + BF = 2; +} + +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; + +#ifdef PROFILE +if (ltid == 0) reformat_start = _rdtsc(); +#endif + +/* Upconvert the cp input to fp32 that is used for elementwise stuff */ +for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + MATRIX_CVT_BF16_FP32_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, cp_bf16, in, ik, K), &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K)); +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE +if (ltid == 0) { + reformat_end = _rdtsc(); + reformat_cycles = reformat_end - reformat_start; +} +#endif + +if (use_fused_implementation) { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16.tpl.c" +} else { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16.tpl.c" +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat weights time is %f ms (%.2f%%)\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("GEMM W*x time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9))); + printf("GEMM R*h time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9))); +} +#undef PROFILE +#endif + +#undef MATRIX_CVT_BF16_FP32_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..4948cd24d2f2cb9d24d1a60ef2f7d1b8283c2e79 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16_amx.tpl.c @@ -0,0 +1,236 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +#define MATRIX_CVT_BF16_FP32_LD(m, n, ld, _src, _dst) \ +do { \ + libxsmm_bfloat16 *__src = _src; \ + float *const __dst = _dst; \ + libxsmm_blasint __i, __j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_store_ps((float*)&__dst[(__j*ld)+__i], _mm512_loadcvt_bf16_fp32(&__src[(__j*ld)+__i])); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD(m, n, ld, _srcdst, _colv) \ +do { \ + libxsmm_bfloat16 *__colv = _colv; \ + float *__srcdst = _srcdst; \ + libxsmm_blasint __i, __j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_store_ps((float*)&__srcdst[(__j*ld)+__i], _mm512_loadcvt_bf16_fp32(&__colv[__i])); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD(m, n, ld, _srcdst, _colv, const_bias) \ +do { \ + libxsmm_bfloat16 *__colv = _colv; \ + float *__srcdst = _srcdst; \ + libxsmm_blasint __i, __j; \ + __m512 __vbias = _mm512_set1_ps(const_bias); \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_store_ps((float*)&__srcdst[(__j*ld)+__i], _mm512_add_ps(__vbias, _mm512_loadcvt_bf16_fp32(&__colv[__i]))); \ + } \ + } \ +} while (0) + +/* helper variables */ +libxsmm_blasint j, ik, ikb, in, /*ic, icb,*/ inik, BF, CB, CB_BLOCKS, KB_BLOCKS; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const int lpb = 2; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +unsigned long long blocks, blocksa, blocksb; + +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *b = (element_output_type*)handle->b->data; + +/* These buffers are scratch for fp32 output of gemms (intermmediate results) */ +float *cst = (float*)handle->cst_scratch; +float *ht = (float*)handle->ht_scratch; +float *it = (float*)handle->it_scratch; +float *ft = (float*)handle->ft_scratch; +float *ot = (float*)handle->ot_scratch; +float *cit = (float*)handle->cit_scratch; +float *cot = (float*)handle->cot_scratch; +/* This has to be also upconverted since it is used in the elementwise functions */ +float *csp_f32 = (float*)handle->csp_scratch; +/* These are the output bf16 data */ +element_output_type *cst_bf16 = (element_output_type*)handle->cst->data; +element_output_type *ht_bf16 = (element_output_type*)handle->ht->data; +element_output_type *it_bf16 = (element_output_type*)handle->it->data; +element_output_type *ft_bf16 = (element_output_type*)handle->ft->data; +element_output_type *ot_bf16 = (element_output_type*)handle->ot->data; +element_output_type *cit_bf16 = (element_output_type*)handle->cit->data; +element_output_type *cot_bf16 = (element_output_type*)handle->cot->data; + +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[C*K]); +element_filter_type *wfD = &(w[2*C*K]); +element_filter_type *woD = &(w[3*C*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K*K]); +element_filter_type *rfD = &(r[2*K*K]); +element_filter_type *roD = &(r[3*K*K]); +element_output_type *bi = &(b[0]); +element_output_type *bd = &(b[K]); +element_output_type *bf = &(b[2*K]); +element_output_type *bo = &(b[3*K]); +LIBXSMM_VLA_DECL(2, float, cp, csp_f32, K); +LIBXSMM_VLA_DECL(2, element_input_type, cp_bf16, csp, K); +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(3, float, cs, cst, N, K); +LIBXSMM_VLA_DECL(3, float, h, ht, N, K); +LIBXSMM_VLA_DECL(3, float, i, it, N, K); +LIBXSMM_VLA_DECL(3, float, f, ft, N, K); +LIBXSMM_VLA_DECL(3, float, o, ot, N, K); +LIBXSMM_VLA_DECL(3, float, ci, cit, N, K); +LIBXSMM_VLA_DECL(3, float, co, cot, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, cs_out, cst_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, h_out, ht_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, i_out, it_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, f_out, ft_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, o_out, ot_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, ci_out, cit_bf16, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, co_out, cot_bf16, N, K); +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->fwd_kernela; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, &kernel_flags, NULL );*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->fwd_kernelb; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL );*/ +const libxsmm_bsmmfunction_reducebatch_addr tile_config_kernel = handle->fwd_tileconfig; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL );*/ + +/* Auxiliary arrays for batch-reduce gemms */ +#if 0 +const element_filter_type *A_array[1024]; +const element_input_type *B_array[1024]; +#endif +float *cps_ptr = NULL; + +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; +const int use_fused_implementation = handle->use_fwd_fused_impl; /*(C == 2048 && K == 2048) ? 1 : 0;*/ + +#ifdef PROFILE +__int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0, reformat_start, reformat_end, reformat_cycles = 0; +float total_time = 0.0; +#endif + +/* Hoist tileconfig if possible */ +if ((bk % 32 == 0) && (bc % 32 == 0) && (bn % 32 == 0)) { + tile_config_kernel(NULL, NULL, NULL, NULL); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} +if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} + +if (C == 2048 && K == 1024) { + BF = 2; +} + +/* Overwrite the blocking factor based on the value passed onto the descriptor */ +BF = handle->fwd_block; + +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; + +#ifdef PROFILE +if (ltid == 0) reformat_start = _rdtsc(); +#endif + +/* Upconvert the cp input to fp32 that is used for elementwise stuff */ +for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + MATRIX_CVT_BF16_FP32_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, cp_bf16, in, ik, K), &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K)); +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE +if (ltid == 0) { + reformat_end = _rdtsc(); + reformat_cycles = reformat_end - reformat_start; +} +#endif + +if (use_fused_implementation) { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16_amx.tpl.c" +} else { +#include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16_amx.tpl.c" +} + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat weights time is %f ms (%.2f%%)\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("GEMM W*x time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9))); + printf("GEMM R*h time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9))); +} +#undef PROFILE +#endif + +#undef MATRIX_CVT_BF16_FP32_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..f2d535f4142b7fe6a5f665bfc520697a47a0d4f0 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused.tpl.c @@ -0,0 +1,254 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ + +/* First perform the W*x part of the output */ +for (j = 0; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize ci with bd */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); + /* ci += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize f with (bf + forget_bias) */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_const_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); + /* f += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize o with bo */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); + /* o += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wo, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif + } + } +} + +/* Compute the R*h part of the output */ +for (j = 0; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) : &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..74103dc7eede54c3b510ba8a998b1518ee5af516 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16.tpl.c @@ -0,0 +1,331 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ + +#define NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(m, n, ld, _src, _dst) \ +do { \ + float *const src = _src; \ + libxsmm_bfloat16 *const dst = _dst; \ + libxsmm_blasint __i, __j; \ + __m512i packed_result; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=32 ) { \ + packed_result = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&src[(__j*ld)+__i+16]), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&src[(__j*ld)+__i])); \ + _mm512_storeu_si512(&dst[(__j*ld)+__i], packed_result); \ + } \ + } \ +} while (0) + +/* First perform the W*x part of the output */ +blocks = CB_BLOCKS; +for (j = 0; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wi, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize ci with bd */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); + /* ci += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wc, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize f with (bf + forget_bias) */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); + /* f += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wf, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize o with bo */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); + /* o += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wo, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + + +/* Compute the R*h part of the output */ +blocks = KB_BLOCKS; +/* Peel off the t=0 iteration to hoist the innermost if conditions */ +j = 0; +/* let's run the cell in blocks for good locality */ +/* Block reduction loop if requested */ +for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, i, 0, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, ci, 0, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, f, 0, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, o, 0, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K); + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + /* Downconvert computed results to bf16 output buffers */ + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co_out, j, in, ik, N, K)); + +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } +} +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +for (j = 1; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + /* Downconvert computed results to bf16 output buffers */ + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co_out, j, in, ik, N, K)); +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +#undef NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..93cfb6d34f1efd739e5d3fc28f22d475512a27a1 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16_amx.tpl.c @@ -0,0 +1,331 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#define NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(m, n, ld, _src, _dst) \ +do { \ + float *const __src = _src; \ + libxsmm_bfloat16 *__dst = _dst; \ + libxsmm_blasint __i, __j; \ + __m512i __packed_result; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=32 ) { \ + __packed_result = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i+16]), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i])); \ + _mm512_storeu_si512((libxsmm_bfloat16*)&__dst[(__j*ld)+__i], (__m512i) __packed_result); \ + } \ + } \ +} while (0) + +/* First perform the W*x part of the output */ +blocks = CB_BLOCKS; +for (j = 0; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wi, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize ci with bd */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); + /* ci += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wc, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize f with (bf + forget_bias) */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); + /* f += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wf, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif + +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize o with bo */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); + /* o += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wo, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + + +/* Compute the R*h part of the output */ +blocks = KB_BLOCKS; +/* Peel off the t=0 iteration to hoist the innermost if conditions */ +j = 0; +/* let's run the cell in blocks for good locality */ +/* Block reduction loop if requested */ +for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, i, 0, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, ci, 0, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, f, 0, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, o, 0, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K); + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + /* Downconvert computed results to bf16 output buffers */ + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co_out, j, in, ik, N, K)); + +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } +} +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +for (j = 1; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + /* Downconvert computed results to bf16 output buffers */ + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co_out, j, in, ik, N, K)); + +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +#undef NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..d4894a4e507b4824c855365b5e86f2f09f07a205 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused.tpl.c @@ -0,0 +1,237 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ + +/* All data is in column-major format */ +for (j = 0; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize ci with bd */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); + /* ci += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize f with (bf + forget_bias) */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_const_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); + /* f += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize o with bo */ + if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); + /* o += W.x */ + for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wo, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + if (0 == j) { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); + } + } else { + for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { + A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); + } + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) : &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..49e8e63a0cb67642396795a9bbc552e30dc74398 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16.tpl.c @@ -0,0 +1,374 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ + +#define NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(m, n, ld, _src, _dst) \ +do { \ + float *const src = _src; \ + libxsmm_bfloat16 *const dst = _dst; \ + libxsmm_blasint __i, __j; \ + __m512i packed_result; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=32 ) { \ + packed_result = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&src[(__j*ld)+__i+16]), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&src[(__j*ld)+__i])); \ + _mm512_storeu_si512(&dst[(__j*ld)+__i], packed_result); \ + } \ + } \ +} while (0) + +blocksa = CB_BLOCKS; +blocksb = KB_BLOCKS; + +/* All data is in column-major format */ +/* Peel off the t=0 iteration to hoist the innermost if conditions */ +j = 0; +for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wi, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, i, 0, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize ci with bd */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); + /* ci += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wc, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, ci, 0, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize f with (bf + forget_bias) */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); + /* f += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wf, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, f, 0, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize o with bo */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); + /* o += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wo, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, o, 0, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K); + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + /* Downconvert computed results to bf16 output buffers */ + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co_out, j, in, ik, N, K)); +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } +} +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +for (j = 1; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wi, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize ci with bd */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); + /* ci += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wc, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize f with (bf + forget_bias) */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); + /* f += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wf, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize o with bo */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); + /* o += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wo, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + /* Downconvert computed results to bf16 output buffers */ + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co_out, j, in, ik, N, K)); + +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +#undef NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..01a9d4afdb7f3562be4aaa7968a3f363d285ab08 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16_amx.tpl.c @@ -0,0 +1,374 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#define NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(m, n, ld, _src, _dst) \ +do { \ + float *const __src = _src; \ + libxsmm_bfloat16 *__dst = _dst; \ + libxsmm_blasint __i, __j; \ + __m512i __packed_result; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=32 ) { \ + __packed_result = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i+16]), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i])); \ + _mm512_storeu_si512((libxsmm_bfloat16*)&__dst[(__j*ld)+__i], (__m512i) __packed_result); \ + } \ + } \ +} while (0) + +blocksa = CB_BLOCKS; +blocksb = KB_BLOCKS; + +/* All data is in column-major format */ +/* Peel off the t=0 iteration to hoist the innermost if conditions */ +j = 0; +for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wi, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, i, 0, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize ci with bd */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); + /* ci += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wc, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, ci, 0, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize f with (bf + forget_bias) */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); + /* f += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wf, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, f, 0, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize o with bo */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); + /* o += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wo, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(2, hp, in, CB*KB_BLOCKS*bk, K), + &LIBXSMM_VLA_ACCESS(3, o, 0, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) ; + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + /* Downconvert computed results to bf16 output buffers */ + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co_out, j, in, ik, N, K)); + +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } +} +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +for (j = 1; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik % (N/bn))*bn; + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); + /* i += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wi, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize ci with bd */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); + /* ci += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wc, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize f with (bf + forget_bias) */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); + /* f += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wf, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* initialize o with bo */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); + /* o += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wo, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, x, j, in, CB*CB_BLOCKS*bc, N, C), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocksa); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(3, h_out, j-1, in, CB*KB_BLOCKS*bk, N, K), + &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocksb); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + + if (CB == BF-1) { +#ifdef PROFILE + if (ltid == 0) { + eltwise_start = _rdtsc(); + } +#endif + cps_ptr = &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K) ; + /* Compute i, ci, f, o, cs, co and h */ +#if defined(LIBXSMM_RNN_CELL_AVX512) + if (bk % 16 == 0 && bc % 16 == 0) { +#include "libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c" + } else { + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); + } +#else + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); + libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); +#endif + /* Downconvert computed results to bf16 output buffers */ + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci_out, j, in, ik, N, K)); + NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(bk, bn, K, &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co_out, j, in, ik, N, K)); + +#ifdef PROFILE + if (ltid == 0) { + eltwise_end = _rdtsc(); + eltwise_cycles += eltwise_end-eltwise_start; + } +#endif + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +#undef NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..9ac18aff15a45a583f591a5aefafb092e0b32e37 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_bf16_amx.tpl.c @@ -0,0 +1,226 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Kunal Banerjee (Intel Corp.) +******************************************************************************/ +#if 0 +#define PROFILE +#endif + +#define MATRIX_CVT_BF16_FP32_LD(m, n, ld, _src, _dst) \ +do { \ + libxsmm_bfloat16 *__src = _src; \ + float *__dst = _dst; \ + libxsmm_blasint __i, __j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_store_ps((float*)&__dst[(__j*ld)+__i], _mm512_loadcvt_bf16_fp32(&__src[(__j*ld)+__i])); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD(m, n, ld, _srcdst, _colv) \ +do { \ + libxsmm_bfloat16 *__colv = _colv; \ + float *__srcdst = _srcdst; \ + libxsmm_blasint __i, __j; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_store_ps((float*)&__srcdst[(__j*ld)+__i], _mm512_loadcvt_bf16_fp32(&__colv[__i])); \ + } \ + } \ +} while (0) + +#define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD(m, n, ld, _srcdst, _colv, const_bias) \ +do { \ + libxsmm_bfloat16 *__colv = _colv; \ + float *__srcdst = _srcdst; \ + libxsmm_blasint __i, __j; \ + __m512 __vbias = _mm512_set1_ps(const_bias); \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=16 ) { \ + _mm512_store_ps((float*)&__srcdst[(__j*ld)+__i], _mm512_add_ps(__vbias, _mm512_loadcvt_bf16_fp32(&__colv[__i]))); \ + } \ + } \ +} while (0) + +/* helper variables */ +libxsmm_blasint j, ik, ikb, /*in,*/ inb, /*ic, icb,*/ inik, BF, CB, CB_BLOCKS, KB_BLOCKS; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +const libxsmm_blasint cBlocks = C/bc; +const libxsmm_blasint kBlocks = K/bk; +const libxsmm_blasint nBlocks = N/bn; +const int lpb = 2; +const int bc_lp = bc/lpb; +const int bk_lp = bk/lpb; +unsigned long long blocks/*, blocksa, blocksb*/; + +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *csp = (element_input_type* )handle->csp->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *w = (element_filter_type*)handle->w->data; +element_filter_type *r = (element_filter_type*)handle->r->data; +element_output_type *b = (element_output_type*)handle->b->data; + +/* These buffers are scratch for fp32 output of gemms (intermmediate results) */ +float *cst = (float*)handle->cst_scratch; +/*float *ht = (float*)handle->ht_scratch;*/ +float *it = (float*)handle->it_scratch; +float *ft = (float*)handle->ft_scratch; +float *ot = (float*)handle->ot_scratch; +float *cit = (float*)handle->cit_scratch; +/*float *cot = (float*)handle->cot_scratch;*/ +/* This has to be also upconverted since it is used in the elementwise functions */ +float *csp_f32 = (float*)handle->csp_scratch; +/* These are the output bf16 data */ +element_output_type *cst_bf16 = (element_output_type*)handle->cst->data; +element_output_type *ht_bf16 = (element_output_type*)handle->ht->data; +element_output_type *it_bf16 = (element_output_type*)handle->it->data; +element_output_type *ft_bf16 = (element_output_type*)handle->ft->data; +element_output_type *ot_bf16 = (element_output_type*)handle->ot->data; +element_output_type *cit_bf16 = (element_output_type*)handle->cit->data; +element_output_type *cot_bf16 = (element_output_type*)handle->cot->data; + +element_filter_type *wiD = &(w[0]); +element_filter_type *wcD = &(w[C*K]); +element_filter_type *wfD = &(w[2*C*K]); +element_filter_type *woD = &(w[3*C*K]); +element_filter_type *riD = &(r[0]); +element_filter_type *rcD = &(r[K*K]); +element_filter_type *rfD = &(r[2*K*K]); +element_filter_type *roD = &(r[3*K*K]); +element_output_type *bi = &(b[0]); +element_output_type *bd = &(b[K]); +element_output_type *bf = &(b[2*K]); +element_output_type *bo = &(b[3*K]); +LIBXSMM_VLA_DECL(4, float, cp, csp_f32, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_input_type, cp_bf16, csp, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_input_type, x, xt, nBlocks, cBlocks, bn,bc); +LIBXSMM_VLA_DECL(4, element_input_type, hp, hpD, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD, cBlocks, bc_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD, kBlocks, bk_lp, bk, lpb); +LIBXSMM_VLA_DECL(5, float, cs, cst, nBlocks, kBlocks, bn, bk); +/*LIBXSMM_VLA_DECL(5, float, h, ht, nBlocks, kBlocks, bn, bk);*/ +LIBXSMM_VLA_DECL(5, float, i, it, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, float, f, ft, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, float, o, ot, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, float, ci, cit, nBlocks, kBlocks, bn, bk); +/*LIBXSMM_VLA_DECL(5, float, co, cot, nBlocks, kBlocks, bn, bk);*/ +LIBXSMM_VLA_DECL(5, element_output_type, cs_out, cst_bf16, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, h_out, ht_bf16, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, i_out, it_bf16, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, f_out, ft_bf16, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, o_out, ot_bf16, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, ci_out, cit_bf16, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, co_out, cot_bf16, nBlocks, kBlocks, bn, bk); +/* define batch-reduce gemm kernels */ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->fwd_kernela; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, &kernel_flags, NULL );*/ +const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->fwd_kernelb; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL );*/ +const libxsmm_bsmmfunction_reducebatch_addr tile_config_kernel = handle->fwd_tileconfig; /*libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL );*/ + +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +#ifdef PROFILE +__int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0, reformat_start, reformat_end, reformat_cycles = 0; +float total_time = 0.0; +#endif + +/* Hoist tileconfig if possible */ +if ((bk % 32 == 0) && (bc % 32 == 0) && (bn % 32 == 0)) { + tile_config_kernel(NULL, NULL, NULL, NULL); +} + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { + BF = 8; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} +if (C > 2048 || K > 2048) { + BF = 16; + while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { + BF--; + } +} + +if (C == 2048 && K == 1024) { + BF = 2; +} + +/* Overwrite the blocking factor based on the value passed onto the descriptor */ +BF = handle->fwd_block; + +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; + +#ifdef PROFILE +if (ltid == 0) reformat_start = _rdtsc(); +#endif + +/* Upconvert the cp input to fp32 that is used for elementwise stuff */ +for (inik = thr_begin; inik < thr_end; ++inik ) { + inb = inik % (N/bn); + ikb = inik / (N/bn); + MATRIX_CVT_BF16_FP32_LD( bk, bn, bk, &LIBXSMM_VLA_ACCESS(4, cp_bf16, inb, ikb, 0, 0, kBlocks, bn, bk), &LIBXSMM_VLA_ACCESS(4, cp, inb, ikb, 0, 0, kBlocks, bn, bk)); +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); +#ifdef PROFILE +if (ltid == 0) { + reformat_end = _rdtsc(); + reformat_cycles = reformat_end - reformat_start; +} +#endif + +#include "libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_diffused_bf16_amx.tpl.c" + +handle->tilerelease_kernel(NULL, NULL, NULL); + +#ifdef PROFILE +if (ltid == 0) { + printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); + total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; + printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("Reformat weights time is %f ms (%.2f%%)\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); + printf("GEMM W*x time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9))); + printf("GEMM R*h time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9))); +} +#undef PROFILE +#endif + +#undef MATRIX_CVT_BF16_FP32_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD +#undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_diffused_bf16_amx.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_diffused_bf16_amx.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..dc2d0fe67a90f0b32bc56c654b2d4dd462dff024 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_diffused_bf16_amx.tpl.c @@ -0,0 +1,409 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.) +******************************************************************************/ +#define NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD(m, n, ld, _src, _dst) \ +do { \ + float *__src = _src; \ + libxsmm_bfloat16 *__dst = _dst; \ + libxsmm_blasint __i, __j; \ + __m512i __packed_result; \ + for ( __j = 0; __j < n; ++__j ) { \ + for ( __i = 0; __i < m; __i+=32 ) { \ + __packed_result = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i+16]), LIBXSMM_INTRINSICS_MM512_LOAD_PS((float*)&__src[(__j*ld)+__i])); \ + _mm512_storeu_si512((libxsmm_bfloat16*)&__dst[(__j*ld)+__i], (__m512i) __packed_result); \ + } \ + } \ +} while (0) + +/* First perform the W*x part of the output */ +blocks = CB_BLOCKS; +for (j = 0; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + inb = inik % (N/bn); + ikb = inik / (N/bn); + ik = ikb*bk; + /* initialize i with bi */ +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, i, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &bi[ik] ); + /* i += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wi, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(5, x, j, inb, CB*CB_BLOCKS, 0, 0, nBlocks, cBlocks, bn, bc), + &LIBXSMM_VLA_ACCESS(5, i, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); + + /* initialize ci with bd */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, ci, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &bd[ik] ); + /* ci += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wc, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(5, x, j, inb, CB*CB_BLOCKS, 0, 0, nBlocks, cBlocks, bn, bc), + &LIBXSMM_VLA_ACCESS(5, ci, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); + + /* initialize f with (bf + forget_bias) */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, f, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &bf[ik], handle->forget_bias ); + /* f += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wf, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(5, x, j, inb, CB*CB_BLOCKS, 0, 0, nBlocks, cBlocks, bn, bc), + &LIBXSMM_VLA_ACCESS(5, f, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); + + /* initialize o with bo */ + if (CB == 0) MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, o, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &bo[ik] ); + /* o += W.x */ + batchreduce_kernela(&LIBXSMM_VLA_ACCESS(5, wo, ikb, CB*CB_BLOCKS, 0, 0, 0, cBlocks, bc_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(5, x, j, inb, CB*CB_BLOCKS, 0, 0, nBlocks, cBlocks, bn, bc), + &LIBXSMM_VLA_ACCESS(5, o, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles += gemm_end-gemm_start; + } +#endif + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +/* Compute the R*h part of the output */ +blocks = KB_BLOCKS; +/* Peel off the t=0 iteration to hoist the innermost if conditions */ +j = 0; +/* let's run the cell in blocks for good locality */ +/* Block reduction loop if requested */ +for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + inb = inik % (N/bn); + ikb = inik / (N/bn); + ik = ikb*bk; +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hp, inb, CB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, i, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + /* Eltwise ops and downcovert for the i computed block */ + if (CB == BF-1) { + libxsmm_blasint _k, _j; + float* _i = &LIBXSMM_VLA_ACCESS(5, i, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst = &LIBXSMM_VLA_ACCESS(5, i_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + __m512 _vi0, _vi1; + const __m512 _halves = _mm512_set1_ps( (LIBXSMM_DNN_ELTWISE_FTYPE)0.5 ); + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 32 ) { + _vi0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*bk)+_k] ); + _vi0 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vi0, _halves ) ), _halves, _halves); + _mm512_store_ps( &_i[(_j*bk)+_k], _vi0 ); + _vi1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*bk)+_k+16] ); + _vi1 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vi1, _halves ) ), _halves, _halves); + _mm512_store_ps( &_i[(_j*bk)+_k+16], _vi1 ); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vi1, _vi0)); + } + } + } +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hp, inb, CB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, ci, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + /* Eltwise ops and downcovert for the ci computed block */ + if (CB == BF-1) { + libxsmm_blasint _k, _j; + float* _ci = &LIBXSMM_VLA_ACCESS(5, ci, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst = &LIBXSMM_VLA_ACCESS(5, ci_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + __m512 _vci0, _vci1; + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 32 ) { + _vci0 = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*bk)+_k] )); + _mm512_store_ps( &_ci[(_j*bk)+_k], _vci0 ); + _vci1 = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*bk)+_k+16] )); + _mm512_store_ps( &_ci[(_j*bk)+_k+16], _vci1 ); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vci1, _vci0)); + } + } + } +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hp, inb, CB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, f, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + /* Eltwise ops and downcovert for the f computed block */ + if (CB == BF-1) { + libxsmm_blasint _k, _j; + float* _f = &LIBXSMM_VLA_ACCESS(5, f, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst = &LIBXSMM_VLA_ACCESS(5, f_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + __m512 _vf0, _vf1; + const __m512 _halves = _mm512_set1_ps( (LIBXSMM_DNN_ELTWISE_FTYPE)0.5 ); + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 32 ) { + _vf0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*bk)+_k] ); + _vf0 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vf0, _halves ) ), _halves, _halves); + _mm512_store_ps( &_f[(_j*bk)+_k], _vf0 ); + _vf1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*bk)+_k+16] ); + _vf1 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vf1, _halves ) ), _halves, _halves); + _mm512_store_ps( &_f[(_j*bk)+_k+16], _vf1 ); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vf1, _vf0)); + } + } + } +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(4, hp, inb, CB*KB_BLOCKS, 0, 0, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, o, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + /* Eltwise ops and downcovert for the o computed block */ + if (CB == BF-1) { + libxsmm_blasint _k, _j; + float* _o = &LIBXSMM_VLA_ACCESS(5, o, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + float* _i = &LIBXSMM_VLA_ACCESS(5, i, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + float* _f = &LIBXSMM_VLA_ACCESS(5, f, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + float* _ci = &LIBXSMM_VLA_ACCESS(5, ci, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + float* _cps = &LIBXSMM_VLA_ACCESS(4, cp, inb, ikb, 0, 0, kBlocks, bn, bk); + float* _cs = &LIBXSMM_VLA_ACCESS(5, cs, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst_o = &LIBXSMM_VLA_ACCESS(5, o_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst_cs = &LIBXSMM_VLA_ACCESS(5, cs_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst_h = &LIBXSMM_VLA_ACCESS(5, h_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst_co = &LIBXSMM_VLA_ACCESS(5, co_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + __m512 _vf, _vcs, _vi, _vci, _vco, _vo, _vh, _vf1, _vcs1, _vi1, _vci1, _vco1, _vo1, _vh1; + const __m512 _halves = _mm512_set1_ps( (LIBXSMM_DNN_ELTWISE_FTYPE)0.5 ); + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 32 ) { + _vo = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*bk)+_k] ); + _vi = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*bk)+_k] ); + _vci = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*bk)+_k] ); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*bk)+_k] ); + _vcs = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*bk)+_k] ); + _vo = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vo, _halves ) ), _halves, _halves); + _vcs = _mm512_mul_ps( _vf, _vcs ); + _vcs = _mm512_fmadd_ps( _vi, _vci, _vcs ); + _mm512_store_ps( &_cs[(_j*bk)+_k], _vcs ); + _vco = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _vcs ); + _vh = _mm512_mul_ps( _vo, _vco ); + _vo1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*bk)+_k+16] ); + _vi1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*bk)+_k+16] ); + _vci1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*bk)+_k+16] ); + _vf1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*bk)+_k+16] ); + _vcs1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*bk)+_k+16] ); + _vo1 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vo1, _halves ) ), _halves, _halves); + _vcs1 = _mm512_mul_ps( _vf1, _vcs1 ); + _vcs1 = _mm512_fmadd_ps( _vi1, _vci1, _vcs1 ); + _mm512_store_ps( &_cs[(_j*bk)+_k+16], _vcs1 ); + _vco1 = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _vcs1 ); + _vh1 = _mm512_mul_ps( _vo1, _vco1 ); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst_o[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vo1, _vo)); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst_cs[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vcs1, _vcs)); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst_h[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vh1, _vh)); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst_co[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vco1, _vco)); + } + } + } + } +} +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +for (j = 1; j < t; ++j) { + /* let's run the cell in blocks for good locality */ + /* Block reduction loop if requested */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + inb = inik % (N/bn); + ikb = inik / (N/bn); + ik = ikb*bk; +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* i += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ri, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(5, h_out, j-1, inb, CB*KB_BLOCKS, 0, 0, nBlocks, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, i, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + /* Eltwise ops and downcovert for the i computed block */ + if (CB == BF-1) { + libxsmm_blasint _k, _j; + float* _i = &LIBXSMM_VLA_ACCESS(5, i, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst = &LIBXSMM_VLA_ACCESS(5, i_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + __m512 _vi0, _vi1; + const __m512 _halves = _mm512_set1_ps( (LIBXSMM_DNN_ELTWISE_FTYPE)0.5 ); + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 32 ) { + _vi0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*bk)+_k] ); + _vi0 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vi0, _halves ) ), _halves, _halves); + _mm512_store_ps( &_i[(_j*bk)+_k], _vi0 ); + _vi1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*bk)+_k+16] ); + _vi1 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vi1, _halves ) ), _halves, _halves); + _mm512_store_ps( &_i[(_j*bk)+_k+16], _vi1 ); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vi1, _vi0)); + } + } + } +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* ci += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rc, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(5, h_out, j-1, inb, CB*KB_BLOCKS, 0, 0, nBlocks, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, ci, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + /* Eltwise ops and downcovert for the ci computed block */ + if (CB == BF-1) { + libxsmm_blasint _k, _j; + float* _ci = &LIBXSMM_VLA_ACCESS(5, ci, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst = &LIBXSMM_VLA_ACCESS(5, ci_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + __m512 _vci0, _vci1; + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 32 ) { + _vci0 = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*bk)+_k] )); + _mm512_store_ps( &_ci[(_j*bk)+_k], _vci0 ); + _vci1 = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*bk)+_k+16] )); + _mm512_store_ps( &_ci[(_j*bk)+_k+16], _vci1 ); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vci1, _vci0)); + } + } + } +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* f += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, rf, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(5, h_out, j-1, inb, CB*KB_BLOCKS, 0, 0, nBlocks, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, f, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + /* Eltwise ops and downcovert for the f computed block */ + if (CB == BF-1) { + libxsmm_blasint _k, _j; + float* _f = &LIBXSMM_VLA_ACCESS(5, f, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst = &LIBXSMM_VLA_ACCESS(5, f_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + __m512 _vf0, _vf1; + const __m512 _halves = _mm512_set1_ps( (LIBXSMM_DNN_ELTWISE_FTYPE)0.5 ); + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 32 ) { + _vf0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*bk)+_k] ); + _vf0 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vf0, _halves ) ), _halves, _halves); + _mm512_store_ps( &_f[(_j*bk)+_k], _vf0 ); + _vf1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*bk)+_k+16] ); + _vf1 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vf1, _halves ) ), _halves, _halves); + _mm512_store_ps( &_f[(_j*bk)+_k+16], _vf1 ); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vf1, _vf0)); + } + } + } +#ifdef PROFILE + if (ltid == 0) gemm_start = _rdtsc(); +#endif + /* o += R.h */ + batchreduce_kernelb(&LIBXSMM_VLA_ACCESS(5, ro, ikb, CB*KB_BLOCKS, 0, 0, 0, kBlocks, bk_lp, bk, lpb), + &LIBXSMM_VLA_ACCESS(5, h_out, j-1, inb, CB*KB_BLOCKS, 0, 0, nBlocks, kBlocks, bn, bk), + &LIBXSMM_VLA_ACCESS(5, o, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); +#ifdef PROFILE + if (ltid == 0) { + gemm_end = _rdtsc(); + gemm_cycles2 += gemm_end-gemm_start; + } +#endif + /* Eltwise ops and downcovert for the o computed block */ + if (CB == BF-1) { + libxsmm_blasint _k, _j; + float* _o = &LIBXSMM_VLA_ACCESS(5, o, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + float* _i = &LIBXSMM_VLA_ACCESS(5, i, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + float* _f = &LIBXSMM_VLA_ACCESS(5, f, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + float* _ci = &LIBXSMM_VLA_ACCESS(5, ci, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + float* _cps = &LIBXSMM_VLA_ACCESS(5, cs, j-1, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + float* _cs = &LIBXSMM_VLA_ACCESS(5, cs, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst_o = &LIBXSMM_VLA_ACCESS(5, o_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst_cs = &LIBXSMM_VLA_ACCESS(5, cs_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst_h = &LIBXSMM_VLA_ACCESS(5, h_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + libxsmm_bfloat16 *dst_co = &LIBXSMM_VLA_ACCESS(5, co_out, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + __m512 _vf, _vcs, _vi, _vci, _vco, _vo, _vh, _vf1, _vcs1, _vi1, _vci1, _vco1, _vo1, _vh1; + const __m512 _halves = _mm512_set1_ps( (LIBXSMM_DNN_ELTWISE_FTYPE)0.5 ); + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 32 ) { + _vo = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*bk)+_k] ); + _vi = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*bk)+_k] ); + _vci = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*bk)+_k] ); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*bk)+_k] ); + _vcs = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*bk)+_k] ); + _vo = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vo, _halves ) ), _halves, _halves); + _vcs = _mm512_mul_ps( _vf, _vcs ); + _vcs = _mm512_fmadd_ps( _vi, _vci, _vcs ); + _mm512_store_ps( &_cs[(_j*bk)+_k], _vcs ); + _vco = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _vcs ); + _vh = _mm512_mul_ps( _vo, _vco ); + _vo1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*bk)+_k+16] ); + _vi1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*bk)+_k+16] ); + _vci1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*bk)+_k+16] ); + _vf1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*bk)+_k+16] ); + _vcs1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*bk)+_k+16] ); + _vo1 = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vo1, _halves ) ), _halves, _halves); + _vcs1 = _mm512_mul_ps( _vf1, _vcs1 ); + _vcs1 = _mm512_fmadd_ps( _vi1, _vci1, _vcs1 ); + _mm512_store_ps( &_cs[(_j*bk)+_k+16], _vcs1 ); + _vco1 = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _vcs1 ); + _vh1 = _mm512_mul_ps( _vo1, _vco1 ); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst_o[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vo1, _vo)); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst_cs[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vcs1, _vcs)); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst_h[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vh1, _vh)); + _mm512_storeu_si512((libxsmm_bfloat16*)&dst_co[(_j*bk)+_k], (__m512i) LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH(_vco1, _vco)); + } + } + } + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + +#undef NATIVE_MATRIX_RNE_CVT_FP32_BFP16_LD + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..3dba8bbd569a746a005ae484a355a5a059009b58 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c @@ -0,0 +1,357 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* helper variables */ +libxsmm_blasint i, ik, in, ic, jk, jb/*jn shadows global variable*/, jc, ek, en, ec; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *wD = (element_filter_type*)handle->w->data; +element_filter_type *rD = (element_filter_type*)handle->r->data; +element_output_type *ht = (element_output_type*)handle->ht->data; +element_input_type *dxt = (element_input_type*)handle->dxt->data; +element_filter_type *dwD = (element_filter_type*)handle->dw->data; +element_filter_type *drD = (element_filter_type*)handle->dr->data; +element_output_type *db = (element_output_type*)handle->db->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *deltat = (element_output_type*)handle->scratch_deltat; +element_input_type *scratch_xT = (element_input_type*)handle->scratch_xT; +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +/* multidimensional arrays */ +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(2, element_filter_type, w, wD, K); +LIBXSMM_VLA_DECL(2, element_filter_type, r, rD, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_input_type, dx, dxt, N, C); +LIBXSMM_VLA_DECL(2, element_filter_type, dw, dwD, K); +LIBXSMM_VLA_DECL(2, element_filter_type, dr, drD, K); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, delta, deltat, N, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +LIBXSMM_VLA_DECL(2, element_filter_type, wT, scratch_wT, C); +LIBXSMM_VLA_DECL(2, element_filter_type, rT, scratch_rT, K); +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +#if defined(LIBXSMM_DNN_RNN_RELU_BWDUPD) || defined(LIBXSMM_DNN_RNN_SIGMOID_BWDUPD) || defined(LIBXSMM_DNN_RNN_TANH_BWDUPD) +element_output_type *zt = (element_output_type*)handle->internal_z; +LIBXSMM_VLA_DECL(3, element_output_type, z, zt, N, K); +#endif +/* define gemm kernels */ +libxsmm_smmfunction gemmkernela = libxsmm_smmdispatch( bc, bn, bk, &C, &K, &C, NULL, NULL, NULL, NULL ); +libxsmm_smmfunction gemmkernelb = libxsmm_smmdispatch( bk, bk, bn, &K, &N, &K, NULL, NULL, NULL, NULL ); +libxsmm_smmfunction gemmkernelc = libxsmm_smmdispatch( bk, bc, bn, &K, &N, &K, NULL, NULL, NULL, NULL ); +libxsmm_smmfunction gemmkerneld = libxsmm_smmdispatch( bk, bn, bk, &K, &K, &K, NULL, NULL, NULL, NULL ); + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +/* number of tasks that could be run in parallel for K blocks*/ +/* compute chunk size */ +const libxsmm_blasint chunksize_k = (K % (libxsmm_blasint)handle->desc.threads == 0) ? (K / (libxsmm_blasint)handle->desc.threads) : ((K / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_k = (ltid * chunksize_k < K) ? (ltid * chunksize_k) : K; +const libxsmm_blasint thr_end_k = ((ltid + 1) * chunksize_k < K) ? ((ltid + 1) * chunksize_k) : K; + +libxsmm_blasint ikic, inic, inik, icin, ikin; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* initialization is done at the beginning */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(N*C*t, dxt, start_thread, tid, handle->desc.threads); +} +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + libxsmm_internal_matrix_zero(C*K, dwD, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K*K, drD, start_thread, tid, handle->desc.threads); + libxsmm_internal_matrix_zero(K, db, start_thread, tid, handle->desc.threads); +} + +/* transpose W */ +for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ik = (ikic / (C/bc))*bk; + ic = (ikic % (C/bc))*bc; + + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + ek = ik + jk; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, wT, ek, ec, C) = LIBXSMM_VLA_ACCESS(2, w, ec, ek, K); + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk))*bk; + ic = (ikic % (K/bk))*bk; + + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + ek = ik + jk; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, rT, ek, ec, K) = LIBXSMM_VLA_ACCESS(2, r, ec, ek, K); + } + } +} + +/* transpose xt for current timestep */ +for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + ic = (icin / (N/bn))*bc; + in = (icin % (N/bn))*bn; + + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, xT, ec, en, N) = LIBXSMM_VLA_ACCESS(3, x, t-1, en, ec, N, C); + } + } +} + +/* transpose ht for current timestep */ +for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + ik = (ikin / (N/bn))*bk; + in = (ikin % (N/bn))*bn; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, h, t-2, en, ek, N, K); + } + } +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +/* The following code is for time step t-1 */ +for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik / (K/bk))*bn; + ik = (inik % (K/bk))*bk; + +#if defined(LIBXSMM_DNN_RNN_RELU_BWDUPD) + libxsmm_internal_matrix_relu_inverse_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_SIGMOID_BWDUPD) + libxsmm_internal_matrix_sigmoid_inverse_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_TANH_BWDUPD) + libxsmm_internal_matrix_tanh_inverse_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K) ); +#endif + + libxsmm_internal_matrix_inplace_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, t-1, in, ik, N, K), + &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K) ); +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* gemm kernel bwd_d */ + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + in = (inic / (C/bc))*bn; + ic = (inic % (C/bc))*bc; + + for (ik = 0; ik < K; ik += bk) { + gemmkernela( &LIBXSMM_VLA_ACCESS(2, wT, ik, ic, C), &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, dx, t-1, in, ic, N, C) ); + } + } +} +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* gradient bias */ + for (ik = thr_begin_k; ik < thr_end_k; ik++) { + for (in = 0; in < N; in++) { + db[ik] += LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K); + } + } + + /* dr = delta * h^T */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ic = (ikic / (K/bk))*bk; + ik = (ikic % (K/bk))*bk; + + for (in = 0; in < N; in += bn) { + gemmkernelb( &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N), &LIBXSMM_VLA_ACCESS(2, dr, ic, ik, K) ); + } + } + + /* dw = delta * x^T */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk))*bc; + ik = (ikic % (K/bk))*bk; + + for (in = 0; in < N; in += bn ) { + gemmkernelc( &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N), &LIBXSMM_VLA_ACCESS(2, dw, ic, ik, K) ); + } + } +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +for (i = t-2; i >= 0; --i) { + /* transpose xt for current timestep */ + for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + ic = (icin / (N/bn))*bc; + in = (icin % (N/bn))*bn; + + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, xT, ec, en, N) = LIBXSMM_VLA_ACCESS(3, x, i, en, ec, N, C); + } + } + } + + /* transpose ht for current timestep */ + if (0 == i) { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + ik = (ikin / (N/bn))*bk; + in = (ikin % (N/bn))*bn; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(2, hp, en, ek, K); + } + } + } + } else { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + ik = (ikin / (N/bn))*bk; + in = (ikin % (N/bn))*bn; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, h, i-1, en, ek, N, K); + } + } + } + } + + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + /* let's run the cell in blocks for good locality */ + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik / (K/bk))*bn; + ik = (inik % (K/bk))*bk; + + /* delta = dh */ + libxsmm_internal_matrix_copy_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) ); + + /* delta += R^T * delta+1 */ + for (ic = 0; ic < K; ic += bk) { + gemmkerneld( &LIBXSMM_VLA_ACCESS(2, rT, ic, ik, K), &LIBXSMM_VLA_ACCESS(3, delta, i+1, in, ic, N, K), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) ); + } + + /* run inverse non-linear op */ +#if defined(LIBXSMM_DNN_RNN_RELU_BWDUPD) + libxsmm_internal_matrix_relu_inverse_inplace_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_SIGMOID_BWDUPD) + libxsmm_internal_matrix_sigmoid_inverse_inplace_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_TANH_BWDUPD) + libxsmm_internal_matrix_tanh_inverse_inplace_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) ); +#endif + } + + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* dx = W^T * delta */ + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + in = (inic / (C/bc))*bn; + ic = (inic % (C/bc))*bc; + + for (ik = 0; ik < K; ik += bk) { + gemmkernela( &LIBXSMM_VLA_ACCESS(2, wT, ik, ic, C), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, dx, i, in, ic, N, C) ); + } + } + } + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* gradient bias */ + for (ik = thr_begin_k; ik < thr_end_k; ik++) { + for (in = 0; in < N; in++) { + db[ik] += LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K); + } + } + + /* dr = delta * h^T */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ic = (ikic / (K/bk))*bk; + ik = (ikic % (K/bk))*bk; + + for (in = 0; in < N; in += bn) { + gemmkernelb( &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N), &LIBXSMM_VLA_ACCESS(2, dr, ic, ik, K) ); + } + } + + /* dw = delta * x^T */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ic = (ikic / (K/bk))*bc; + ik = (ikic % (K/bk))*bk; + + for (in = 0; in < N; in += bn ) { + gemmkernelc( &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N), &LIBXSMM_VLA_ACCESS(2, dw, ic, ik, K) ); + } + } + } + + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..2b18e0d7a2f2cd10ae6495004938d609aff0ba12 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c @@ -0,0 +1,425 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* helper variables */ +libxsmm_blasint i, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; +/* tensor dimensions */ +libxsmm_blasint K = handle->desc.K; +libxsmm_blasint N = handle->desc.N; +libxsmm_blasint C = handle->desc.C; +libxsmm_blasint t = handle->T; +libxsmm_blasint bk = handle->bk; +libxsmm_blasint bn = handle->bn; +libxsmm_blasint bc = handle->bc; +/* tensor raw pointers */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD = (element_input_type* )handle->hp->data; +element_filter_type *wtD = (element_filter_type*)handle->wt->data; +element_filter_type *rtD = (element_filter_type*)handle->rt->data; +element_output_type *ht = (element_output_type*)handle->ht->data; +element_input_type *dxt = (element_input_type*)handle->dxt->data; +element_filter_type *dwD = (element_filter_type*)handle->dw->data; +element_filter_type *drD = (element_filter_type*)handle->dr->data; +element_output_type *db = (element_output_type*)handle->db->data; +element_output_type *dht = (element_output_type*)handle->dht->data; +element_output_type *deltat = (element_output_type*)handle->scratch_deltat; +element_input_type *scratch_xT = (element_input_type*)handle->scratch_xT; +#if 0 +element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; +element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; +#endif +element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; +/* Auxiliary variables for bact-reduce calls */ +libxsmm_blasint nBlocks = N/bn; +libxsmm_blasint cBlocks = C/bc; +libxsmm_blasint kBlocks = K/bk; +unsigned long long blocks; +const float beta = 0.0; +/* multidimensional arrays */ +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, wT, wtD, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, rT, rtD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_input_type, dx, dxt, N, C); +LIBXSMM_VLA_DECL(4, element_filter_type, dw, dwD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, dr, drD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, delta, deltat, N, K); +LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); +#if 0 +LIBXSMM_VLA_DECL(4, element_filter_type, wT, scratch_wT, kBlocks, bk, bc); +LIBXSMM_VLA_DECL(4, element_filter_type, rT, scratch_rT, kBlocks, bk, bk); +#endif +LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); +#if defined(LIBXSMM_DNN_RNN_RELU_BWDUPD) || defined(LIBXSMM_DNN_RNN_SIGMOID_BWDUPD) || defined(LIBXSMM_DNN_RNN_TANH_BWDUPD) +element_output_type *zt = (element_output_type*)handle->internal_z; +LIBXSMM_VLA_DECL(3, element_output_type, z, zt, N, K); +#endif +/* define batch-reduce gemm kernels */ +/*const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelaz = libxsmm_smmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, &beta, NULL, NULL);*/ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelbz = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &K, &N, &bk, NULL, &beta, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelcz = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &K, &N, &bk, NULL, &beta, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &K, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &K, &N, &bk, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kerneld = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, NULL, NULL); + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; + +/* number of tasks that could be run in parallel for N and K blocks*/ +const libxsmm_blasint work_nk = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; +const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; + +/* number of tasks that could be run in parallel for N and C blocks*/ +const libxsmm_blasint work_nc = (N/bn) * (C/bc); +/* compute chunk size */ +const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; +const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; + +/* number of tasks that could be run in parallel for C and K blocks*/ +const libxsmm_blasint work_ck = (C/bc) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; +const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; + +/* number of tasks that could be run in parallel for K and K blocks*/ +const libxsmm_blasint work_kk = (K/bk) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; +const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; + +#if defined(LIBXSMM_RNN_CELL_AVX512) +int k_tasks = K/16; +int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K; +const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K; +__m512 db_sum; +#else +/* number of tasks that could be run in parallel for K blocks*/ +/* compute chunk size */ +const libxsmm_blasint chunksize_k = (K % (libxsmm_blasint)handle->desc.threads == 0) ? (K / (libxsmm_blasint)handle->desc.threads) : ((K / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin_k = (ltid * chunksize_k < K) ? (ltid * chunksize_k) : K; +const libxsmm_blasint thr_end_k = ((ltid + 1) * chunksize_k < K) ? ((ltid + 1) * chunksize_k) : K; + +#endif + +libxsmm_blasint ikic, inic, inik, icin, ikin; + +/* Auxiliary arrays for batch-reduce gemm calls */ +const element_filter_type *A_array[1024]; +const element_output_type *B_array[1024]; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (C >= 512 && K >= 512 && C%2 == 0 && K%2 == 0) { + BF = 2; +} +if (C >= 2048 && K >= 2048 && C%8 == 0 && K%8 == 0) { + BF = 8; +} +KB_BLOCKS = kBlocks/BF; + +#if 0 +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* transpose W */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + ik = (ikic / (C/bc)); + ic = (ikic % (C/bc)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bc; ++jc) { + LIBXSMM_VLA_ACCESS(4, wT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(4, w, ik, ic, jc, jk, cBlocks, bc, bk); + } + } + } +} + +/* transpose R */ +for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + ik = (ikic / (K/bk)); + ic = (ikic % (K/bk)); + for (jk = 0; jk < bk; ++jk) { + for (jc = 0; jc < bk; ++jc) { + LIBXSMM_VLA_ACCESS(4, rT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(4, r, ik, ic, jc, jk, kBlocks, bk, bk); + } + } +} +#endif + +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* transpose xt for current timestep */ + for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + ic = (icin / (N/bn))*bc; + in = (icin % (N/bn))*bn; + + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, xT, ec, en, N) = LIBXSMM_VLA_ACCESS(3, x, t-1, en, ec, N, C); + } + } + } + + /* transpose ht for current timestep */ + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + ik = (ikin / (N/bn))*bk; + in = (ikin % (N/bn))*bn; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, h, t-2, en, ek, N, K); + } + } + } +} + +/* The following code is for time step t-1 */ +for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik / (K/bk))*bn; + ik = (inik % (K/bk))*bk; + +#if defined(LIBXSMM_DNN_RNN_RELU_BWDUPD) + libxsmm_internal_matrix_relu_inverse_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_SIGMOID_BWDUPD) + libxsmm_internal_matrix_sigmoid_inverse_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_TANH_BWDUPD) + libxsmm_internal_matrix_tanh_inverse_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K) ); +#endif + libxsmm_internal_matrix_inplace_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, t-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K) ); +} + +libxsmm_barrier_wait(handle->barrier, (int)ltid); + +if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* gemm kernel bwd_d */ + for (KB = 0; KB < BF; KB++) { + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + in = (inic / (C/bc))*bn; + icb = (inic % (C/bc)); + ic = icb * bc; + /* Prepare arguments for batch-reduce call */ + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik+=bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik + KB*KB_BLOCKS*bk, N, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, t-1, in, ic, N, C), &blocks); + } + } +} + +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* dr = delta * h^T */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelbz(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dr, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + } + + /* dw = delta * x^T */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(3, delta, t-1, in, ik, N, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelcz(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dw, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } +} + +for (i = t-2; i >= 0; --i) { + /* let's run the cell in blocks for good locality */ + for (inik = thr_begin_nk; inik < thr_end_nk; ++inik ) { + in = (inik / (K/bk))*bn; + ikb = (inik % (K/bk)); + ik = ikb*bk; + /* delta = dh */ + libxsmm_internal_matrix_copy_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, dh, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) ); + + /* delta += R^T * delta+1 */ + for (ic = 0; ic < kBlocks; ic++) { + A_array[ic] = &LIBXSMM_VLA_ACCESS(4, rT, ikb, ic, 0, 0, kBlocks, bk, bk); + B_array[ic] = &LIBXSMM_VLA_ACCESS(3, delta, i+1, in, ic*bk, N, K); + } + /* Reduce batch gemm call */ + blocks = kBlocks; + batchreduce_kerneld(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) , &blocks); + + /* run inverse non-linear op */ +#if defined(LIBXSMM_DNN_RNN_RELU_BWDUPD) + libxsmm_internal_matrix_relu_inverse_inplace_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_SIGMOID_BWDUPD) + libxsmm_internal_matrix_sigmoid_inverse_inplace_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_TANH_BWDUPD) + libxsmm_internal_matrix_tanh_inverse_inplace_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K) ); +#endif + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* transpose xt for current timestep */ + for (icin = thr_begin_nc; icin < thr_end_nc; ++icin ) { + ic = (icin / (N/bn))*bc; + in = (icin % (N/bn))*bn; + + for (jc = 0; jc < bc; ++jc) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ec = ic + jc; + LIBXSMM_VLA_ACCESS(2, xT, ec, en, N) = LIBXSMM_VLA_ACCESS(3, x, i, en, ec, N, C); + } + } + } + + /* transpose ht for current timestep */ + if (0 == i) { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + ik = (ikin / (N/bn))*bk; + in = (ikin % (N/bn))*bn; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(2, hp, en, ek, K); + } + } + } + } else { + for (ikin = thr_begin_nk; ikin < thr_end_nk; ++ikin ) { + ik = (ikin / (N/bn))*bk; + in = (ikin % (N/bn))*bn; + + for (jk = 0; jk < bk; ++jk) { + for (jb = 0; jb < bn; ++jb) { + en = in + jb; + ek = ik + jk; + LIBXSMM_VLA_ACCESS(2, hT, ek, en, N) = LIBXSMM_VLA_ACCESS(3, h, i-1, en, ek, N, K); + } + } + } + } + } + + if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* dx = W^T * delta */ + for (KB = 0; KB < BF; KB++) { + for (inic = thr_begin_nc; inic < thr_end_nc; ++inic ) { + in = (inic / (C/bc))*bn; + icb = (inic % (C/bc)); + ic = icb * bc; + /* Prepare arguments for batch-reduce call */ + for (ik = 0, ikb = 0; ikb < KB_BLOCKS; ik+=bk, ikb++) { + A_array[ikb] = &LIBXSMM_VLA_ACCESS(4, wT, icb, ikb + KB*KB_BLOCKS, 0, 0, kBlocks, bk, bc); + B_array[ikb] = &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik + KB*KB_BLOCKS*bk, N, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, dx, i, in, ic, N, C), &blocks); + } + } + } + + libxsmm_barrier_wait(handle->barrier, (int)ltid); + + if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { + /* dr = delta * h^T */ + for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bk; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, hT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dr, ikb, icb, 0, 0, kBlocks, bk, bk), &blocks); + } + + /* dw = delta * x^T */ + for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { + icb = ikic / (K/bk); + ic = icb*bc; + ikb = ikic % (K/bk); + ik = ikb*bk; + + for (in = 0, inb = 0; in < N; in += bn, inb++) { + A_array[inb] = &LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K); + B_array[inb] = &LIBXSMM_VLA_ACCESS(2, xT, ic, in, N); + } + blocks = nBlocks; + batchreduce_kernelc(A_array, B_array, &LIBXSMM_VLA_ACCESS(4, dw, ikb, icb, 0, 0, cBlocks, bc, bk), &blocks); + } + } +} + +/* gradient bias */ +if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { +#if defined(LIBXSMM_RNN_CELL_AVX512) + for (ik = k_thr_begin; ik < k_thr_end; ik += 16) { + db_sum = _mm512_setzero_ps(); + for (i = 0; i < t; i++) { + for (in = 0; in < N; in++) { + db_sum = _mm512_add_ps(db_sum, LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K))); + } + } + LIBXSMM_INTRINSICS_MM512_STREAM_PS(&db[ik], db_sum); + } +#else + for (i = 0; i < t; i++) { + for (ik = thr_begin_k; ik < thr_end_k; ik++) { + for (in = 0; in < N; in++) { + db[ik] += LIBXSMM_VLA_ACCESS(3, delta, i, in, ik, N, K); + } + } + } +#endif +} +libxsmm_barrier_wait(handle->barrier, (int)ltid); diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..0e77df1e195dae3dc822694271e95b2cab440f48 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c @@ -0,0 +1,92 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* helper variables */ +libxsmm_blasint i, ik, in, ic, inik; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD= (element_input_type* )handle->hp->data; +element_filter_type *wD = (element_filter_type*)handle->w->data; +element_filter_type *rD = (element_filter_type*)handle->r->data; +element_output_type *b = (element_output_type*)handle->b->data; +element_output_type *ht = (element_output_type*)handle->ht->data; +element_output_type *zt = (element_output_type*)handle->internal_z; +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(2, element_filter_type, w, wD, K); +LIBXSMM_VLA_DECL(2, element_filter_type, r, rD, K); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, z, zt, N, K); +/* define gemm kernels */ +libxsmm_smmfunction gemmkernela = libxsmm_smmdispatch( bk, bn, bc, &K, &C, &K, NULL, NULL, NULL, NULL ); +libxsmm_smmfunction gemmkernelb = libxsmm_smmdispatch( bk, bn, bk, &K, &K, &K, NULL, NULL, NULL, NULL ); +/* parallelize over C-blocks */ +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* All data is in column-major format */ +for (i = 0; i < t; ++i) { + /* let's run the cell in blocks for good locality */ + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = (inik / (K/bk))*bn; + ik = (inik % (K/bk))*bk; + + /* z = per_col(b) */ + libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &b[ik] ); + + /* z += W.x */ + for (ic = 0; ic < C; ic += bc) { + /* this is a small matmul */ + gemmkernela( &LIBXSMM_VLA_ACCESS(2, w, ic, ik, K), &LIBXSMM_VLA_ACCESS(3, x, i, in, ic, N, C), &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K) ); + } + /* z += U.h */ + if (0 == i) { + for (ic = 0; ic < K; ic += bk) { + /* this is a small matmul */ + gemmkernelb( &LIBXSMM_VLA_ACCESS(2, r, ic, ik, K), &LIBXSMM_VLA_ACCESS(2, hp, in, ic, K), &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K) ); + } + } else { + for (ic = 0; ic < K; ic += bk) { + /* this is a small matmul */ + gemmkernelb( &LIBXSMM_VLA_ACCESS(2, r, ic, ik, K), &LIBXSMM_VLA_ACCESS(3, h, i-1, in, ic, N, K), &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K) ); + } + } +#if defined(LIBXSMM_DNN_RNN_RELU_FWD) + libxsmm_internal_matrix_relu_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_SIGMOID_FWD) + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in, ik, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_TANH_FWD) + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in, ik, N, K) ); +#endif + } + + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..a945819a481f2a3f9ab0c62cb244f3035247fbcf --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c @@ -0,0 +1,136 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* helper variables */ +libxsmm_blasint i, ik, in, ic, inik, BF, CB, CB_BLOCKS, KB_BLOCKS; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD= (element_input_type* )handle->hp->data; +element_filter_type *wD = (element_filter_type*)handle->w->data; +element_filter_type *rD = (element_filter_type*)handle->r->data; +element_output_type *b = (element_output_type*)handle->b->data; +element_output_type *ht = (element_output_type*)handle->ht->data; +element_output_type *zt = (element_output_type*)handle->internal_z; +/*libxsmm_blasint nBlocks = N/bn;*/ +libxsmm_blasint cBlocks = C/bc; +libxsmm_blasint kBlocks = K/bk; +unsigned long long blocks; +LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); +LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); +LIBXSMM_VLA_DECL(4, element_filter_type, w, wD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, r, rD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); +LIBXSMM_VLA_DECL(3, element_output_type, z, zt, N, K); +int prefetch_mode = LIBXSMM_GEMM_PREFETCH_NONE/*LIBXSMM_GEMM_PREFETCH_AL1_BL1*/; +/* define gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, NULL, &prefetch_mode ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, &prefetch_mode ); + +/* Auxiliary arrays for batch-reduce gemms */ +const element_input_type *A_array[1024]; +const element_input_type *B_array[1024]; +const element_input_type *A_array2[1024]; +const element_input_type *B_array2[1024]; + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* Blocking reduction domain if it is too large */ +BF = 1; +if (C >= 2048 && K >= 2048 && C%2 == 0 && K%2 == 0) { + BF = 2; +} +CB_BLOCKS = cBlocks/BF; +KB_BLOCKS = kBlocks/BF; +assert(CB_BLOCKS <= 1024); +assert(KB_BLOCKS <= 1024); + +/* lazy barrier init */ +libxsmm_barrier_init(handle->barrier, (int)ltid); + +/* All data is in column-major format */ +for (i = 0; i < t; ++i) { + /* let's run the cell in blocks for good locality */ + for (CB = 0; CB < BF; CB++) { + for (inik = thr_begin; inik < thr_end; ++inik ) { + if (C >= 2048 && K >= 2048) { + in = inik % (N/bn); + ik = inik / (N/bn); + } else { + in = inik / (K/bk); + ik = inik % (K/bk); + } + + /* z = per_col(b) */ + if (0 == CB) { + libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in*bn, ik*bk, N, K), &b[ik*bk] ); + } + + /* z += W.x */ + /* Prepare arrays for the call */ + for (ic = 0; ic < CB_BLOCKS; ic++) { + /* this is a small matmul */ + A_array[ic] = &LIBXSMM_VLA_ACCESS(4, w, ik, ic + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); + B_array[ic] = &LIBXSMM_VLA_ACCESS(3, x, i, in*bn, (ic + CB*CB_BLOCKS)*bc, N, C); + } + /* Reduce batch gemm call */ + blocks = CB_BLOCKS; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, z, i, in*bn, ik*bk, N, K), &blocks); + + /* z += U.h */ + if (0 == i) { + /* Prepare arrays for the call */ + for (ic = 0; ic < KB_BLOCKS; ic++) { + A_array2[ic] = &LIBXSMM_VLA_ACCESS(4, r, ik, ic + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array2[ic] = &LIBXSMM_VLA_ACCESS(2, hp, in*bn, (ic + CB*KB_BLOCKS)*bk, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array2, B_array2, &LIBXSMM_VLA_ACCESS(3, z, i, in*bn, ik*bk, N, K), &blocks); + } else { + /* Prepare arrays for the call */ + for (ic = 0; ic < KB_BLOCKS; ic++) { + A_array2[ic] = &LIBXSMM_VLA_ACCESS(4, r, ik, ic + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); + B_array2[ic] = &LIBXSMM_VLA_ACCESS(3, h, i-1, in*bn, (ic + CB*KB_BLOCKS)*bk, N, K); + } + /* Reduce batch gemm call */ + blocks = KB_BLOCKS; + batchreduce_kernelb(A_array2, B_array2, &LIBXSMM_VLA_ACCESS(3, z, i, in*bn, ik*bk, N, K), &blocks); + } +#if defined(LIBXSMM_DNN_RNN_RELU_FWD) + libxsmm_internal_matrix_relu_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in*bn, ik*bk, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in*bn, ik*bk, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_SIGMOID_FWD) + libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in*bn, ik*bk, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in*bn, ik*bk, N, K) ); +#endif +#if defined(LIBXSMM_DNN_RNN_TANH_FWD) + libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in*bn, ik*bk, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in*bn, ik*bk, N, K) ); +#endif + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..2717adf3c84146cc40b5f28c7bd1e5d9ebeae7c4 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c @@ -0,0 +1,234 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas, Alexander Heinecke, Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +/* helper variables */ +libxsmm_blasint i, ik, in, ic, inik; +/* input sizes */ +const libxsmm_blasint K = handle->desc.K; +const libxsmm_blasint N = handle->desc.N; +const libxsmm_blasint C = handle->desc.C; +const libxsmm_blasint t = handle->T; +const libxsmm_blasint bk = handle->bk; +const libxsmm_blasint bn = handle->bn; +const libxsmm_blasint bc = handle->bc; +/* define tensors */ +element_input_type *xt = (element_input_type* )handle->xt->data; +element_input_type *hpD= (element_input_type* )handle->hp->data; +element_filter_type *wD = (element_filter_type*)handle->w->data; +element_filter_type *rD = (element_filter_type*)handle->r->data; +element_output_type *b = (element_output_type*)handle->b->data; +element_output_type *ht = (element_output_type*)handle->ht->data; +element_output_type *zt = (element_output_type*)handle->internal_z; +libxsmm_blasint nBlocks = N/bn; +libxsmm_blasint cBlocks = C/bc; +libxsmm_blasint kBlocks = K/bk; +unsigned long long blocks; +LIBXSMM_VLA_DECL(5, element_input_type, x, xt, nBlocks, cBlocks, bn, bc); +LIBXSMM_VLA_DECL(4, element_input_type, hp, hpD, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, w, wD, cBlocks, bc, bk); +LIBXSMM_VLA_DECL(4, element_filter_type, r, rD, kBlocks, bk, bk); +LIBXSMM_VLA_DECL(5, element_output_type, h, ht, nBlocks, kBlocks, bn, bk); +LIBXSMM_VLA_DECL(5, element_output_type, z, zt, nBlocks, kBlocks, bn, bk); +int prefetch_mode = LIBXSMM_GEMM_PREFETCH_NONE/*LIBXSMM_GEMM_PREFETCH_AL1_BL1*/; +/* define gemm kernels */ +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &bc, &bk, NULL, NULL, NULL, &prefetch_mode ); +const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &bk, &bk, NULL, NULL, NULL, &prefetch_mode ); + +/* computing first logical thread */ +const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; +/* number of tasks that could be run in parallel */ +const libxsmm_blasint work = (N/bn) * (K/bk); +/* compute chunk size */ +const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; +libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; + +/* The snippet below does a 2D domain decomposition of output IF the number of threads and the number of work items are compatible */ +/* TODO: For now 2D decomposition targets single socket SKX */ +int row_teams = 7; +int column_teams = 4; +libxsmm_blasint my_col_id = ltid % column_teams; +libxsmm_blasint my_row_id = ltid / column_teams; +int in_tasks = (int)(N/bn); +int ik_tasks = (int)(K/bk); +int in_tasks_per_thread = in_tasks/row_teams; +int ik_tasks_per_thread = ik_tasks/column_teams; +libxsmm_blasint my_in_start = my_row_id * in_tasks_per_thread; +libxsmm_blasint my_in_end = (my_row_id+1) * in_tasks_per_thread; +libxsmm_blasint my_ik_start = my_col_id * ik_tasks_per_thread; +libxsmm_blasint my_ik_end = (my_col_id+1) * ik_tasks_per_thread; +int perform_2d_decomp = (in_tasks % row_teams == 0 && ik_tasks % column_teams == 0 && row_teams*column_teams == handle->desc.threads && cBlocks <= 32 && kBlocks <= 32 && ik_tasks_per_thread <= 16 && in_tasks_per_thread <= 2 ) ? 1 : 0; + +if (perform_2d_decomp) { + /* Auxiliary arrays for batch-reduce gemms and potential prefetch */ + const element_input_type *A_array[16][2][32]; + const element_input_type *B_array[16][2][32]; + const element_input_type *A_array2[16][2][32]; + const element_input_type *B_array2[16][2][32]; + const element_input_type *A_array_pf[16][2][32]; + const element_input_type *B_array_pf[16][2][32]; + const element_input_type *A_array2_pf[16][2][32]; + const element_input_type *B_array2_pf[16][2][32]; + int ii, jj; + + /* lazy barrier init */ + libxsmm_barrier_init(handle->barrier, (int)ltid); + + /* All data is in column-major format */ + for (i = 0; i < t; ++i) { + /* Prepare arrays for the batch-reduce calls */ + for (ik = my_ik_start, ii = 0; ik < my_ik_end; ++ik, ii++ ) { + for (in = my_in_start, jj = 0; in < my_in_end; ++in, jj++ ) { + /* Prepare arrays for the call */ + for (ic = 0; ic < cBlocks; ic++) { + /* this is a small matmul */ + A_array[ii][jj][ic] = &LIBXSMM_VLA_ACCESS(4, w, ik, ic, 0, 0, cBlocks, bc, bk); + B_array[ii][jj][ic] = &LIBXSMM_VLA_ACCESS(5, x, i, in, ic, 0, 0, nBlocks, cBlocks, bn, bc); + } + /* z += U.h */ + if (0 == i) { + /* Prepare arrays for the call */ + for (ic = 0; ic < kBlocks; ic++) { + A_array2[ii][jj][ic] = &LIBXSMM_VLA_ACCESS(4, r, ik, ic, 0, 0, kBlocks, bk, bk); + B_array2[ii][jj][ic] = &LIBXSMM_VLA_ACCESS(4, hp, in, ic, 0, 0, kBlocks, bn, bk); + } + } else { + /* Prepare arrays for the call */ + for (ic = 0; ic < kBlocks; ic++) { + A_array2[ii][jj][ic] = &LIBXSMM_VLA_ACCESS(4, r, ik, ic, 0, 0, kBlocks, bk, bk); + B_array2[ii][jj][ic] = &LIBXSMM_VLA_ACCESS(5, h, i-1, in, ic, 0, 0, nBlocks, kBlocks, bn, bk); + } + } + } + } + + if (prefetch_mode != LIBXSMM_GEMM_PREFETCH_NONE) { /* coverity[dead_error_begin] */ + /* Prepare additional prefetch arrays that are shifted images of regular ones when external prefetching is requested */ + int pf_dist_A = 2; + int pf_dist_B = 4; + libxsmm_blasint total_blocks = in_tasks_per_thread*ik_tasks_per_thread*cBlocks; + const element_input_type **src_ptr = &A_array[0][0][0]; + const element_input_type **dst_ptr = &A_array_pf[0][0][0]; + for (ii = 0; ii < total_blocks - pf_dist_A; ii++) { + dst_ptr[ii] = src_ptr[ii+pf_dist_A]; + } + src_ptr = &B_array[0][0][0]; + dst_ptr = &B_array_pf[0][0][0]; + for (ii = 0; ii < total_blocks - pf_dist_B; ii++) { + dst_ptr[ii] = src_ptr[ii+pf_dist_B]; + } + total_blocks = in_tasks_per_thread*ik_tasks_per_thread*kBlocks; + src_ptr = &A_array2[0][0][0]; + dst_ptr = &A_array2_pf[0][0][0]; + for (ii = 0; ii < total_blocks - pf_dist_A; ii++) { + dst_ptr[ii] = src_ptr[ii+pf_dist_A]; + } + src_ptr = &B_array2[0][0][0]; + dst_ptr = &B_array2_pf[0][0][0]; + for (ii = 0; ii < total_blocks - pf_dist_B; ii++) { + dst_ptr[ii] = src_ptr[ii+pf_dist_B]; + } + } + + /* let's run the cell in blocks for good locality */ + for (ik = my_ik_start, ii = 0; ik < my_ik_end; ++ik, ii++ ) { + for (in = my_in_start, jj = 0; in < my_in_end; ++in, jj++ ) { + /* z = per_col(b) */ + libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &b[ik*bk]); + /* z += W.x */ + blocks = cBlocks; + batchreduce_kernela(&A_array[ii][jj][0], &B_array[ii][jj][0], &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &blocks, &A_array_pf[ii][jj][0], &B_array_pf[ii][jj][0]); + /* z += U.h */ + blocks = kBlocks; + batchreduce_kernelb(&A_array2[ii][jj][0], &B_array2[ii][jj][0], &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &blocks, &A_array2_pf[ii][jj][0], &B_array2_pf[ii][jj][0]); + +#if defined(LIBXSMM_DNN_RNN_RELU_FWD) + libxsmm_internal_matrix_relu_ld( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &LIBXSMM_VLA_ACCESS(5, h, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk)); +#endif +#if defined(LIBXSMM_DNN_RNN_SIGMOID_FWD) + libxsmm_internal_matrix_sigmoid_ld( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &LIBXSMM_VLA_ACCESS(5, h, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk)); +#endif +#if defined(LIBXSMM_DNN_RNN_TANH_FWD) + libxsmm_internal_matrix_tanh_ld( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &LIBXSMM_VLA_ACCESS(5, h, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk)); +#endif + } + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + } +} else { + /* Auxiliary arrays for batch-reduce gemms */ + const element_input_type *A_array[1024]; + const element_input_type *B_array[1024]; + const element_input_type *A_array2[1024]; + const element_input_type *B_array2[1024]; + assert(kBlocks <= 1024); + assert(cBlocks <= 1024); + + /* lazy barrier init */ + libxsmm_barrier_init(handle->barrier, (int)ltid); + + /* All data is in column-major format */ + for (i = 0; i < t; ++i) { + /* let's run the cell in blocks for good locality */ + for (inik = thr_begin; inik < thr_end; ++inik ) { + in = inik / (K/bk); + ik = inik % (K/bk); + + /* z = per_col(b) */ + libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &b[ik*bk]); + + /* z += W.x */ + /* Prepare arrays for the call */ + for (ic = 0; ic < cBlocks; ic++) { + /* this is a small matmul */ + A_array[ic] = &LIBXSMM_VLA_ACCESS(4, w, ik, ic, 0, 0, cBlocks, bc, bk); + B_array[ic] = &LIBXSMM_VLA_ACCESS(5, x, i, in, ic, 0, 0, nBlocks, cBlocks, bn, bc); + } + /* Reduce batch gemm call */ + blocks = cBlocks; + batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); + + /* z += U.h */ + if (0 == i) { + /* Prepare arrays for the call */ + for (ic = 0; ic < kBlocks; ic++) { + A_array2[ic] = &LIBXSMM_VLA_ACCESS(4, r, ik, ic, 0, 0, kBlocks, bk, bk); + B_array2[ic] = &LIBXSMM_VLA_ACCESS(4, hp, in, ic, 0, 0, kBlocks, bn, bk); + } + /* Reduce batch gemm call */ + blocks = kBlocks; + batchreduce_kernelb(A_array2, B_array2, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); + } else { + /* Prepare arrays for the call */ + for (ic = 0; ic < kBlocks; ic++) { + A_array2[ic] = &LIBXSMM_VLA_ACCESS(4, r, ik, ic, 0, 0, kBlocks, bk, bk); + B_array2[ic] = &LIBXSMM_VLA_ACCESS(5, h, i-1, in, ic, 0, 0, nBlocks, kBlocks, bn, bk); + } + /* Reduce batch gemm call */ + blocks = kBlocks; + batchreduce_kernelb(A_array2, B_array2, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &blocks); + } + +#if defined(LIBXSMM_DNN_RNN_RELU_FWD) + libxsmm_internal_matrix_relu_ld( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &LIBXSMM_VLA_ACCESS(5, h, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk)); +#endif +#if defined(LIBXSMM_DNN_RNN_SIGMOID_FWD) + libxsmm_internal_matrix_sigmoid_ld( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &LIBXSMM_VLA_ACCESS(5, h, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk)); +#endif +#if defined(LIBXSMM_DNN_RNN_TANH_FWD) + libxsmm_internal_matrix_tanh_ld( bk, bn, bk, &LIBXSMM_VLA_ACCESS(5, z, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk), &LIBXSMM_VLA_ACCESS(5, h, i, in, ik, 0, 0, nBlocks, kBlocks, bn, bk)); +#endif + } + libxsmm_barrier_wait(handle->barrier, (int)ltid); + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..0ea81cefca4b7cbb159139a6798bde1446d4f132 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c @@ -0,0 +1,148 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512) +#define LIBXSMM_DNN_CONVERT_F32_BF16(in, out, length) do { \ + unsigned int full_chunks = length / 16; \ + unsigned int remainder = length % 16; \ + int __i = 0; \ + if (remainder == 0) { \ + for ( __i = 0; __i < length; __i+= 16) { \ + _mm256_storeu_si256((__m256i*)(out+__i), _mm512_cvtepi32_epi16( _mm512_srai_epi32( LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16( LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i) ),16)) ); \ + } \ + } else { \ + unsigned int chunk; \ + for ( chunk = 0; chunk < full_chunks; chunk++) { \ + __i = chunk * 16; \ + _mm256_storeu_si256((__m256i*)(out+__i), _mm512_cvtepi32_epi16( _mm512_srai_epi32( LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16( LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i) ),16)) ); \ + } \ + libxsmm_rne_convert_fp32_bf16((const float*)in+16*full_chunks, (libxsmm_bfloat16*)out+16*full_chunks, remainder); \ + } \ +} while(0) + +#define LIBXSMM_DNN_CONVERT_BF16_F32(in, out, length) do { \ + unsigned int full_chunks = length / 16; \ + unsigned int remainder = length % 16; \ + int __i = 0; \ + if (remainder == 0) { \ + for ( __i = 0; __i < length; __i+= 16) { \ + _mm512_storeu_ps( out+__i, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(in+__i))),16))); \ + } \ + } else { \ + unsigned int chunk; \ + for ( chunk = 0; chunk < full_chunks; chunk++) { \ + __i = chunk * 16; \ + _mm512_storeu_ps( out+__i, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(in+__i))),16))); \ + } \ + libxsmm_convert_bf16_f32( (const libxsmm_bfloat16*)in+16*full_chunks, (float*)out+16*full_chunks, remainder); \ + } \ +} while(0) +#endif + +libxsmm_blasint bn = handle->bn; +libxsmm_blasint Bn = handle->Bn; +libxsmm_blasint bc = handle->bc; +libxsmm_blasint Bc = handle->Bc; + +/* loop counters */ +int i = 0; +libxsmm_blasint img1, img2, ifm1, ifm2; + +float rcp_N = 1.0f/handle->desc.N; + +/* computing first logical thread */ +const int ltid = tid - start_thread; + +/* number of tasks that could run in parallel for the batch */ +const int n_work = Bn * bn; +/* compute chunk size */ +const int n_chunksize = (n_work % handle->desc.threads == 0) ? (n_work / handle->desc.threads) : ((n_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int n_thr_begin = (ltid * n_chunksize < n_work) ? (ltid * n_chunksize) : n_work; +const int n_thr_end = ((ltid + 1) * n_chunksize < n_work) ? ((ltid + 1) * n_chunksize) : n_work; + +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16) || defined(LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512) +/* number of tasks that could run in parallel for the batch */ +const int nc_work = Bn * bn; +/* compute chunk size */ +const int nc_chunksize = (nc_work % handle->desc.threads == 0) ? (nc_work / handle->desc.threads) : ((nc_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int nc_thr_begin = (ltid * nc_chunksize < nc_work) ? (ltid * nc_chunksize) : nc_work; +const int nc_thr_end = ((ltid + 1) * nc_chunksize < nc_work) ? ((ltid + 1) * nc_chunksize) : nc_work; + +libxsmm_bfloat16* poutput_bf16 = (element_output_type*)handle->reg_output->data; +libxsmm_bfloat16* pdinput_bf16 = (element_input_type*)handle->grad_input->data; +float* poutput_fp32 = (float*)handle->scratch; +float* pdinput_fp32 = ((float*)handle->scratch)+(handle->desc.N*handle->desc.C); +LIBXSMM_VLA_DECL(4, const float, output, poutput_fp32, Bc, bn, bc); +LIBXSMM_VLA_DECL(4, float, dinput, pdinput_fp32, Bc, bn, bc); +#else +LIBXSMM_VLA_DECL(4, const element_output_type, output, (element_output_type*)handle->reg_output->data, Bc, bn, bc); +LIBXSMM_VLA_DECL(4, element_input_type, dinput, (element_input_type*)handle->grad_input->data, Bc, bn, bc); +#endif +LIBXSMM_VLA_DECL(2, const element_label_type, label, (element_label_type*)handle->label->data, bn); + +/* lazy barrier init */ +libxsmm_barrier_init( handle->barrier, ltid ); + +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16) +for ( i = nc_thr_begin; i < nc_thr_end; ++i ) { + libxsmm_bfloat16_hp out; + out.i[0] = 0; + out.i[1] = poutput_bf16[i]; + poutput_fp32[i] = out.f; +} + +libxsmm_barrier_wait( handle->barrier, ltid ); +#endif +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512) +LIBXSMM_DNN_CONVERT_BF16_F32(poutput_bf16+nc_thr_begin, poutput_fp32+nc_thr_begin, nc_thr_end-nc_thr_begin); + +libxsmm_barrier_wait( handle->barrier, ltid ); +#endif + +for ( i = n_thr_begin; i < n_thr_end; ++i ) { + img1 = i/bn; + img2 = i%bn; + + /* set output to input and set compute max per image */ + for ( ifm1 = 0; ifm1 < Bc; ++ifm1 ) { + for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) { + if ( (ifm1*Bc)+ifm2 == (libxsmm_blasint)LIBXSMM_VLA_ACCESS( 2, label, img1, img2, bn ) ) { + LIBXSMM_VLA_ACCESS( 4, dinput, img1, ifm1, img2, ifm2, Bc, bn, bc ) = + ( LIBXSMM_VLA_ACCESS( 4, output, img1, ifm1, img2, ifm2, Bc, bn, bc ) - 1.0f ) * rcp_N * handle->desc.loss_weight; + } else { + LIBXSMM_VLA_ACCESS( 4, dinput, img1, ifm1, img2, ifm2, Bc, bn, bc ) = + LIBXSMM_VLA_ACCESS( 4, output, img1, ifm1, img2, ifm2, Bc, bn, bc ) * rcp_N * handle->desc.loss_weight; + } + } + } +} + +libxsmm_barrier_wait( handle->barrier, ltid ); + +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16) +for ( i = nc_thr_begin; i < nc_thr_end; ++i ) { + libxsmm_bfloat16_hp din; + din.f = pdinput_fp32[i]; + pdinput_bf16[i] = din.i[1]; +} + +libxsmm_barrier_wait( handle->barrier, ltid ); +#endif +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512) +LIBXSMM_DNN_CONVERT_F32_BF16(pdinput_fp32+nc_thr_begin, pdinput_bf16+nc_thr_begin, nc_thr_end-nc_thr_begin); + +libxsmm_barrier_wait( handle->barrier, ltid ); +#undef LIBXSMM_DNN_CONVERT_F32_BF16 +#undef LIBXSMM_DNN_CONVERT_BF16_F32 +#endif + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..af7b022ad9778a3a294fcdde3cccb57583d027c2 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c @@ -0,0 +1,179 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512) +#define LIBXSMM_DNN_CONVERT_F32_BF16(in, out, length) do { \ + unsigned int full_chunks = length / 16; \ + unsigned int remainder = length % 16; \ + int __i = 0; \ + if (remainder == 0) { \ + for ( __i = 0; __i < length; __i+= 16) { \ + _mm256_storeu_si256((__m256i*)(out+__i), _mm512_cvtepi32_epi16( _mm512_srai_epi32( LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16( LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i) ),16)) ); \ + } \ + } else { \ + unsigned int chunk; \ + for ( chunk = 0; chunk < full_chunks; chunk++) { \ + __i = chunk * 16; \ + _mm256_storeu_si256((__m256i*)(out+__i), _mm512_cvtepi32_epi16( _mm512_srai_epi32( LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16( LIBXSMM_INTRINSICS_MM512_LOAD_PS((const float*)in+__i) ),16)) ); \ + } \ + libxsmm_rne_convert_fp32_bf16((const float*)in+16*full_chunks, (libxsmm_bfloat16*)out+16*full_chunks, remainder); \ + } \ +} while(0) + +#define LIBXSMM_DNN_CONVERT_BF16_F32(in, out, length) do { \ + unsigned int full_chunks = length / 16; \ + unsigned int remainder = length % 16; \ + int __i = 0; \ + if (remainder == 0) { \ + for ( __i = 0; __i < length; __i+= 16) { \ + _mm512_storeu_ps( out+__i, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(in+__i))),16))); \ + } \ + } else { \ + unsigned int chunk; \ + for ( chunk = 0; chunk < full_chunks; chunk++) { \ + __i = chunk * 16; \ + _mm512_storeu_ps( out+__i, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(in+__i))),16))); \ + } \ + libxsmm_convert_bf16_f32( (const libxsmm_bfloat16*)in+16*full_chunks, (float*)out+16*full_chunks, remainder); \ + } \ +} while(0) +#endif + +libxsmm_blasint bn = handle->bn; +libxsmm_blasint Bn = handle->Bn; +libxsmm_blasint bc = handle->bc; +libxsmm_blasint Bc = handle->Bc; + +/* loop counters */ +int i = 0; +libxsmm_blasint img1, img2, ifm1, ifm2; + +/* computing first logical thread */ +const int ltid = tid - start_thread; + +/* number of tasks that could run in parallel for the batch */ +const int n_work = Bn * bn; +/* compute chunk size */ +const int n_chunksize = (n_work % handle->desc.threads == 0) ? (n_work / handle->desc.threads) : ((n_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int n_thr_begin = (ltid * n_chunksize < n_work) ? (ltid * n_chunksize) : n_work; +const int n_thr_end = ((ltid + 1) * n_chunksize < n_work) ? ((ltid + 1) * n_chunksize) : n_work; + +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16) || defined(LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512) +/* number of tasks that could run in parallel for the batch */ +const int nc_work = Bn * bn; +/* compute chunk size */ +const int nc_chunksize = (nc_work % handle->desc.threads == 0) ? (nc_work / handle->desc.threads) : ((nc_work / handle->desc.threads) + 1); +/* compute thr_begin and thr_end */ +const int nc_thr_begin = (ltid * nc_chunksize < nc_work) ? (ltid * nc_chunksize) : nc_work; +const int nc_thr_end = ((ltid + 1) * nc_chunksize < nc_work) ? ((ltid + 1) * nc_chunksize) : nc_work; + +libxsmm_bfloat16* poutput_bf16 = (element_output_type*)handle->reg_output->data; +libxsmm_bfloat16* pinput_bf16 = (element_input_type*)handle->reg_input->data; +float* poutput_fp32 = (float*)handle->scratch; +float* pinput_fp32 = ((float*)handle->scratch)+(handle->desc.N*handle->desc.C); +LIBXSMM_VLA_DECL(4, float, output, poutput_fp32, Bc, bn, bc); +LIBXSMM_VLA_DECL(4, const float, input, pinput_fp32, Bc, bn, bc); +#else +LIBXSMM_VLA_DECL(4, element_output_type, output, (element_output_type*)handle->reg_output->data, Bc, bn, bc); +LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type*)handle->reg_input->data, Bc, bn, bc); +#endif +LIBXSMM_VLA_DECL(2, const element_label_type, label, (element_label_type*)handle->label->data, bn); + +/* lazy barrier init */ +libxsmm_barrier_init( handle->barrier, ltid ); + +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16) +for ( i = nc_thr_begin; i < nc_thr_end; ++i ) { + libxsmm_bfloat16_hp in; + in.i[0] = 0; + in.i[1] = pinput_bf16[i]; + pinput_fp32[i] = in.f; +} + +libxsmm_barrier_wait( handle->barrier, ltid ); +#endif +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512) +LIBXSMM_DNN_CONVERT_BF16_F32(pinput_bf16+nc_thr_begin, pinput_fp32+nc_thr_begin, nc_thr_end-nc_thr_begin); + +libxsmm_barrier_wait( handle->barrier, ltid ); +#endif + +for ( i = n_thr_begin; i < n_thr_end; ++i ) { + float max = FLT_MIN; + float sum_of_exp = 0.0f; + + img1 = i/bn; + img2 = i%bn; + + /* set output to input and set compute max per image */ + for ( ifm1 = 0; ifm1 < Bc; ++ifm1 ) { + for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) { + LIBXSMM_VLA_ACCESS( 4, output, img1, ifm1, img2, ifm2, Bc, bn, bc ) = LIBXSMM_VLA_ACCESS( 4, input, img1, ifm1, img2, ifm2, Bc, bn, bc ); + if ( LIBXSMM_VLA_ACCESS( 4, input, img1, ifm1, img2, ifm2, Bc, bn, bc ) > max ) { + max = LIBXSMM_VLA_ACCESS( 4, input, img1, ifm1, img2, ifm2, Bc, bn, bc ); + } + } + } + + /* sum exp over outputs */ + for ( ifm1 = 0; ifm1 < Bc; ++ifm1 ) { + for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) { + LIBXSMM_VLA_ACCESS( 4, output, img1, ifm1, img2, ifm2, Bc, bn, bc ) = (float)exp( (double)(LIBXSMM_VLA_ACCESS( 4, output, img1, ifm1, img2, ifm2, Bc, bn, bc ) - max) ); + sum_of_exp += LIBXSMM_VLA_ACCESS( 4, output, img1, ifm1, img2, ifm2, Bc, bn, bc ); + } + } + + /* scale output */ + sum_of_exp = 1.0f/sum_of_exp; + for ( ifm1 = 0; ifm1 < Bc; ++ifm1 ) { + for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) { + LIBXSMM_VLA_ACCESS( 4, output, img1, ifm1, img2, ifm2, Bc, bn, bc ) = LIBXSMM_VLA_ACCESS( 4, output, img1, ifm1, img2, ifm2, Bc, bn, bc ) * sum_of_exp; + } + } +} + +libxsmm_barrier_wait( handle->barrier, ltid ); + +/* calculate loss single threaded */ +if ( ltid == 0 ) { + handle->loss = 0.0f; + for ( img1 = 0; img1 < Bn; ++img1 ) { + for ( img2 = 0; img2 FLT_MIN ) ? LIBXSMM_VLA_ACCESS( 4, output, img1, ifm1b, img2, ifm2b, Bc, bn, bc ) : FLT_MIN; + handle->loss = LIBXSMM_LOGF( val ); + } + } + handle->loss = ((-1.0f)*handle->loss)/handle->desc.N; +} + +libxsmm_barrier_wait( handle->barrier, ltid ); + +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16) +for ( i = nc_thr_begin; i < nc_thr_end; ++i ) { + libxsmm_bfloat16_hp in; + in.f = poutput_fp32[i]; + poutput_bf16[i] = in.i[1]; +} + +libxsmm_barrier_wait( handle->barrier, ltid ); +#endif +#if defined(LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512) +LIBXSMM_DNN_CONVERT_F32_BF16(poutput_fp32+nc_thr_begin, poutput_bf16+nc_thr_begin, nc_thr_end-nc_thr_begin); + +libxsmm_barrier_wait( handle->barrier, ltid ); +#undef LIBXSMM_DNN_CONVERT_F32_BF16 +#undef LIBXSMM_DNN_CONVERT_BF16_F32 +#endif + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..300a7bafa8b26b8d232eaa3ecec34d68bee3ff2b --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c @@ -0,0 +1,34 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +/* use for-loops to potentially leverage NUMA in the future */ +int i1, i2, i3; +#if defined(LIBXSMM_DNN_COPY_LOW_PRECISION) +int lpb = tensor->layout->dim_size[0]; +int bfm = tensor->layout->dim_size[1]; +int fmb = tensor->layout->dim_size[2]; +#else +int lpb = 1; +int bfm = tensor->layout->dim_size[0]; +int fmb = tensor->layout->dim_size[1]; +#endif + +const element_type* user_data = (const element_type*)data; +LIBXSMM_VLA_DECL(3, element_type, handle_data, (element_type*)tensor->data, bfm, lpb); + +for (i1 = 0; i1 < fmb; ++i1) { + for (i2 = 0; i2 < bfm; ++i2) { + for (i3 = 0; i3 < lpb; ++i3) { + LIBXSMM_VLA_ACCESS(3, handle_data, i1, i2, i3, bfm, lpb) = user_data[(i1*bfm*lpb) + (i2*lpb) + i3]; + } + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..f559d8868cff8a27cf6fb2a8ccba0406ed9b574f --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c @@ -0,0 +1,34 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +/* use for-loops to potentially leverage NUMA in the future */ +int i1, i2, i3; +#if defined(LIBXSMM_DNN_COPY_LOW_PRECISION) +int lpb = tensor->layout->dim_size[0]; +int bfm = tensor->layout->dim_size[1]; +int fmb = tensor->layout->dim_size[2]; +#else +int lpb = 1; +int bfm = tensor->layout->dim_size[0]; +int fmb = tensor->layout->dim_size[1]; +#endif + +element_type* user_data = (element_type*)data; +LIBXSMM_VLA_DECL(3, const element_type, handle_data, (const element_type*)tensor->data, bfm, lpb); + +for (i1 = 0; i1 < fmb; ++i1) { + for (i2 = 0; i2 < bfm; ++i2) { + for (i3 = 0; i3 < lpb; ++i3) { + user_data[(i1*bfm*lpb) + (i2*lpb) + i3] = LIBXSMM_VLA_ACCESS(3, handle_data, i1, i2, i3, bfm, lpb); + } + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..d1c1d1a5d533bdb022594370d9d14c9b530bd287 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c @@ -0,0 +1,51 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas, Hans Pabst (Intel Corp.) +******************************************************************************/ + +int i1, i2, i3, i4, i5, i6; +int lpb, bfm, W, H, fmb, N, C; +/* low precision formatting */ +if ( tensor->layout->num_dims == 6 ) { + lpb = tensor->layout->dim_size[0]; + bfm = tensor->layout->dim_size[1]; + W = tensor->layout->dim_size[2]; + H = tensor->layout->dim_size[3]; + fmb = tensor->layout->dim_size[4]; + N = tensor->layout->dim_size[5]; +} else { + lpb = 1; + bfm = tensor->layout->dim_size[0]; + W = tensor->layout->dim_size[1]; + H = tensor->layout->dim_size[2]; + fmb = tensor->layout->dim_size[3]; + N = tensor->layout->dim_size[4]; +} +C = fmb * bfm * lpb; + +/*printf(" layout act copy in N %i fmb %i H %i W %i bfm %i lpb %i \n", N, fmb, H, W, bfm, lpb);*/ +{ + LIBXSMM_VLA_DECL(6, element_type, handle_data_1, (element_type*)tensor->data, fmb, H, W, bfm, lpb); + LIBXSMM_VLA_DECL(4, const element_type, user_data, (const element_type*)data, C, H, W); + + for (i1 = 0; i1 < N; ++i1) { + for (i2 = 0; i2 < fmb; ++i2) { + for (i3 = 0; i3 < H; ++i3) { + for (i4 = 0; i4 < W; ++i4) { + for (i5 = 0; i5 < bfm; ++i5) { + for (i6 = 0; i6 < lpb; ++i6) { + LIBXSMM_VLA_ACCESS(6, handle_data_1, i1, i2, i3, i4, i5, i6, fmb, H, W, bfm, lpb) = + LIBXSMM_VLA_ACCESS(4, user_data, i1, ((size_t)i2*bfm*lpb) + ((size_t)i5*lpb) + i6, i3, i4, C, H, W); + } + } + } + } + } + } +} diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..356809e64a69d829d418638465837326e3748026 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c @@ -0,0 +1,51 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas, Hans Pabst (Intel Corp.) +******************************************************************************/ + +int i1, i2, i3, i4, i5, i6; +int lpb, bfm, W, H, fmb, N, C; +/* low precision formatting */ +if ( tensor->layout->num_dims == 6 ) { + lpb = tensor->layout->dim_size[0]; + bfm = tensor->layout->dim_size[1]; + W = tensor->layout->dim_size[2]; + H = tensor->layout->dim_size[3]; + fmb = tensor->layout->dim_size[4]; + N = tensor->layout->dim_size[5]; +} else { + lpb = 1; + bfm = tensor->layout->dim_size[0]; + W = tensor->layout->dim_size[1]; + H = tensor->layout->dim_size[2]; + fmb = tensor->layout->dim_size[3]; + N = tensor->layout->dim_size[4]; +} +C = fmb * bfm * lpb; + +/* printf(" layout act copy out N %i fmb %i H %i W %i bfm %i lpb %i \n", N, fmb, H, W, bfm, lpb); */ +{ + LIBXSMM_VLA_DECL(6, const element_type, handle_data_1, (const element_type*)tensor->data, fmb, H, W, bfm, lpb); + LIBXSMM_VLA_DECL(4, element_type, user_data, (element_type*)data, C, H, W); + + for (i1 = 0; i1 < N; ++i1) { + for (i2 = 0; i2 < fmb; ++i2) { + for (i3 = 0; i3 < H; ++i3) { + for (i4 = 0; i4 < W; ++i4) { + for (i5 = 0; i5 < bfm; ++i5) { + for (i6 = 0; i6 < lpb; ++i6) { + LIBXSMM_VLA_ACCESS(4, user_data, i1, ((size_t)i2*bfm*lpb) + ((size_t)i5*lpb) + i6, i3, i4, C, H, W) = + LIBXSMM_VLA_ACCESS(6, handle_data_1, i1, i2, i3, i4, i5, i6, fmb, H, W, bfm, lpb); + } + } + } + } + } + } +} diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..456f54e2fddb30ae999d14759494db3b504c34cd --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c @@ -0,0 +1,64 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas, Hans Pabst (Intel Corp.) +******************************************************************************/ + +/* @TODO: use for-loops to potentially leverage NUMA in the future */ +int i1, i2, i3, i4, i5, i6, i7; +int lpb = 0; +int bofm = 0; +int bifm = 0; +int S = 0; +int R = 0; +int ifmb = 0; +int ofmb = 0; +/* low precision formatting */ +if ( tensor->layout->num_dims == 7 ) { + lpb = tensor->layout->dim_size[0]; + bofm = tensor->layout->dim_size[1]; + bifm = tensor->layout->dim_size[2]; + S = tensor->layout->dim_size[3]; + R = tensor->layout->dim_size[4]; + ifmb = tensor->layout->dim_size[5]; + ofmb = tensor->layout->dim_size[6]; +} else if ( tensor->layout->num_dims == 6 ) { + lpb = 1; + bofm = tensor->layout->dim_size[0]; + bifm = tensor->layout->dim_size[1]; + S = tensor->layout->dim_size[2]; + R = tensor->layout->dim_size[3]; + ifmb = tensor->layout->dim_size[4]; + ofmb = tensor->layout->dim_size[5]; +} else { + /* should not happen, @TODO throw ERR */ +} + +/*printf("Layout of filters fil ofmb %i ifmb %i R %i S %i bifm %i bofm %i lpb %i \n", ofmb, ifmb, R, S, bifm, bofm, lpb);*/ +{ + LIBXSMM_VLA_DECL(7, element_type, handle_data_1, (element_type*)tensor->data, ifmb, R, S, bifm, bofm, lpb); + LIBXSMM_VLA_DECL(4, const element_type, user_data, (const element_type*)data, ifmb * bifm * lpb, R, S); + + for (i1 = 0; i1 < ofmb; ++i1) { + for (i2 = 0; i2 < ifmb; ++i2) { + for (i3 = 0; i3 < R; ++i3) { + for (i4 = 0; i4 < S; ++i4) { + for (i5 = 0; i5 < bifm; ++i5) { + for (i6 = 0; i6 < bofm; ++i6) { + for (i7 = 0; i7 < lpb; ++i7) { + LIBXSMM_VLA_ACCESS(7, handle_data_1, i1, i2, i3, i4, i5, i6, i7, ifmb, R, S, bifm, bofm, lpb) = + LIBXSMM_VLA_ACCESS(4, user_data, i1 * bofm + i6, ((size_t)i2*bifm*lpb) + ((size_t)i5*lpb) + i7, i3, i4, ifmb * bifm * lpb, R, S); + } + } + } + } + } + } + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..63175f9085cabbde3885702f2f34f08cfcf2f90c --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c @@ -0,0 +1,63 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke, Evangelos Georganas, Hans Pabst (Intel Corp.) +******************************************************************************/ + +/* @TODO: use for-loops to potentially leverage NUMA in the future */ +int i1, i2, i3, i4, i5, i6, i7; +int lpb = 0; +int bofm = 0; +int bifm = 0; +int S = 0; +int R = 0; +int ifmb = 0; +int ofmb = 0; +/* low precision formatting */ +if ( tensor->layout->num_dims == 7 ) { + lpb = tensor->layout->dim_size[0]; + bofm = tensor->layout->dim_size[1]; + bifm = tensor->layout->dim_size[2]; + S = tensor->layout->dim_size[3]; + R = tensor->layout->dim_size[4]; + ifmb = tensor->layout->dim_size[5]; + ofmb = tensor->layout->dim_size[6]; +} else if ( tensor->layout->num_dims == 6 ) { + lpb = 1; + bofm = tensor->layout->dim_size[0]; + bifm = tensor->layout->dim_size[1]; + S = tensor->layout->dim_size[2]; + R = tensor->layout->dim_size[3]; + ifmb = tensor->layout->dim_size[4]; + ofmb = tensor->layout->dim_size[5]; +} else { + /* should not happen, @TODO throw ERR */ +} + +{ + LIBXSMM_VLA_DECL(4, element_type, user_data, (element_type*)data, ifmb * bifm * lpb, R, S); + LIBXSMM_VLA_DECL(7, const element_type, handle_data_1, (const element_type*)tensor->data, ifmb, R, S, bifm, bofm, lpb); + + for (i1 = 0; i1 < ofmb; ++i1) { + for (i2 = 0; i2 < ifmb; ++i2) { + for (i3 = 0; i3 < R; ++i3) { + for (i4 = 0; i4 < S; ++i4) { + for (i5 = 0; i5 < bifm; ++i5) { + for (i6 = 0; i6 < bofm; ++i6) { + for (i7 = 0; i7 < lpb; ++i7) { + LIBXSMM_VLA_ACCESS(4, user_data, i1 * bofm + i6, ((size_t)i2*bifm*lpb) + ((size_t)i5*lpb) + i7, i3, i4, ifmb * bifm * lpb, R, S) = + LIBXSMM_VLA_ACCESS(7, handle_data_1, i1, i2, i3, i4, i5, i6, i7, ifmb, R, S, bifm, bofm, lpb); + } + } + } + } + } + } + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_zero_rim_st_input_custom.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_zero_rim_st_input_custom.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..1cf97883fed37bff9e729e7486417a69b2d4d2e5 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_zero_rim_st_input_custom.tpl.c @@ -0,0 +1,25 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +/* this is crappy as it requires a complicated if... */ +if (handle->desc.pad_h_in > 0 || handle->desc.pad_w_in > 0) { + for ( ij = 0; ij < handle->ifhp; ij++ ) { + for ( ii = 0; ii < handle->ifwp; ii++ ) { + if ( (ij < handle->desc.pad_h_in) || (ij >= (handle->desc.H+handle->desc.pad_h_in)) || + (ii < handle->desc.pad_w_in) || (ii >= (handle->desc.W+handle->desc.pad_w_in)) ) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(5, del_input, img, ifm1lpblock, ij, ii, ifm2, handle->blocksifm*handle->fm_lp_block, handle->ifhp, handle->ifwp, handle->ifmblock) = (element_input_type)0; + } + } + } + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_dnn_zero_rim_st_input_nhwc.tpl.c b/third_party/libxsmm/src/template/libxsmm_dnn_zero_rim_st_input_nhwc.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..9809dfd3fa570b8a0754eee4577ebcea4670d644 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_dnn_zero_rim_st_input_nhwc.tpl.c @@ -0,0 +1,25 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +/* this is crappy as it requires a complicated if... */ +if (handle->desc.pad_h_in > 0 || handle->desc.pad_w_in > 0) { + for ( ij = 0; ij < handle->ifhp; ij++ ) { + for ( ii = 0; ii < handle->ifwp; ii++ ) { + if ( (ij < handle->desc.pad_h_in) || (ij >= (handle->desc.H+handle->desc.pad_h_in)) || + (ii < handle->desc.pad_w_in) || (ii >= (handle->desc.W+handle->desc.pad_w_in)) ) { + for (ifm2 = 0; ifm2 < handle->ifmblock; ++ifm2) { + LIBXSMM_VLA_ACCESS(5, del_input, img, ij, ii, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock) = (element_input_type)0; + } + } + } + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_internal_gru_bwdupd_fused_eltwise_1.tpl.c b/third_party/libxsmm/src/template/libxsmm_internal_gru_bwdupd_fused_eltwise_1.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..1147cde51319eb51c3724780ab318e4cdec57fac --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_internal_gru_bwdupd_fused_eltwise_1.tpl.c @@ -0,0 +1,72 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +{ + libxsmm_blasint _k, _j; + __m512 _vdh, _vdout, _vdf, _vdc, _vf, _vc, _vhp, _vt1, _vt2; + element_input_type* _dout = &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K); + element_input_type* _hp; + element_input_type* _c = &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K); + element_input_type* _f = &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K); + element_input_type* _dh = &LIBXSMM_VLA_ACCESS(3, dh, j, in, ik, N, K); + element_input_type* _dc = &LIBXSMM_VLA_ACCESS(2, dc, in, ik, K); + element_input_type* _df = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + const __m512 _vneg_ones = _mm512_set1_ps( (float)-1.0 ); + const __m512 _vones = _mm512_set1_ps( (float)1.0 ); + if (0 == j) { + _hp = &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K); + } else { + _hp = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K); + } + if (j == t-1) { + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_dh[(_j*K)+_k]); + LIBXSMM_INTRINSICS_MM512_STREAM_PS(&_dout[(_j*K)+_k], _vdout); + _vc = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_c[(_j*K)+_k]); + _vt1 = _mm512_sub_ps(_vones, _vc); + _vt1 = _mm512_mul_ps(_vdout, _vt1); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_f[(_j*K)+_k]); + _vt2 = _mm512_fnmsub_ps(_vf, _vf, _vneg_ones); + _vdf = _mm512_mul_ps(_vt1, _vt2); + LIBXSMM_INTRINSICS_MM512_STREAM_PS(&_df[(_j*K)+_k], _vdf); + _vhp = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_hp[(_j*K)+_k]); + _vt1 = _mm512_mul_ps(_vt1, _vc); + _vt2 = _mm512_sub_ps(_vhp, _vf); + _vdc = _mm512_mul_ps(_vt1, _vt2); + LIBXSMM_INTRINSICS_MM512_STREAM_PS(&_dc[(_j*K)+_k], _vdc); + } + } + } else { + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_dout[(_j*K)+_k]); + _vdh = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_dh[(_j*K)+_k]); + _vdout = _mm512_add_ps(_vdout, _vdh); + LIBXSMM_INTRINSICS_MM512_STREAM_PS(&_dout[(_j*K)+_k], _vdout); + _vc = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_c[(_j*K)+_k]); + _vt1 = _mm512_sub_ps(_vones, _vc); + _vt1 = _mm512_mul_ps(_vdout, _vt1); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_f[(_j*K)+_k]); + _vt2 = _mm512_fnmsub_ps(_vf, _vf, _vneg_ones); + _vdf = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS(&_df[(_j*K)+_k], _vdf); + _vhp = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_hp[(_j*K)+_k]); + _vt1 = _mm512_mul_ps(_vt1, _vc); + _vt2 = _mm512_sub_ps(_vhp, _vf); + _vdc = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS(&_dc[(_j*K)+_k], _vdc); + } + } + } +} diff --git a/third_party/libxsmm/src/template/libxsmm_internal_gru_bwdupd_fused_eltwise_2.tpl.c b/third_party/libxsmm/src/template/libxsmm_internal_gru_bwdupd_fused_eltwise_2.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..aa0d3273b84dd7c293041774153d5e4126385ffc --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_internal_gru_bwdupd_fused_eltwise_2.tpl.c @@ -0,0 +1,38 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Kunal Banerjee (Intel Corp.) +******************************************************************************/ + +{ + libxsmm_blasint _k, _j; + __m512 _vdi, _vdo, _vi, _vhp, _vt1, _vt2; + element_input_type* _hp; + element_input_type* _i = &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K); + element_input_type* _di = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + element_input_type* _do = &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + const __m512 _vones = _mm512_set1_ps( (float)1.0 ); + if (0 == j) { + _hp = &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K); + } else { + _hp = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K); + } + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vi = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_i[(_j*K)+_k]); + _vt1 = _mm512_sub_ps(_vones, _vi); + _vt1 = _mm512_mul_ps(_vi, _vt1); + _vhp = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_hp[(_j*K)+_k]); + _vdo = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&_do[(_j*K)+_k]); + _vt2 = _mm512_mul_ps(_vdo, _vhp); + _vdi = _mm512_mul_ps(_vt1, _vt2); + LIBXSMM_INTRINSICS_MM512_STREAM_PS(&_di[(_j*K)+_k], _vdi); + } + } +} diff --git a/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise.tpl.c b/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..e3b4d9dfeb91476ef0f488f9121047d024f91f18 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise.tpl.c @@ -0,0 +1,113 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.), Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +{ + libxsmm_blasint _k, _j; + __m512 _vdout, _vdh, _vo, _vt1, _vt2, _vco, _vdcs, _vdcp, _vi, _vci, _vdci, _vdi, _vcps, _vf, _vdf, _vdp; + element_input_type* _dout = &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K); + element_input_type* _dh = &LIBXSMM_VLA_ACCESS(3, dh, j, in, ik, N, K); + element_input_type* _o = &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K); + element_input_type* _co = &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K); + element_input_type* _dcs = &LIBXSMM_VLA_ACCESS(2, dcs, in, ik, K); + element_input_type* _i = &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K); + element_input_type* _ci = &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K); + element_input_type* _dci = &LIBXSMM_VLA_ACCESS(2, dci, in, ik, K); + element_input_type* _di = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + element_input_type* _cps = cps_ptr; + element_input_type* _f = &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K); + element_input_type* _df = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + element_input_type* _dp = &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + element_input_type* _dcp = &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K); + const __m512 _vneg_ones = _mm512_set1_ps( (float)-1.0 ); + const __m512 _vones = _mm512_set1_ps( (float)1.0 ); + if (j == t-1) { + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dh[(_j*K)+_k] ); + _vo = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vdout, _vo ); + _vco = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_co[(_j*K)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vco, _vco, _vneg_ones); + _vt1 = _mm512_mul_ps( _vt1, _vt2 ); + _vdcs = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dcs[(_j*K)+_k] ); + _vdcp = _mm512_add_ps( _vdcs, _vt1 ); + _vi = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vi, _vdcp ); + _vci = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*K)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vci, _vci, _vneg_ones); + _vdci = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dci[(_j*K)+_k], _vdci ); + _vt1 = _mm512_mul_ps( _vci, _vdcp ); + _vt2 = _mm512_sub_ps( _vones, _vi ); + _vdi = _mm512_mul_ps( _vi, _vt2); + _vdi = _mm512_mul_ps( _vdi, _vt1); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_di[(_j*K)+_k], _vdi ); + _vcps = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vcps, _vdcp ); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*K)+_k] ); + _vt2 = _mm512_sub_ps( _vones, _vf ); + _vdf = _mm512_mul_ps( _vf, _vt2); + _vdf = _mm512_mul_ps( _vdf, _vt1); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_df[(_j*K)+_k], _vdf ); + _vt1 = _mm512_mul_ps( _vdout, _vco); + _vt2 = _mm512_sub_ps( _vones, _vo ); + _vt2 = _mm512_mul_ps( _vo, _vt2); + _vdp = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dp[(_j*K)+_k], _vdp ); + _vdcp = _mm512_mul_ps( _vdcp, _vf); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dcp[(_j*K)+_k], _vdcp ); + } + } + } else { + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dout[(_j*K)+_k] ); + _vdh = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dh[(_j*K)+_k] ); + _vdout = _mm512_add_ps( _vdout, _vdh ); + _vo = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vdout, _vo ); + _vco = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_co[(_j*K)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vco, _vco, _vneg_ones); + _vt1 = _mm512_mul_ps( _vt1, _vt2 ); + _vdcp = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dcp[(_j*K)+_k] ); + _vdcp = _mm512_add_ps( _vdcp, _vt1 ); + _vi = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vi, _vdcp ); + _vci = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*K)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vci, _vci, _vneg_ones); + _vdci = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dci[(_j*K)+_k], _vdci ); + _vt1 = _mm512_mul_ps( _vci, _vdcp ); + _vt2 = _mm512_sub_ps( _vones, _vi ); + _vdi = _mm512_mul_ps( _vi, _vt2); + _vdi = _mm512_mul_ps( _vdi, _vt1); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_di[(_j*K)+_k], _vdi ); + _vcps = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vcps, _vdcp ); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*K)+_k] ); + _vt2 = _mm512_sub_ps( _vones, _vf ); + _vdf = _mm512_mul_ps( _vf, _vt2); + _vdf = _mm512_mul_ps( _vdf, _vt1); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_df[(_j*K)+_k], _vdf ); + _vt1 = _mm512_mul_ps( _vdout, _vco); + _vt2 = _mm512_sub_ps( _vones, _vo ); + _vt2 = _mm512_mul_ps( _vo, _vt2); + _vdp = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dp[(_j*K)+_k], _vdp ); + _vdcp = _mm512_mul_ps( _vdcp, _vf); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dcp[(_j*K)+_k], _vdcp ); + } + } + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise_ncnc_reformat_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise_ncnc_reformat_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..fc1d8c6816788e4930f11190a6b0338c5467e22e --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise_ncnc_reformat_bf16.tpl.c @@ -0,0 +1,159 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.), Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +#define NATIVE_STORECVT_F32_BF16(A,B) _mm256_storeu_si256((__m256i*)(A), (__m256i)LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(B)) +{ + float* _dout = &LIBXSMM_VLA_ACCESS(4, dout, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _dh = &LIBXSMM_VLA_ACCESS(5, dh, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + element_input_type* _o = &LIBXSMM_VLA_ACCESS(5, o, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + element_input_type* _co = &LIBXSMM_VLA_ACCESS(5, co, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + element_input_type* _dcs = &LIBXSMM_VLA_ACCESS(4, dcs, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _ii = &LIBXSMM_VLA_ACCESS(5, i, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + element_input_type* _ci = &LIBXSMM_VLA_ACCESS(5, ci, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + element_input_type* _dci = &LIBXSMM_VLA_ACCESS(4, dci, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _di = &LIBXSMM_VLA_ACCESS(4, di, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _cps = cps_ptr; + element_input_type* _f = &LIBXSMM_VLA_ACCESS(5, f, j, inb, ikb, 0, 0, nBlocks, kBlocks, bn, bk); + element_input_type* _df = &LIBXSMM_VLA_ACCESS(4, df, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _dp = &LIBXSMM_VLA_ACCESS(4, dp, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _dcp = &LIBXSMM_VLA_ACCESS(4, dcp, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _dciB = &LIBXSMM_VLA_ACCESS(5, dciB, ikb, inb, 0, 0, 0, nBlocks, bn_lp, bk, lpb); + element_input_type* _diB = &LIBXSMM_VLA_ACCESS(5, diB, ikb, inb, 0, 0, 0, nBlocks, bn_lp, bk, lpb); + element_input_type* _dfB = &LIBXSMM_VLA_ACCESS(5, dfB, ikb, inb, 0, 0, 0, nBlocks, bn_lp, bk, lpb); + element_input_type* _dpB = &LIBXSMM_VLA_ACCESS(5, dpB, ikb, inb, 0, 0, 0, nBlocks, bn_lp, bk, lpb); + + libxsmm_blasint _k, _j; + __m512 _vdout, _vdh, _vo, _vt1, _vt2, _vco, _vdcs, _vdcp, _vii, _vci, _vdci, _vdi, _vcps, _vf, _vdf, _vdp; + const __m512 _neg_ones = _mm512_set1_ps( (float)-1.0 ); + const __m512 _ones = _mm512_set1_ps( (float)1.0 ); + const int _lpb = 2; + + if (j == t-1) { + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = _mm512_loadcvt_bf16_fp32( &_dh[(_j*bk)+_k] ); + _vo = _mm512_loadcvt_bf16_fp32( &_o[(_j*bk)+_k] ); + _vt1 = _mm512_mul_ps( _vdout, _vo ); + _vco = _mm512_loadcvt_bf16_fp32( &_co[(_j*bk)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vco, _vco, _neg_ones); + _vt1 = _mm512_mul_ps( _vt1, _vt2 ); + _vdcs = _mm512_loadcvt_bf16_fp32( &_dcs[(_j*bk)+_k] ); + _vdcp = _mm512_add_ps( _vdcs, _vt1 ); + _vii = _mm512_loadcvt_bf16_fp32( &_ii[(_j*bk)+_k] ); + _vt1 = _mm512_mul_ps( _vii, _vdcp ); + _vci = _mm512_loadcvt_bf16_fp32( &_ci[(_j*bk)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vci, _vci, _neg_ones); + _vdci = _mm512_mul_ps( _vt1, _vt2 ); + NATIVE_STORECVT_F32_BF16( &_dci[(_j*bk)+_k], _vdci ); + _vt1 = _mm512_mul_ps( _vci, _vdcp ); + _vt2 = _mm512_sub_ps( _ones, _vii ); + _vdi = _mm512_mul_ps( _vii, _vt2); + _vdi = _mm512_mul_ps( _vdi, _vt1); + NATIVE_STORECVT_F32_BF16( &_di[(_j*bk)+_k], _vdi ); + _vcps = _mm512_loadcvt_bf16_fp32( &_cps[(_j*bk)+_k] ); + _vt1 = _mm512_mul_ps( _vcps, _vdcp ); + _vf = _mm512_loadcvt_bf16_fp32( &_f[(_j*bk)+_k] ); + _vt2 = _mm512_sub_ps( _ones, _vf ); + _vdf = _mm512_mul_ps( _vf, _vt2); + _vdf = _mm512_mul_ps( _vdf, _vt1); + NATIVE_STORECVT_F32_BF16( &_df[(_j*bk)+_k], _vdf ); + _vt1 = _mm512_mul_ps( _vdout, _vco); + _vt2 = _mm512_sub_ps( _ones, _vo ); + _vt2 = _mm512_mul_ps( _vo, _vt2); + _vdp = _mm512_mul_ps( _vt1, _vt2 ); + NATIVE_STORECVT_F32_BF16( &_dp[(_j*bk)+_k], _vdp ); + _vdcp = _mm512_mul_ps( _vdcp, _vf); + NATIVE_STORECVT_F32_BF16( &_dcp[(_j*bk)+_k], _vdcp ); + } + } + } else { + for ( _j = 0; _j < bn; ++_j ) { + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dout[(_j*bk)+_k] ); + _vdh = _mm512_loadcvt_bf16_fp32( &_dh[(_j*bk)+_k] ); + _vdout = _mm512_add_ps( _vdout, _vdh ); + _vo = _mm512_loadcvt_bf16_fp32( &_o[(_j*bk)+_k] ); + _vt1 = _mm512_mul_ps( _vdout, _vo ); + _vco = _mm512_loadcvt_bf16_fp32( &_co[(_j*bk)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vco, _vco, _neg_ones); + _vt1 = _mm512_mul_ps( _vt1, _vt2 ); + _vdcp = _mm512_loadcvt_bf16_fp32( &_dcp[(_j*bk)+_k] ); + _vdcp = _mm512_add_ps( _vdcp, _vt1 ); + _vii = _mm512_loadcvt_bf16_fp32( &_ii[(_j*bk)+_k] ); + _vt1 = _mm512_mul_ps( _vii, _vdcp ); + _vci = _mm512_loadcvt_bf16_fp32( &_ci[(_j*bk)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vci, _vci, _neg_ones); + _vdci = _mm512_mul_ps( _vt1, _vt2 ); + NATIVE_STORECVT_F32_BF16( &_dci[(_j*bk)+_k], _vdci ); + _vt1 = _mm512_mul_ps( _vci, _vdcp ); + _vt2 = _mm512_sub_ps( _ones, _vii ); + _vdi = _mm512_mul_ps( _vii, _vt2); + _vdi = _mm512_mul_ps( _vdi, _vt1); + NATIVE_STORECVT_F32_BF16( &_di[(_j*bk)+_k], _vdi ); + _vcps = _mm512_loadcvt_bf16_fp32( &_cps[(_j*bk)+_k] ); + _vt1 = _mm512_mul_ps( _vcps, _vdcp ); + _vf = _mm512_loadcvt_bf16_fp32( &_f[(_j*bk)+_k] ); + _vt2 = _mm512_sub_ps( _ones, _vf ); + _vdf = _mm512_mul_ps( _vf, _vt2); + _vdf = _mm512_mul_ps( _vdf, _vt1); + NATIVE_STORECVT_F32_BF16( &_df[(_j*bk)+_k], _vdf ); + _vt1 = _mm512_mul_ps( _vdout, _vco); + _vt2 = _mm512_sub_ps( _ones, _vo ); + _vt2 = _mm512_mul_ps( _vo, _vt2); + _vdp = _mm512_mul_ps( _vt1, _vt2 ); + NATIVE_STORECVT_F32_BF16( &_dp[(_j*bk)+_k], _vdp ); + _vdcp = _mm512_mul_ps( _vdcp, _vf); + NATIVE_STORECVT_F32_BF16( &_dcp[(_j*bk)+_k], _vdcp ); + } + } + } + { + /* Store di/dci/df/dp to diB/dciB/dfB/dpB which is CNNC AND vnni format */ + const __m512i perm_idx = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m256i c0, c1; + __m512i _c01; + LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, di_, _di, bk); + LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, df_, _df, bk); + LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, dp_, _dp, bk); + LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, dci_, _dci, bk); + LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, diB_, _diB, bk, _lpb); + LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, dfB_, _dfB, bk, _lpb); + LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, dpB_, _dpB, bk, _lpb); + LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, dciB_, _dciB, bk, _lpb); + for (_j = 0; _j < bn; _j+=2) { + for (_k = 0; _k < bk; _k+=16) { + c0 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, di_, _j, _k, bk)); + c1 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, di_, _j+1, _k, bk)); + _c01 = _mm512_inserti64x4 (LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), c0, 0); + _c01 = _mm512_inserti64x4 (_c01, c1, 1); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(3, diB_, _j/_lpb, _k, 0, bk, _lpb), _mm512_permutexvar_epi16(perm_idx, _c01)); + c0 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, df_, _j, _k, bk)); + c1 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, df_, _j+1, _k, bk)); + _c01 = _mm512_inserti64x4 (LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), c0, 0); + _c01 = _mm512_inserti64x4 (_c01, c1, 1); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(3, dfB_, _j/_lpb, _k, 0, bk, _lpb), _mm512_permutexvar_epi16(perm_idx, _c01)); + c0 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, dp_, _j, _k, bk)); + c1 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, dp_, _j+1, _k, bk)); + _c01 = _mm512_inserti64x4 (LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), c0, 0); + _c01 = _mm512_inserti64x4 (_c01, c1, 1); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(3, dpB_, _j/_lpb, _k, 0, bk, _lpb), _mm512_permutexvar_epi16(perm_idx, _c01)); + c0 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, dci_, _j, _k, bk)); + c1 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, dci_, _j+1, _k, bk)); + _c01 = _mm512_inserti64x4 (LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), c0, 0); + _c01 = _mm512_inserti64x4 (_c01, c1, 1); + _mm512_store_epi32(&LIBXSMM_VLA_ACCESS(3, dciB_, _j/_lpb, _k, 0, bk, _lpb), _mm512_permutexvar_epi16(perm_idx, _c01)); + } + } + } +} + +#undef NATIVE_STORECVT_F32_BF16 + diff --git a/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise_reformat.tpl.c b/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise_reformat.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..623cf71d8ca6f4f9bf087662c8f23540ff0c226c --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise_reformat.tpl.c @@ -0,0 +1,124 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.), Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +{ + libxsmm_blasint _k, _j; + __m512 _vdout, _vdh, _vo, _vt1, _vt2, _vco, _vdcs, _vdcp, _vi, _vci, _vdci, _vdi, _vcps, _vf, _vdf, _vdp; + element_input_type* _dout = &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K); + element_input_type* _dh = &LIBXSMM_VLA_ACCESS(3, dh, j, in, ik, N, K); + element_input_type* _o = &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K); + element_input_type* _co = &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K); + element_input_type* _dcs = &LIBXSMM_VLA_ACCESS(2, dcs, in, ik, K); + element_input_type* _i = &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K); + element_input_type* _ci = &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K); + element_input_type* _dci = &LIBXSMM_VLA_ACCESS(2, dci, in, ik, K); + element_input_type* _di = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + element_input_type* _cps = cps_ptr; + element_input_type* _f = &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K); + element_input_type* _df = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + element_input_type* _dp = &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + element_input_type* _dcp = &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K); + element_input_type* _dciB = &LIBXSMM_VLA_ACCESS(4, dciB, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _diB = &LIBXSMM_VLA_ACCESS(4, diB, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _dfB = &LIBXSMM_VLA_ACCESS(4, dfB, inb, ikb, 0, 0, kBlocks, bn, bk); + element_input_type* _dpB = &LIBXSMM_VLA_ACCESS(4, dpB, inb, ikb, 0, 0, kBlocks, bn, bk); + const __m512 _vneg_ones = _mm512_set1_ps( (float)-1.0 ); + const __m512 _vones = _mm512_set1_ps( (float)1.0 ); + if (j == t-1) { + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dh[(_j*K)+_k] ); + _vo = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vdout, _vo ); + _vco = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_co[(_j*K)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vco, _vco, _vneg_ones); + _vt1 = _mm512_mul_ps( _vt1, _vt2 ); + _vdcs = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dcs[(_j*K)+_k] ); + _vdcp = _mm512_add_ps( _vdcs, _vt1 ); + _vi = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vi, _vdcp ); + _vci = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*K)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vci, _vci, _vneg_ones); + _vdci = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dci[(_j*K)+_k], _vdci ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dciB[(_j*bk)+_k], _vdci ); + _vt1 = _mm512_mul_ps( _vci, _vdcp ); + _vt2 = _mm512_sub_ps( _vones, _vi ); + _vdi = _mm512_mul_ps( _vi, _vt2); + _vdi = _mm512_mul_ps( _vdi, _vt1); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_di[(_j*K)+_k], _vdi ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_diB[(_j*bk)+_k], _vdi ); + _vcps = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vcps, _vdcp ); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*K)+_k] ); + _vt2 = _mm512_sub_ps( _vones, _vf ); + _vdf = _mm512_mul_ps( _vf, _vt2); + _vdf = _mm512_mul_ps( _vdf, _vt1); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_df[(_j*K)+_k], _vdf ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dfB[(_j*bk)+_k], _vdf ); + _vt1 = _mm512_mul_ps( _vdout, _vco); + _vt2 = _mm512_sub_ps( _vones, _vo ); + _vt2 = _mm512_mul_ps( _vo, _vt2); + _vdp = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dp[(_j*K)+_k], _vdp ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dpB[(_j*bk)+_k], _vdp ); + _vdcp = _mm512_mul_ps( _vdcp, _vf); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dcp[(_j*K)+_k], _vdcp ); + } + } + } else { + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dout[(_j*K)+_k] ); + _vdh = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dh[(_j*K)+_k] ); + _vdout = _mm512_add_ps( _vdout, _vdh ); + _vo = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vdout, _vo ); + _vco = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_co[(_j*K)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vco, _vco, _vneg_ones); + _vt1 = _mm512_mul_ps( _vt1, _vt2 ); + _vdcp = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dcp[(_j*K)+_k] ); + _vdcp = _mm512_add_ps( _vdcp, _vt1 ); + _vi = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vi, _vdcp ); + _vci = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*K)+_k] ); + _vt2 = _mm512_fnmsub_ps ( _vci, _vci, _vneg_ones); + _vdci = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dci[(_j*K)+_k], _vdci ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dciB[(_j*bk)+_k], _vdci ); + _vt1 = _mm512_mul_ps( _vci, _vdcp ); + _vt2 = _mm512_sub_ps( _vones, _vi ); + _vdi = _mm512_mul_ps( _vi, _vt2); + _vdi = _mm512_mul_ps( _vdi, _vt1); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_di[(_j*K)+_k], _vdi ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_diB[(_j*bk)+_k], _vdi ); + _vcps = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*K)+_k] ); + _vt1 = _mm512_mul_ps( _vcps, _vdcp ); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*K)+_k] ); + _vt2 = _mm512_sub_ps( _vones, _vf ); + _vdf = _mm512_mul_ps( _vf, _vt2); + _vdf = _mm512_mul_ps( _vdf, _vt1); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_df[(_j*K)+_k], _vdf ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dfB[(_j*bk)+_k], _vdf ); + _vt1 = _mm512_mul_ps( _vdout, _vco); + _vt2 = _mm512_sub_ps( _vones, _vo ); + _vt2 = _mm512_mul_ps( _vo, _vt2); + _vdp = _mm512_mul_ps( _vt1, _vt2 ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dp[(_j*K)+_k], _vdp ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dpB[(_j*bk)+_k], _vdp ); + _vdcp = _mm512_mul_ps( _vdcp, _vf); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_dcp[(_j*K)+_k], _vdcp ); + } + } + } +} diff --git a/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise_reformat_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise_reformat_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..4ebd4aae2b54f5fdb7af14d8414b3247469cb2d4 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_internal_lstm_bwdupd_fused_eltwise_reformat_bf16.tpl.c @@ -0,0 +1,169 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.), Alexander Heinecke (Intel Corp.) +******************************************************************************/ +{ + float* _dout = &LIBXSMM_VLA_ACCESS(2, dout, in, ik, K); + element_input_type* _dh = &LIBXSMM_VLA_ACCESS(3, dh, j, in, ik, N, K); + element_input_type* _o = &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K); + element_input_type* _co = &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K); + element_input_type* _dcs = &LIBXSMM_VLA_ACCESS(2, dcs, in, ik, K); + element_input_type* _ii = &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K); + element_input_type* _ci = &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K); + element_input_type* _dci = &LIBXSMM_VLA_ACCESS(2, dci, in, ik, K); + element_input_type* _di = &LIBXSMM_VLA_ACCESS(2, di, in, ik, K); + element_input_type* _cps = cps_ptr; + element_input_type* _f = &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K); + element_input_type* _df = &LIBXSMM_VLA_ACCESS(2, df, in, ik, K); + element_input_type* _dp = &LIBXSMM_VLA_ACCESS(2, dp, in, ik, K); + element_input_type* _dcp = &LIBXSMM_VLA_ACCESS(2, dcp, in, ik, K); + element_input_type* _dciB = &LIBXSMM_VLA_ACCESS(5, dciB, ikb, inb, 0, 0, 0, nBlocks, bn_lp, bk, lpb); + element_input_type* _diB = &LIBXSMM_VLA_ACCESS(5, diB, ikb, inb, 0, 0, 0, nBlocks, bn_lp, bk, lpb); + element_input_type* _dfB = &LIBXSMM_VLA_ACCESS(5, dfB, ikb, inb, 0, 0, 0, nBlocks, bn_lp, bk, lpb); + element_input_type* _dpB = &LIBXSMM_VLA_ACCESS(5, dpB, ikb, inb, 0, 0, 0, nBlocks, bn_lp, bk, lpb); + + libxsmm_blasint _k, _j; + __m512 _vdout, _vdh, _vo, _vt1, _vt2, _vco, _vdcs, _vdcp, _vii, _vci, _vdci, _vdi, _vcps, _vf, _vdf, _vdp; + const __m512 _neg_ones = _mm512_set1_ps( (float)-1.0 ); + const __m512 _ones = _mm512_set1_ps( (float)1.0 ); + int _lpb = 2; + + if (j == t-1) { + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_dh[(_j*K)+_k] )); + _vo = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_o[(_j*K)+_k] )); + _vt1 = _mm512_mul_ps( _vdout, _vo ); + _vco = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_co[(_j*K)+_k] )); + _vt2 = _mm512_fnmsub_ps ( _vco, _vco, _neg_ones); + _vt1 = _mm512_mul_ps( _vt1, _vt2 ); + _vdcs = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_dcs[(_j*K)+_k] )); + _vdcp = _mm512_add_ps( _vdcs, _vt1 ); + _vii = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_ii[(_j*K)+_k] )); + _vt1 = _mm512_mul_ps( _vii, _vdcp ); + _vci = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_ci[(_j*K)+_k] )); + _vt2 = _mm512_fnmsub_ps ( _vci, _vci, _neg_ones); + _vdci = _mm512_mul_ps( _vt1, _vt2 ); + _mm256_stream_si256((__m256i*)&_dci[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdci) ); + _vt1 = _mm512_mul_ps( _vci, _vdcp ); + _vt2 = _mm512_sub_ps( _ones, _vii ); + _vdi = _mm512_mul_ps( _vii, _vt2); + _vdi = _mm512_mul_ps( _vdi, _vt1); + _mm256_stream_si256((__m256i*)&_di[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdi) ); + _vcps = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_cps[(_j*K)+_k] )); + _vt1 = _mm512_mul_ps( _vcps, _vdcp ); + _vf = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_f[(_j*K)+_k] )); + _vt2 = _mm512_sub_ps( _ones, _vf ); + _vdf = _mm512_mul_ps( _vf, _vt2); + _vdf = _mm512_mul_ps( _vdf, _vt1); + _mm256_stream_si256((__m256i*)&_df[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdf) ); + _vt1 = _mm512_mul_ps( _vdout, _vco); + _vt2 = _mm512_sub_ps( _ones, _vo ); + _vt2 = _mm512_mul_ps( _vo, _vt2); + _vdp = _mm512_mul_ps( _vt1, _vt2 ); + _mm256_stream_si256((__m256i*)&_dp[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdp) ); + _vdcp = _mm512_mul_ps( _vdcp, _vf); + _mm256_stream_si256((__m256i*)&_dcp[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdcp) ); + } + } + } else { + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vdout = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_dout[(_j*K)+_k] ); + _vdh = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_dh[(_j*K)+_k] )); + _vdout = _mm512_add_ps( _vdout, _vdh ); + _vo = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_o[(_j*K)+_k] )); + _vt1 = _mm512_mul_ps( _vdout, _vo ); + _vco = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_co[(_j*K)+_k] )); + _vt2 = _mm512_fnmsub_ps ( _vco, _vco, _neg_ones); + _vt1 = _mm512_mul_ps( _vt1, _vt2 ); + _vdcp = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_dcp[(_j*K)+_k] )); + _vdcp = _mm512_add_ps( _vdcp, _vt1 ); + _vii = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_ii[(_j*K)+_k] )); + _vt1 = _mm512_mul_ps( _vii, _vdcp ); + _vci = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_ci[(_j*K)+_k] )); + _vt2 = _mm512_fnmsub_ps ( _vci, _vci, _neg_ones); + _vdci = _mm512_mul_ps( _vt1, _vt2 ); + _mm256_stream_si256((__m256i*)&_dci[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdci) ); + _vt1 = _mm512_mul_ps( _vci, _vdcp ); + _vt2 = _mm512_sub_ps( _ones, _vii ); + _vdi = _mm512_mul_ps( _vii, _vt2); + _vdi = _mm512_mul_ps( _vdi, _vt1); + _mm256_stream_si256((__m256i*)&_di[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdi) ); + _vcps = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_cps[(_j*K)+_k] )); + _vt1 = _mm512_mul_ps( _vcps, _vdcp ); + _vf = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&_f[(_j*K)+_k] )); + _vt2 = _mm512_sub_ps( _ones, _vf ); + _vdf = _mm512_mul_ps( _vf, _vt2); + _vdf = _mm512_mul_ps( _vdf, _vt1); + _mm256_stream_si256((__m256i*)&_df[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdf) ); + _vt1 = _mm512_mul_ps( _vdout, _vco); + _vt2 = _mm512_sub_ps( _ones, _vo ); + _vt2 = _mm512_mul_ps( _vo, _vt2); + _vdp = _mm512_mul_ps( _vt1, _vt2 ); + _mm256_stream_si256((__m256i*)&_dp[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdp) ); + _vdcp = _mm512_mul_ps( _vdcp, _vf); + _mm256_stream_si256((__m256i*)&_dcp[(_j*K)+_k], LIBXSMM_INTRINSISCS_MM512_CVTNEPS_PBH(_vdcp) ); + } + } + } + + { /* Store di/dci/df/dp to diB/dciB/dfB/dpB which is CNNC AND vnni format */ + LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, di_, _di, K); + LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, df_, _df, K); + LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, dp_, _dp, K); + LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, dci_, _dci, K); + LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, diB_, _diB, bk, _lpb); + LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, dfB_, _dfB, bk, _lpb); + LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, dpB_, _dpB, bk, _lpb); + LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, dciB_, _dciB, bk, _lpb); + if ( (bn % 2 == 0) && (bk % 16 == 0) ) { + const __m512i perm_idx = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m256i c0, c1; + __m512i c01; + for (_j = 0; _j < bn; _j+=2) { + for (_k = 0; _k < bk; _k+=16) { + c0 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, di_, _j, _k, K)); + c1 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, di_, _j+1, _k, K)); + c01 = _mm512_inserti64x4 (LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), c0, 0); + c01 = _mm512_inserti64x4 (c01, c1, 1); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(3, diB_, _j/_lpb, _k, 0, bk, _lpb), _mm512_permutexvar_epi16(perm_idx, c01)); + c0 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, df_, _j, _k, K)); + c1 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, df_, _j+1, _k, K)); + c01 = _mm512_inserti64x4 (LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), c0, 0); + c01 = _mm512_inserti64x4 (c01, c1, 1); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(3, dfB_, _j/_lpb, _k, 0, bk, _lpb), _mm512_permutexvar_epi16(perm_idx, c01)); + c0 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, dp_, _j, _k, K)); + c1 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, dp_, _j+1, _k, K)); + c01 = _mm512_inserti64x4 (LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), c0, 0); + c01 = _mm512_inserti64x4 (c01, c1, 1); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(3, dpB_, _j/_lpb, _k, 0, bk, _lpb), _mm512_permutexvar_epi16(perm_idx, c01)); + c0 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, dci_, _j, _k, K)); + c1 = _mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, dci_, _j+1, _k, K)); + c01 = _mm512_inserti64x4 (LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), c0, 0); + c01 = _mm512_inserti64x4 (c01, c1, 1); + _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(3, dciB_, _j/_lpb, _k, 0, bk, _lpb), _mm512_permutexvar_epi16(perm_idx, c01)); + } + } + } else { + for (_j = 0; _j < bn; _j++) { + for (_k = 0; _k < bk; _k++) { + LIBXSMM_VLA_ACCESS(3, diB_, _j / _lpb, _k, _j%_lpb, bk, _lpb) = LIBXSMM_VLA_ACCESS(2, di_, _j, _k, K); + LIBXSMM_VLA_ACCESS(3, dfB_, _j / _lpb, _k, _j%_lpb, bk, _lpb) = LIBXSMM_VLA_ACCESS(2, df_, _j, _k, K); + LIBXSMM_VLA_ACCESS(3, dpB_, _j / _lpb, _k, _j%_lpb, bk, _lpb) = LIBXSMM_VLA_ACCESS(2, dp_, _j, _k, K); + LIBXSMM_VLA_ACCESS(3, dciB_, _j / _lpb, _k, _j%_lpb, bk, _lpb) = LIBXSMM_VLA_ACCESS(2, dci_, _j, _k, K); + } + } + } + } +} + + diff --git a/third_party/libxsmm/src/template/libxsmm_internal_lstm_fwd_fused_eltwise.tpl.c b/third_party/libxsmm/src/template/libxsmm_internal_lstm_fwd_fused_eltwise.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..7a50dd1ddd8f95e33a731ef52f3b636d99bba790 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_internal_lstm_fwd_fused_eltwise.tpl.c @@ -0,0 +1,50 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.), Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +{ + libxsmm_blasint _k, _j; + element_input_type* _o = &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K); + element_input_type* _i = &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K); + element_input_type* _f = &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K); + element_input_type* _ci = &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K); + element_input_type* _cps = cps_ptr; + element_input_type* _cs = &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K); + element_input_type* _h = &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K); + element_input_type* _co = &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K); + __m512 _vf, _vcs, _vi, _vci, _vco, _vo, _vh; + const __m512 _halves = _mm512_set1_ps( (LIBXSMM_DNN_ELTWISE_FTYPE)0.5 ); + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vo = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*K)+_k] ); + _vi = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*K)+_k] ); + _vci = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*K)+_k] ); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*K)+_k] ); + _vcs = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*K)+_k] ); + _vo = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vo, _halves ) ), _halves, _halves); + _vi = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vi, _halves ) ), _halves, _halves); + _vci = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _vci ); + _vf = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vf, _halves ) ), _halves, _halves); + _vcs = _mm512_mul_ps( _vf, _vcs ); + _vcs = _mm512_fmadd_ps( _vi, _vci, _vcs ); + _vco = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _vcs ); + _vh = _mm512_mul_ps( _vo, _vco ); + _mm512_storeu_ps( &_o[(_j*K)+_k], _vo ); + _mm512_storeu_ps( &_i[(_j*K)+_k], _vi ); + _mm512_storeu_ps( &_ci[(_j*K)+_k], _vci ); + _mm512_storeu_ps( &_f[(_j*K)+_k], _vf ); + _mm512_storeu_ps( &_cs[(_j*K)+_k], _vcs ); + _mm512_storeu_ps( &_co[(_j*K)+_k], _vco ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_h[(_j*K)+_k], _vh ); + } + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c b/third_party/libxsmm/src/template/libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..4d1c860343f33514524850b1bbd2ec10d4dfc9a2 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_internal_lstm_fwd_fused_eltwise_bf16.tpl.c @@ -0,0 +1,50 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Evangelos Georganas (Intel Corp.), Alexander Heinecke (Intel Corp.) +******************************************************************************/ + +{ + libxsmm_blasint _k, _j; + float* _o = &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K); + float* _i = &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K); + float* _f = &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K); + float* _ci = &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K); + float* _cps = cps_ptr; + float* _cs = &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K); + float* _h = &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K); + float* _co = &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K); + __m512 _vf, _vcs, _vi, _vci, _vco, _vo, _vh; + const __m512 _halves = _mm512_set1_ps( (LIBXSMM_DNN_ELTWISE_FTYPE)0.5 ); + for ( _j = 0; _j < bn; ++_j ) { + LIBXSMM_PRAGMA_UNROLL_N(4) + for ( _k = 0; _k < bk; _k += 16 ) { + _vo = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_o[(_j*K)+_k] ); + _vi = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_i[(_j*K)+_k] ); + _vci = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_ci[(_j*K)+_k] ); + _vf = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_f[(_j*K)+_k] ); + _vcs = LIBXSMM_INTRINSICS_MM512_LOAD_PS( &_cps[(_j*K)+_k] ); + _vo = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vo, _halves ) ), _halves, _halves); + _vi = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vi, _halves ) ), _halves, _halves); + _vci = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _vci ); + _vf = _mm512_fmadd_ps( LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _mm512_mul_ps( _vf, _halves ) ), _halves, _halves); + _vcs = _mm512_mul_ps( _vf, _vcs ); + _vcs = _mm512_fmadd_ps( _vi, _vci, _vcs ); + _vco = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( _vcs ); + _vh = _mm512_mul_ps( _vo, _vco ); + _mm512_storeu_ps( &_o[(_j*K)+_k], _vo ); + _mm512_storeu_ps( &_i[(_j*K)+_k], _vi ); + _mm512_storeu_ps( &_ci[(_j*K)+_k], _vci ); + _mm512_storeu_ps( &_f[(_j*K)+_k], _vf ); + _mm512_storeu_ps( &_cs[(_j*K)+_k], _vcs ); + _mm512_storeu_ps( &_co[(_j*K)+_k], _vco ); + LIBXSMM_INTRINSICS_MM512_STREAM_PS( &_h[(_j*K)+_k], _vh ); + } + } +} + diff --git a/third_party/libxsmm/src/template/libxsmm_matdiff.tpl.c b/third_party/libxsmm/src/template/libxsmm_matdiff.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..c9a83c7659bdaea6aadff0fe8d2cc70229d171e7 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_matdiff.tpl.c @@ -0,0 +1,174 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Hans Pabst (Intel Corp.) +******************************************************************************/ + +const LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE *const real_ref = (const LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE*)ref; +const LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE *const real_tst = (const LIBXSMM_MATDIFF_TEMPLATE_ELEM_TYPE*)tst; +double compf = 0, compfr = 0, compft = 0, normfr = 0, normft = 0, normr = 0, normt = 0; +double normrc = 0, normtc = 0, compr = 0, compt = 0, compd = 0; +libxsmm_blasint i, j; + +for (i = 0; i < nn; ++i) { + double comprj = 0, comptj = 0, compij = 0; + double normrj = 0, normtj = 0, normij = 0; + double v0, v1; + + for (j = 0; j < mm; ++j) { + const double ti = (0 != real_tst ? real_tst[i*ldt+j] : 0); + const double ri = real_ref[i*ldr+j]; + const double ta = LIBXSMM_ABS(ti); + const double ra = LIBXSMM_ABS(ri); + + /* minimum/maximum of reference set */ + if (ri < info->min_ref) info->min_ref = ri; + if (ri > info->max_ref) info->max_ref = ri; + + if (LIBXSMM_NOTNAN(ti) && inf > ta) { + const double di = (0 != real_tst ? (ri < ti ? (ti - ri) : (ri - ti)) : 0); + + /* minimum/maximum of test set */ + if (ti < info->min_tst) info->min_tst = ti; + if (ti > info->max_tst) info->max_tst = ti; + + /* maximum absolute error and location */ + if (info->linf_abs < di) { + info->linf_abs = di; + info->v_ref = ri; + info->v_tst = ti; + info->m = j; + info->n = i; + } + + /* maximum error relative to current value */ + if (0 < ra) { + const double dri = di / ra; + if (info->linf_rel < dri) info->linf_rel = dri; + /* sum of relative differences */ + v0 = dri * dri; + if (inf > v0) { + v0 -= compd; + v1 = info->l2_rel + v0; + compd = (v1 - info->l2_rel) - v0; + info->l2_rel = v1; + } + } + + /* row-wise sum of reference values with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(ra, &normrj, &comprj); + + /* row-wise sum of test values with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(ta, &normtj, &comptj); + + /* row-wise sum of differences with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(di, &normij, &compij); + + /* Froebenius-norm of reference matrix with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(ri * ri, &normfr, &compfr); + + /* Froebenius-norm of test matrix with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(ti * ti, &normft, &compft); + + /* Froebenius-norm of differences with Kahan compensation */ + v0 = di * di; + if (inf > v0) { + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(v0, &info->l2_abs, &compf); + } + } + else { /* NaN */ + info->m = j; info->n = i; + result_nan = ((LIBXSMM_NOTNAN(ri) && inf > ra) ? 1 : 2); + break; + } + } + + if (0 == result_nan) { + /* summarize reference values */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(normrj, &info->l1_ref, &compr); + + /* summarize test values */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(normtj, &info->l1_tst, &compt); + + /* calculate Infinity-norm of differences */ + if (info->normi_abs < normij) info->normi_abs = normij; + /* calculate Infinity-norm of reference/test values */ + if (normr < normrj) normr = normrj; + if (normt < normtj) normt = normtj; + } + else { + break; + } +} + +if (0 == result_nan) { + double compr_var = 0, compt_var = 0; + + /* initial variance */ + assert(0 == info->var_ref); /* !LIBXSMM_ASSERT */ + assert(0 == info->var_tst); /* !LIBXSMM_ASSERT */ + + if (0 != ntotal) { /* final average */ + info->avg_ref = info->l1_ref / ntotal; + info->avg_tst = info->l1_tst / ntotal; + } + + /* Infinity-norm relative to reference */ + info->normi_rel = LIBXSMM_MATDIFF_DIV(info->normi_abs, normr, normt); + /* Froebenius-norm relative to reference */ + info->normf_rel = LIBXSMM_MATDIFF_DIV(info->l2_abs, normfr, normft); + + for (j = 0; j < mm; ++j) { + double compri = 0, compti = 0, comp1 = 0; + double normri = 0, normti = 0, norm1 = 0; + + for (i = 0; i < nn; ++i) { + const double ri = real_ref[i*ldr + j], ti = (0 != real_tst ? real_tst[i*ldt + j] : 0); + const double di = (0 != real_tst ? (ri < ti ? (ti - ri) : (ri - ti)) : 0); + const double rd = ri - info->avg_ref, td = ti - info->avg_tst; + const double ra = LIBXSMM_ABS(ri), ta = LIBXSMM_ABS(ti); + + /* variance of reference set with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(rd * rd, &info->var_ref, &compr_var); + + /* variance of test set with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(td * td, &info->var_tst, &compt_var); + + /* column-wise sum of reference values with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(ra, &normri, &compri); + + /* column-wise sum of test values with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(ta, &normti, &compti); + + /* column-wise sum of differences with Kahan compensation */ + LIBXSMM_PRAGMA_FORCEINLINE + libxsmm_kahan_sum(di, &norm1, &comp1); + } + + /* calculate One-norm of differences */ + if (info->norm1_abs < norm1) info->norm1_abs = norm1; + /* calculate One-norm of reference/test values */ + if (normrc < normri) normrc = normri; + if (normtc < normti) normtc = normti; + } + + /* One-norm relative to reference */ + info->norm1_rel = LIBXSMM_MATDIFF_DIV(info->norm1_abs, normrc, normtc); +} diff --git a/third_party/libxsmm/src/template/libxsmm_spmdm_compute_bfloat16_thread.tpl.c b/third_party/libxsmm/src/template/libxsmm_spmdm_compute_bfloat16_thread.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..5f16f1ab5f02a7fb90de5b9ffeb35fff4480e017 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_spmdm_compute_bfloat16_thread.tpl.c @@ -0,0 +1,564 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Nadathur Satish (Intel Corp.) +******************************************************************************/ + +const int m_blocks = handle->mb; +/*const int n_blocks = handle->nb;*/ +const int k_blocks = handle->kb; +const int m_block_size = handle->bm; +const int n_block_size = handle->bn; +const int k_block_size = handle->bk; +int mb = block_id / handle->nb; +int nb = block_id % handle->nb; + + +#define LIBXSMM_SPMDM_COMPUTE_NREGS (6) +int m_overall_start = mb*m_block_size; +int m_overall_end = (mb + 1)*m_block_size; +int num_m; +int num_m_aligned; + +int n_overall_start = nb*n_block_size; +int n_overall_end = (nb + 1)*n_block_size; +int num_n; +int m, n, k, kb; +int last_block_n, num_full_regs, last_n_start; + +int k_overall_start, k_overall_end, num_k; + +float *const scratch_C = (float *)(handle->base_ptr_scratch_B_scratch_C + (size_t)tid*handle->memory_for_scratch_per_thread); +float *const scratch_B = (float *)(handle->base_ptr_scratch_B_scratch_C + (size_t)tid*handle->memory_for_scratch_per_thread + (size_t)m_block_size*n_block_size*sizeof(float)); +#if 0 +float *const scratch_C = (float *)(handle->spmdm_scratch_C + tid*m_block_size*n_block_size*sizeof(float)); +float *const scratch_B = (float *)(handle->spmdm_scratch_B + tid*k_block_size*n_block_size*sizeof(float)); +#endif + +SIMDTYPE_FP32 sum[2*LIBXSMM_SPMDM_COMPUTE_NREGS]; +float* LIBXSMM_RESTRICT ptr_result; +#if SIMD_WIDTH_FP32 > 1 +SIMDTYPE_INT32 vzero = _MM_SETZERO_INT32(); +#endif + +LIBXSMM_UNUSED(nthreads); +LIBXSMM_UNUSED(transa); +LIBXSMM_UNUSED(alpha); +LIBXSMM_UNUSED(beta); +LIBXSMM_UNUSED(tid); + +/* really is twice this */ +assert(n_block_size == LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32); + +if (m_overall_end > handle->m) m_overall_end = handle->m; +num_m = (m_overall_end - m_overall_start); +num_m_aligned = (num_m / 2) * 2; + +if (n_overall_end > handle->n) n_overall_end = handle->n; +num_n = (n_overall_end - n_overall_start); +last_block_n = (num_n != n_block_size); +num_full_regs = (num_n / SIMD_WIDTH_FP32); +if ((num_full_regs > 0) && (num_full_regs%2)) num_full_regs--; +last_n_start = num_full_regs*SIMD_WIDTH_FP32; + +/* Copy in c matrix to buffer */ +ptr_result = c + (size_t)m_overall_start*handle->n + n_overall_start; +if (LIBXSMM_FEQ(0.f, *beta)) { + if (!last_block_n) { + for (m = 0; m < num_m; m++) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + } + } else { + for (m = 0; m < num_m; m++) { + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n)*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + } + for (n = last_n_start; n < num_n; n++) { + scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = 0; + } + } + } +} +else if (LIBXSMM_FEQ(1.f, *beta)) { + if ('T' == transc || 't' == transc) { + int num_m_simd = num_m / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int m2; + + ptr_result = c + (size_t)n_overall_start*handle->m + m_overall_start; + + for (m = 0; m < num_m_simd; m += SIMD_WIDTH_FP32) { + for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { + TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_result + (size_t)n*handle->m + m, handle->m, scratch_C + (size_t)m*n_block_size + n, n_block_size); + } + /* Transpose a SIMD_WIDTH_FP32 * (num_n - num_n_simd) block of output space - input is of size (num_n - num_n_simd) * SIMD_WIDTH_FP32 */ + for (m2 = m; m2 < m + SIMD_WIDTH_FP32; m2++) { + for (n = num_n_simd; n < num_n; n++) { + scratch_C[m2*n_block_size + n] = ptr_result[n*handle->m + m2]; + } + } + } + /* Transpose a (num_m - num_m_simd) * num_n block of output space - input is of size num_n * (num_m - num_m_simd) */ + for (m = num_m_simd; m < num_m; m++) { + for (n = 0; n < num_n; n++) { + scratch_C[m*n_block_size + n] = ptr_result[n*handle->m + m]; + } + } + } + else { + if (!last_block_n) { + for (m = 0; m < num_m; m++) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 0*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 1*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 2*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 3*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 4*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 5*SIMD_WIDTH_FP32)); + } + } + else { + for (m = 0; m < num_m; m++) { + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n) *SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n+1)*SIMD_WIDTH_FP32)); + } + for (n = last_n_start; n < num_n; n++) { + scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32+n] = ptr_result[m*handle->n+n]; + } + } + } + } +} +else { + SIMDTYPE_FP32 beta_v = _MM_SET1_FP32(*beta); + if ('T' == transc || 't' == transc) { + int num_m_simd = num_m / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int m2; + + ptr_result = c + (size_t)n_overall_start*handle->m + m_overall_start; + + for (m = 0; m < num_m_simd; m += SIMD_WIDTH_FP32) { + for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { + TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_result + (size_t)n*handle->m + m, handle->m, scratch_C + (size_t)m*n_block_size + n, n_block_size); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*1, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*1))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*2, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*2))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*3, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*3))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*4, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*4))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*5, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*5))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*6, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*6))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*7, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*7))); + } + /* Transpose a SIMD_WIDTH_FP32 * (num_n - num_n_simd) block of output space - input is of size (num_n - num_n_simd) * SIMD_WIDTH_FP32 */ + for (m2 = m; m2 < m + SIMD_WIDTH_FP32; m2++) { + for (n = num_n_simd; n < num_n; n++) { + scratch_C[m2*n_block_size + n] = (*beta)*ptr_result[n*handle->m + m2]; + } + } + } + /* Transpose a (num_m - num_m_simd) * num_n block of output space - input is of size num_n * (num_m - num_m_simd) */ + for (m = num_m_simd; m < num_m; m++) { + for (n = 0; n < num_n; n++) { + scratch_C[m*n_block_size + n] = (*beta)*ptr_result[n*handle->m + m]; + } + } + + } + else { + if (!last_block_n) { + for (m = 0; m < num_m; m++) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 0*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 1*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 2*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 3*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 4*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 5*SIMD_WIDTH_FP32))); + } + } + else { + for (m = 0; m < num_m; m++) { + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n) *SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n+1)*SIMD_WIDTH_FP32))); + } + for (n = last_n_start; n < num_n; n++) { + scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = (*beta)*ptr_result[m*handle->n + n]; + } + } + } + } +} + +for (kb = 0; kb < k_blocks; kb++) { + const uint16_t* LIBXSMM_RESTRICT ptr_dense; + float * LIBXSMM_RESTRICT scratch_C_base; + const float * LIBXSMM_RESTRICT scratch_B_base; + int block_A = kb * m_blocks + mb; + libxsmm_CSR_sparseslice slice = a_sparse[block_A]; + int m_local = 0; + + k_overall_start = kb*k_block_size; + k_overall_end = (kb+1)*k_block_size; + num_k = (k_overall_end - k_overall_start); + + /* Copy in b matrix */ + if ('T' == transb || 't' == transb) { + int num_k_simd = num_k / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int k2; + + ptr_dense = b + (size_t)n_overall_start*handle->k + k_overall_start; + + for (k = 0; k < num_k_simd; k += SIMD_WIDTH_FP32) { + for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { + TRANSPOSE_SIMD_WIDTH_KERNEL_BFLOAT16(ptr_dense + (size_t)n*handle->k + k, handle->k, scratch_B + (size_t)k*n_block_size + n, n_block_size); + } + /* Transpose a SIMD_WIDTH_FP32 * (num_n - num_n_simd) block of output space - input is of size (num_n - num_n_simd) * SIMD_WIDTH_FP32 */ + for (k2 = k; k2 < k + SIMD_WIDTH_FP32; k2++) { + for (n = num_n_simd; n < num_n; n++) { + uint16_t restmp = ptr_dense[n*handle->k + k2]; + union { int i; float f; } res; + res.i = restmp; + res.i <<= 16; + scratch_B[k2*n_block_size + n] = res.f; + } + } + } + /* Transpose a (num_m - num_m_simd) * num_n block of output space - input is of size num_n * (num_m - num_m_simd) */ + for (k = num_k_simd; k < num_k; k++) { + for (n = 0; n < num_n; n++) { + uint16_t restmp = ptr_dense[n*handle->k + k]; + union { int i; float f; } res; + res.i = restmp; + res.i <<= 16; + scratch_B[k*n_block_size + n] = res.f; + } + } + } else { + ptr_dense = b + (size_t)k_overall_start*handle->n + n_overall_start; + if (!last_block_n) { + for (k = 0; k < num_k; k++) { + SIMDTYPE_INT32 vload_0 = _MM_LOADU_INT32((const SIMDTYPE_INT32*)(ptr_dense + (size_t)k*handle->n + 2*0*SIMD_WIDTH_FP32)); + SIMDTYPE_INT32 vload_1, vload_2; + SIMDTYPE_FP32 v1_0, v2_0; + SIMDTYPE_FP32 v1_1, v2_1; + SIMDTYPE_FP32 v1_2, v2_2; + EXPAND_BFLOAT16(vload_0, v1_0, v2_0); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*0*SIMD_WIDTH_FP32, v1_0); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + (2*0+1)*SIMD_WIDTH_FP32, v2_0); + vload_1 = _MM_LOADU_INT32((const SIMDTYPE_INT32 *)(ptr_dense + (size_t)k*handle->n + 2*1*SIMD_WIDTH_FP32)); + EXPAND_BFLOAT16(vload_1, v1_1, v2_1); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*1*SIMD_WIDTH_FP32, v1_1); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + (2*1+1)*SIMD_WIDTH_FP32, v2_1); + vload_2 = _MM_LOADU_INT32((const SIMDTYPE_INT32 *)(ptr_dense + (size_t)k*handle->n + 2*2*SIMD_WIDTH_FP32)); + EXPAND_BFLOAT16(vload_2, v1_2, v2_2); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*2*SIMD_WIDTH_FP32, v1_2); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + (2*2+1)*SIMD_WIDTH_FP32, v2_2); + } + } else { + for (k = 0; k < num_k; k++) { + for (n = 0; n < num_full_regs; n += 2) { + SIMDTYPE_INT32 vload_0 = _MM_LOADU_INT32((const SIMDTYPE_INT32*)(ptr_dense + (size_t)k*handle->n + (size_t)n*SIMD_WIDTH_FP32)); + SIMDTYPE_FP32 v1_0, v2_0; + EXPAND_BFLOAT16(vload_0, v1_0, v2_0); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, v1_0); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, v2_0); + } + for (n = last_n_start; n < num_n; n++) { + uint16_t restmp = ptr_dense[k*handle->n + n]; + union { int i; float f; } res; + res.i = restmp; + res.i <<= 16; + { + scratch_B[k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = res.f; + } + } + } + } + } + + scratch_C_base = scratch_C - (size_t)m_overall_start*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + scratch_B_base = scratch_B; /* - (size_t)k_overall_start*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; */ + + for (m = m_overall_start; m < m_overall_start + num_m_aligned; m += 2, m_local += 2) { + int start_j, end_j, end_j_2, num_j, num_j_2; + const uint16_t *LIBXSMM_RESTRICT sp_c_ptr_base; + const uint16_t *LIBXSMM_RESTRICT sp_c_ptr_base_2; + const float *LIBXSMM_RESTRICT sp_v_ptr_base; + const float *LIBXSMM_RESTRICT sp_v_ptr_base_2; + float *const LIBXSMM_RESTRICT result_m_index = scratch_C_base + ((size_t)m) *LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + float *const LIBXSMM_RESTRICT result_m_index_2 = scratch_C_base + ((size_t)m+1)*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + + if (m_local >= m_block_size) { block_A++; slice = a_sparse[block_A]; m_local = 0; } + + start_j = slice.rowidx[m_local]; + end_j = slice.rowidx[m_local + 1]; + end_j_2 = slice.rowidx[m_local + 2]; + num_j = (end_j - start_j); + num_j_2 = (end_j_2 - end_j); + sp_c_ptr_base = slice.colidx + start_j; + sp_c_ptr_base_2 = slice.colidx + end_j; + sp_v_ptr_base = (float *)(slice.values) + start_j; + sp_v_ptr_base_2 = (float *)(slice.values) + end_j; + + if (!last_block_n) + { + int64_t j = 0, j2 = 0; + sum[0] = _MM_LOAD_FP32(result_m_index + 0*SIMD_WIDTH_FP32); + sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 0*SIMD_WIDTH_FP32); + sum[1] = _MM_LOAD_FP32(result_m_index + 1*SIMD_WIDTH_FP32); + sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 1*SIMD_WIDTH_FP32); + sum[2] = _MM_LOAD_FP32(result_m_index + 2*SIMD_WIDTH_FP32); + sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 2*SIMD_WIDTH_FP32); + sum[3] = _MM_LOAD_FP32(result_m_index + 3*SIMD_WIDTH_FP32); + sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 3*SIMD_WIDTH_FP32); + sum[4] = _MM_LOAD_FP32(result_m_index + 4*SIMD_WIDTH_FP32); + sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 4*SIMD_WIDTH_FP32); + sum[5] = _MM_LOAD_FP32(result_m_index + 5*SIMD_WIDTH_FP32); + sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 5*SIMD_WIDTH_FP32); + for (; j < num_j && j2 < num_j_2; j++, j2++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j] *LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); + sum[0] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 0*SIMD_WIDTH_FP32), sum[0]); + sum[0 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 0*SIMD_WIDTH_FP32), sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 1*SIMD_WIDTH_FP32), sum[1]); + sum[1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 1*SIMD_WIDTH_FP32), sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[2] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 2*SIMD_WIDTH_FP32), sum[2]); + sum[2 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 2*SIMD_WIDTH_FP32), sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[3] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 3*SIMD_WIDTH_FP32), sum[3]); + sum[3 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 3*SIMD_WIDTH_FP32), sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[4] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 4*SIMD_WIDTH_FP32), sum[4]); + sum[4 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 4*SIMD_WIDTH_FP32), sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[5] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 5*SIMD_WIDTH_FP32), sum[5]); + sum[5 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 5*SIMD_WIDTH_FP32), sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + for (; j < num_j; j++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + sum[0] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 0*SIMD_WIDTH_FP32), sum[0]); + sum[1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 1*SIMD_WIDTH_FP32), sum[1]); + sum[2] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 2*SIMD_WIDTH_FP32), sum[2]); + sum[3] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 3*SIMD_WIDTH_FP32), sum[3]); + sum[4] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 4*SIMD_WIDTH_FP32), sum[4]); + sum[5] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 5*SIMD_WIDTH_FP32), sum[5]); + } + for (; j2 < num_j_2; j2++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); + sum[0 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 0*SIMD_WIDTH_FP32), sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 1*SIMD_WIDTH_FP32), sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[2 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 2*SIMD_WIDTH_FP32), sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[3 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 3*SIMD_WIDTH_FP32), sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[4 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 4*SIMD_WIDTH_FP32), sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[5 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 5*SIMD_WIDTH_FP32), sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + _MM_STORE_FP32(result_m_index + 0*SIMD_WIDTH_FP32, sum[0]); + _MM_STORE_FP32(result_m_index_2 + 0*SIMD_WIDTH_FP32, sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 1*SIMD_WIDTH_FP32, sum[1]); + _MM_STORE_FP32(result_m_index_2 + 1*SIMD_WIDTH_FP32, sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 2*SIMD_WIDTH_FP32, sum[2]); + _MM_STORE_FP32(result_m_index_2 + 2*SIMD_WIDTH_FP32, sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 3*SIMD_WIDTH_FP32, sum[3]); + _MM_STORE_FP32(result_m_index_2 + 3*SIMD_WIDTH_FP32, sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 4*SIMD_WIDTH_FP32, sum[4]); + _MM_STORE_FP32(result_m_index_2 + 4*SIMD_WIDTH_FP32, sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 5*SIMD_WIDTH_FP32, sum[5]); + _MM_STORE_FP32(result_m_index_2 + 5*SIMD_WIDTH_FP32, sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + else { + int64_t j = 0, j2 = 0; + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_SETZERO_FP32(); + sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_SETZERO_FP32(); + sum[n+1] = _MM_SETZERO_FP32(); + sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_SETZERO_FP32(); + } + for (; j < num_j && j2 < num_j_2; j++, j2++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j] *LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + (size_t)n*SIMD_WIDTH_FP32), sum[n]); + sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + (size_t)n*SIMD_WIDTH_FP32), sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[n+1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1]); + sum[n+1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + { + float v_v_f = sp_v_ptr_base[j]; + float v_v_f_2 = sp_v_ptr_base_2[j2]; + for (n = last_n_start; n < num_n; n++) { + result_m_index[n] += sp_col_dense_index[n]*v_v_f; + result_m_index_2[n] += sp_col_dense_index_2[n]*v_v_f_2; + } + } + } + for (; j < num_j; j++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n) *SIMD_WIDTH_FP32), sum[n]); + sum[n+1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1]); + } + { + float v_v_f = sp_v_ptr_base[j]; + for (n = last_n_start; n < num_n; n++) { + result_m_index[n] += sp_col_dense_index[n]*v_v_f; + } + } + } + for (; j2 < num_j_2; j2++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); + for (n = 0; n < num_full_regs; n += 2) { + sum[n + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + ((size_t)n) *SIMD_WIDTH_FP32), sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[n+1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + { + float v_v_f_2 = sp_v_ptr_base_2[j2]; + for (n = last_n_start; n < num_n; n++) { + result_m_index_2[n] += sp_col_dense_index_2[n]*v_v_f_2; + } + } + } + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(result_m_index + ((size_t)n) *SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n], _MM_LOAD_FP32(result_m_index + (size_t)n*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(result_m_index_2 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS], _MM_LOAD_FP32(result_m_index_2 + (size_t)n*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+1], _MM_LOAD_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(result_m_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS], _MM_LOAD_FP32(result_m_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32))); + } + } + } + for (m = m_overall_start + num_m_aligned; m < m_overall_end; m++, m_local++) { + int start_j, end_j, num_j; + const uint16_t* LIBXSMM_RESTRICT sp_c_ptr_base; + const float* LIBXSMM_RESTRICT sp_v_ptr_base; + float* LIBXSMM_RESTRICT result_m_index; + + if (m_local >= m_block_size) { block_A++; slice = a_sparse[block_A]; m_local = 0; } + + start_j = slice.rowidx[m_local]; + end_j = slice.rowidx[m_local + 1]; + num_j = (end_j - start_j); + sp_c_ptr_base = slice.colidx + start_j; + sp_v_ptr_base = slice.values + start_j; + result_m_index = scratch_C_base + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + + if (!last_block_n) { + int64_t j = 0; + sum[0] = _MM_LOAD_FP32(result_m_index + 0*SIMD_WIDTH_FP32); + sum[1] = _MM_LOAD_FP32(result_m_index + 1*SIMD_WIDTH_FP32); + sum[2] = _MM_LOAD_FP32(result_m_index + 2*SIMD_WIDTH_FP32); + sum[3] = _MM_LOAD_FP32(result_m_index + 3*SIMD_WIDTH_FP32); + sum[4] = _MM_LOAD_FP32(result_m_index + 4*SIMD_WIDTH_FP32); + sum[5] = _MM_LOAD_FP32(result_m_index + 5*SIMD_WIDTH_FP32); + for (; j < num_j; j++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + sum[0] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 0*SIMD_WIDTH_FP32), sum[0]); + sum[1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 1*SIMD_WIDTH_FP32), sum[1]); + sum[2] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 2*SIMD_WIDTH_FP32), sum[2]); + sum[3] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 3*SIMD_WIDTH_FP32), sum[3]); + sum[4] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 4*SIMD_WIDTH_FP32), sum[4]); + sum[5] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 5*SIMD_WIDTH_FP32), sum[5]); + } + _MM_STORE_FP32(result_m_index + 0*SIMD_WIDTH_FP32, sum[0]); + _MM_STORE_FP32(result_m_index + 1*SIMD_WIDTH_FP32, sum[1]); + _MM_STORE_FP32(result_m_index + 2*SIMD_WIDTH_FP32, sum[2]); + _MM_STORE_FP32(result_m_index + 3*SIMD_WIDTH_FP32, sum[3]); + _MM_STORE_FP32(result_m_index + 4*SIMD_WIDTH_FP32, sum[4]); + _MM_STORE_FP32(result_m_index + 5*SIMD_WIDTH_FP32, sum[5]); + } + else { + int64_t j = 0; + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_SETZERO_FP32(); + sum[n+1] = _MM_SETZERO_FP32(); + } + for (; j < num_j; j++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n) *SIMD_WIDTH_FP32), sum[n]); + sum[n+1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1]); + } + { + float v_v_f = sp_v_ptr_base[j]; + for (n = last_n_start; n < num_n; n++) { + result_m_index[n] += sp_col_dense_index[n]*v_v_f; + } + } + } + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(result_m_index + ((size_t)n) *SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n], _MM_LOAD_FP32(result_m_index + ((size_t)n) *SIMD_WIDTH_FP32))); + _MM_STORE_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+1], _MM_LOAD_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32))); + } + + } + } +} /* kb */ + +/* Copy out c matrix */ +if ('T' == transc || 't' == transc) { + int num_m_simd = num_m / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int n2; + + ptr_result = c + (size_t)n_overall_start*handle->m + m_overall_start; + for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { + for (m = 0; m < num_m_simd; m += SIMD_WIDTH_FP32) { + TRANSPOSE_SIMD_WIDTH_KERNEL(scratch_C + (size_t)m*n_block_size + n, n_block_size, ptr_result + (size_t)n*handle->m + m, handle->m); + } + /* Transpose a SIMD_WIDTH_FP32 * (num_m - num_m_simd) block of output space - input is of size (num_m - num_m_simd) * SIMD_WIDTH_FP32 */ + for (n2 = n; n2 < n + SIMD_WIDTH_FP32; n2++) { + for (m = num_m_simd; m < num_m; m++) { + ptr_result[n2*handle->m + m] = scratch_C[m*n_block_size + n2]; + } + } + } + /* Transpose a (num_n - num_n_simd) * num_m block of output space - input is of size num_m * (num_n - num_n_simd) */ + for (n = num_n_simd; n < num_n; n++) { + for (m = 0; m < num_m; m++) { + ptr_result[n*handle->m + m] = scratch_C[m*n_block_size + n]; + } + } +} +else { + if (!last_block_n) { + for (m = 0; m < num_m; m++) { + _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 0*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 1*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 2*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 3*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 4*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 5*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32)); + } + } + else { + for (m = 0; m < num_m; m++) { + for (n = 0; n < num_full_regs; n += 2) { + _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n) *SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32)); + } + for (n = last_n_start; n < num_n; n++) { + ptr_result[m*handle->n + n] = scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n]; + } + } + } +} + +#undef LIBXSMM_SPMDM_COMPUTE_NREGS diff --git a/third_party/libxsmm/src/template/libxsmm_spmdm_compute_fp32_thread.tpl.c b/third_party/libxsmm/src/template/libxsmm_spmdm_compute_fp32_thread.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..18601cc8bbcf2cc48d5f8c4d7d7afb35a906629f --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_spmdm_compute_fp32_thread.tpl.c @@ -0,0 +1,542 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Nadathur Satish (Intel Corp.) +******************************************************************************/ + +const int m_blocks = handle->mb; +/* const int n_blocks = handle->nb; */ +const int k_blocks = handle->kb; +const int m_block_size = handle->bm; +const int n_block_size = handle->bn; +const int k_block_size = handle->bk; +const int handle_m = handle->m; +const int handle_n = handle->n; +int mb = block_id / handle->nb; +int nb = block_id % handle->nb; + +#define LIBXSMM_SPMDM_COMPUTE_NREGS (6) +int m_overall_start = mb*m_block_size; +int m_overall_end = (mb + 1)*m_block_size; +int num_m; +int num_m_aligned; + +int n_overall_start = nb*n_block_size; +int n_overall_end = (nb + 1)*n_block_size; +int num_n; +int m, n, k, kb; +int last_block_n, num_full_regs, last_n_start; + +int k_overall_start, k_overall_end, num_k; + +float *const scratch_C = (float*)(handle->base_ptr_scratch_B_scratch_C + (size_t)tid*handle->memory_for_scratch_per_thread); +float *const scratch_B = (float*)(handle->base_ptr_scratch_B_scratch_C + (size_t)tid*handle->memory_for_scratch_per_thread + (size_t)m_block_size*n_block_size*sizeof(float)); +float* LIBXSMM_RESTRICT ptr_result; + +LIBXSMM_UNUSED(nthreads); +LIBXSMM_UNUSED(transa); +LIBXSMM_UNUSED(alpha); +LIBXSMM_UNUSED(beta); +LIBXSMM_UNUSED(tid); + +/* really is twice this */ +assert(n_block_size == LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32); + +if (m_overall_end > handle_m) m_overall_end = handle_m; +num_m = (m_overall_end - m_overall_start); +num_m_aligned = (num_m / 2) * 2; + +if (n_overall_end > handle_n) n_overall_end = handle_n; +num_n = (n_overall_end - n_overall_start); +last_block_n = (num_n != n_block_size); +num_full_regs = (num_n / SIMD_WIDTH_FP32); +if ((num_full_regs > 0) && (num_full_regs%2)) num_full_regs--; +last_n_start = num_full_regs*SIMD_WIDTH_FP32; + +/* Copy in c matrix to buffer*/ +ptr_result = c + (size_t)m_overall_start*handle_n + n_overall_start; +if (LIBXSMM_FEQ(0.f, *beta)) { + if (!last_block_n) { + for (m = 0; m < num_m; m++) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + } + } else { + for (m = 0; m < num_m; m++) { + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); + } + for (n = last_n_start; n < num_n; n++) { + scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = 0; + } + } + } +} +else if (LIBXSMM_FEQ(1.f, *beta)) { + if ('T' == transc || 't' == transc) { + int num_m_simd = num_m / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int m2; + + ptr_result = c + (size_t)n_overall_start*handle_m + m_overall_start; + + for (m = 0; m < num_m_simd; m += SIMD_WIDTH_FP32) { + for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { + TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_result + (size_t)n*handle_m + m, handle_m, scratch_C + (size_t)m*n_block_size + n, n_block_size); + } + /* Transpose a SIMD_WIDTH_FP32 * (num_n - num_n_simd) block of output space - input is of size (num_n - num_n_simd) * SIMD_WIDTH_FP32 */ + for (m2 = m; m2 < m + SIMD_WIDTH_FP32; m2++) { + for (n = num_n_simd; n < num_n; n++) { + scratch_C[m2*n_block_size + n] = ptr_result[n*handle_m + m2]; + } + } + } + /* Transpose a (num_m - num_m_simd) * num_n block of output space - input is of size num_n * (num_m - num_m_simd) */ + for (m = num_m_simd; m < num_m; m++) { + for (n = 0; n < num_n; n++) { + scratch_C[m*n_block_size + n] = ptr_result[n*handle_m + m]; + } + } + } + else { + if (!last_block_n) { + for (m = 0; m < num_m; m++) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 0*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 1*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 2*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 3*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 4*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 5*SIMD_WIDTH_FP32)); + } + } + else { + for (m = 0; m < num_m; m++) { + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + ((size_t)n) *SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + ((size_t)n+1)*SIMD_WIDTH_FP32)); + } + for (n = last_n_start; n < num_n; n++) { + scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = ptr_result[m*handle_n + n]; + } + } + } + } +} +else { + SIMDTYPE_FP32 beta_v = _MM_SET1_FP32(*beta); + if ('T' == transc || 't' == transc) { + int num_m_simd = num_m / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int m2; + + ptr_result = c + (size_t)n_overall_start*handle_m + m_overall_start; + + for (m = 0; m < num_m_simd; m += SIMD_WIDTH_FP32) { + for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { + TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_result + (size_t)n*handle_m + m, handle_m, scratch_C + (size_t)m*n_block_size + n, n_block_size); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*1, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*1))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*2, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*2))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*3, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*3))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*4, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*4))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*5, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*5))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*6, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*6))); + _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*7, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*7))); + } + /* Transpose a SIMD_WIDTH_FP32 * (num_n - num_n_simd) block of output space - input is of size (num_n - num_n_simd) * SIMD_WIDTH_FP32 */ + for (m2 = m; m2 < m + SIMD_WIDTH_FP32; m2++) { + for (n = num_n_simd; n < num_n; n++) { + scratch_C[m2*n_block_size + n] = (*beta)*ptr_result[n*handle_m + m2]; + } + } + } + /* Transpose a (num_m - num_m_simd) * num_n block of output space - input is of size num_n * (num_m - num_m_simd) */ + for (m = num_m_simd; m < num_m; m++) { + for (n = 0; n < num_n; n++) { + scratch_C[m*n_block_size + n] = (*beta)*ptr_result[n*handle_m + m]; + } + } + + } + else { + if (!last_block_n) { + for (m = 0; m < num_m; m++) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 0*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 1*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 2*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 3*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 4*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + 5*SIMD_WIDTH_FP32))); + } + } + else { + for (m = 0; m < num_m; m++) { + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + ((size_t)n) *SIMD_WIDTH_FP32))); + _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle_n + ((size_t)n+1)*SIMD_WIDTH_FP32))); + } + for (n = last_n_start; n < num_n; n++) { + scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = (*beta)*ptr_result[m*handle_n + n]; + } + } + } + } +} + +for (kb = 0; kb < k_blocks; kb++) { + const float * LIBXSMM_RESTRICT ptr_dense; + float * LIBXSMM_RESTRICT scratch_C_base; + const float * LIBXSMM_RESTRICT scratch_B_base; + int block_A = kb * m_blocks + mb; + libxsmm_CSR_sparseslice slice = a_sparse[block_A]; + int m_local = 0; + + k_overall_start = kb*k_block_size; + k_overall_end = (kb+1)*k_block_size; + if (k_overall_end > handle->k) k_overall_end = handle->k; + num_k = (k_overall_end - k_overall_start); + + /* Copy in b matrix*/ + if ('T' == transb || 't' == transb) { + int num_k_simd = num_k / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int k2; + + ptr_dense = b + (size_t)n_overall_start*handle->k + k_overall_start; + + for (k = 0; k < num_k_simd; k += SIMD_WIDTH_FP32) { + for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { + TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_dense + (size_t)n*handle->k + k, handle->k, scratch_B + (size_t)k*n_block_size + n, n_block_size); + } + /* Transpose a SIMD_WIDTH_FP32 * (num_n - num_n_simd) block of output space - input is of size (num_n - num_n_simd) * SIMD_WIDTH_FP32 */ + for (k2 = k; k2 < k + SIMD_WIDTH_FP32; k2++) { + for (n = num_n_simd; n < num_n; n++) { + scratch_B[k2*n_block_size + n] = ptr_dense[n*handle->k + k2]; + } + } + } + /* Transpose a (num_m - num_m_simd) * num_n block of output space - input is of size num_n * (num_m - num_m_simd) */ + for (k = num_k_simd; k < num_k; k++) { + for (n = 0; n < num_n; n++) { + scratch_B[k*n_block_size + n] = ptr_dense[n*handle->k + k]; + } + } + } + else { + ptr_dense = b + (size_t)k_overall_start*handle_n + n_overall_start; + if (!last_block_n) { + for (k = 0; k < num_k; k++) { + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_dense + (size_t)k*handle_n + 0*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_dense + (size_t)k*handle_n + 1*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_dense + (size_t)k*handle_n + 2*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_dense + (size_t)k*handle_n + 3*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_dense + (size_t)k*handle_n + 4*SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_dense + (size_t)k*handle_n + 5*SIMD_WIDTH_FP32)); + } + } else { + for (k = 0; k < num_k; k++) { + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_dense + (size_t)k*handle_n + ((size_t)n) *SIMD_WIDTH_FP32)); + _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_dense + (size_t)k*handle_n + ((size_t)n+1)*SIMD_WIDTH_FP32)); + } + for (n = last_n_start; n < num_n; n++) { + scratch_B[k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = ptr_dense[k*handle_n + n]; + } + } + } + } + + scratch_C_base = scratch_C - (size_t)m_overall_start*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + scratch_B_base = scratch_B; /* - (size_t)k_overall_start*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32;*/ + + for (m = m_overall_start; m < m_overall_start + num_m_aligned; m += 2, m_local += 2) { + int start_j, end_j, end_j_2, num_j, num_j_2; + const uint16_t *LIBXSMM_RESTRICT sp_c_ptr_base; + const uint16_t *LIBXSMM_RESTRICT sp_c_ptr_base_2; + const float *LIBXSMM_RESTRICT sp_v_ptr_base; + const float *LIBXSMM_RESTRICT sp_v_ptr_base_2; + float *LIBXSMM_RESTRICT result_m_index; + float *LIBXSMM_RESTRICT result_m_index_2; + const uint16_t* rowidx; + + if (m_local >= m_block_size) { block_A++; slice = a_sparse[block_A]; m_local = 0; } + + rowidx = slice.rowidx; + start_j = rowidx[m_local]; + end_j = rowidx[m_local+1]; + end_j_2 = rowidx[m_local+2]; + num_j = (end_j - start_j); + num_j_2 = (end_j_2 - end_j); + sp_c_ptr_base = slice.colidx + start_j; + sp_c_ptr_base_2 = slice.colidx + end_j; + sp_v_ptr_base = (float *)(slice.values) + start_j; + sp_v_ptr_base_2 = (float *)(slice.values) + end_j; + result_m_index = scratch_C_base + ((size_t)m) *LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + result_m_index_2 = scratch_C_base + ((size_t)m+1)*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + + if (!last_block_n) + { + int64_t j = 0, j2 = 0; + SIMDTYPE_FP32 sum[2*LIBXSMM_SPMDM_COMPUTE_NREGS]; + sum[0] = _MM_LOAD_FP32(result_m_index + 0*SIMD_WIDTH_FP32); + sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 0*SIMD_WIDTH_FP32); + sum[1] = _MM_LOAD_FP32(result_m_index + 1*SIMD_WIDTH_FP32); + sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 1*SIMD_WIDTH_FP32); + sum[2] = _MM_LOAD_FP32(result_m_index + 2*SIMD_WIDTH_FP32); + sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 2*SIMD_WIDTH_FP32); + sum[3] = _MM_LOAD_FP32(result_m_index + 3*SIMD_WIDTH_FP32); + sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 3*SIMD_WIDTH_FP32); + sum[4] = _MM_LOAD_FP32(result_m_index + 4*SIMD_WIDTH_FP32); + sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 4*SIMD_WIDTH_FP32); + sum[5] = _MM_LOAD_FP32(result_m_index + 5*SIMD_WIDTH_FP32); + sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 5*SIMD_WIDTH_FP32); + for (; j < num_j && j2 < num_j_2; j++, j2++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); + sum[0] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 0*SIMD_WIDTH_FP32), sum[0]); + sum[0 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 0*SIMD_WIDTH_FP32), sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 1*SIMD_WIDTH_FP32), sum[1]); + sum[1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 1*SIMD_WIDTH_FP32), sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[2] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 2*SIMD_WIDTH_FP32), sum[2]); + sum[2 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 2*SIMD_WIDTH_FP32), sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[3] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 3*SIMD_WIDTH_FP32), sum[3]); + sum[3 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 3*SIMD_WIDTH_FP32), sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[4] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 4*SIMD_WIDTH_FP32), sum[4]); + sum[4 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 4*SIMD_WIDTH_FP32), sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[5] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 5*SIMD_WIDTH_FP32), sum[5]); + sum[5 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 5*SIMD_WIDTH_FP32), sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + for (; j < num_j; j++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + sum[0] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 0*SIMD_WIDTH_FP32), sum[0]); + sum[1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 1*SIMD_WIDTH_FP32), sum[1]); + sum[2] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 2*SIMD_WIDTH_FP32), sum[2]); + sum[3] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 3*SIMD_WIDTH_FP32), sum[3]); + sum[4] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 4*SIMD_WIDTH_FP32), sum[4]); + sum[5] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 5*SIMD_WIDTH_FP32), sum[5]); + } + for (; j2 < num_j_2; j2++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); + sum[0 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 0*SIMD_WIDTH_FP32), sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 1*SIMD_WIDTH_FP32), sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[2 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 2*SIMD_WIDTH_FP32), sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[3 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 3*SIMD_WIDTH_FP32), sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[4 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 4*SIMD_WIDTH_FP32), sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[5 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 5*SIMD_WIDTH_FP32), sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + _MM_STORE_FP32(result_m_index + 0*SIMD_WIDTH_FP32, sum[0]); + _MM_STORE_FP32(result_m_index_2 + 0*SIMD_WIDTH_FP32, sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 1*SIMD_WIDTH_FP32, sum[1]); + _MM_STORE_FP32(result_m_index_2 + 1*SIMD_WIDTH_FP32, sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 2*SIMD_WIDTH_FP32, sum[2]); + _MM_STORE_FP32(result_m_index_2 + 2*SIMD_WIDTH_FP32, sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 3*SIMD_WIDTH_FP32, sum[3]); + _MM_STORE_FP32(result_m_index_2 + 3*SIMD_WIDTH_FP32, sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 4*SIMD_WIDTH_FP32, sum[4]); + _MM_STORE_FP32(result_m_index_2 + 4*SIMD_WIDTH_FP32, sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS]); + _MM_STORE_FP32(result_m_index + 5*SIMD_WIDTH_FP32, sum[5]); + _MM_STORE_FP32(result_m_index_2 + 5*SIMD_WIDTH_FP32, sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + else { + int64_t j = 0, j2 = 0; + SIMDTYPE_FP32 sum[2*LIBXSMM_SPMDM_COMPUTE_NREGS]; + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_SETZERO_FP32(); + sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_SETZERO_FP32(); + sum[n+1] = _MM_SETZERO_FP32(); + sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_SETZERO_FP32(); + } + for (; j < num_j && j2 < num_j_2; j++, j2++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + (size_t)n*SIMD_WIDTH_FP32), sum[n]); + sum[n + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + (size_t)n*SIMD_WIDTH_FP32), sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[n+1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1]); + sum[n+1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + { + float v_v_f = sp_v_ptr_base[j]; + float v_v_f_2 = sp_v_ptr_base_2[j2]; + for (n = last_n_start; n < num_n; n++) { + result_m_index[n] += sp_col_dense_index[n]*v_v_f; + result_m_index_2[n] += sp_col_dense_index_2[n]*v_v_f_2; + } + } + } + for (; j < num_j; j++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n) *SIMD_WIDTH_FP32), sum[n]); + sum[n+1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1]); + } + { + float v_v_f = sp_v_ptr_base[j]; + for (n = last_n_start; n < num_n; n++) { + result_m_index[n] += sp_col_dense_index[n]*v_v_f; + } + } + } + for (; j2 < num_j_2; j2++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); + for (n = 0; n < num_full_regs; n += 2) { + sum[n + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + ((size_t)n) *SIMD_WIDTH_FP32), sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS]); + sum[n+1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS]); + } + { + float v_v_f_2 = sp_v_ptr_base_2[j2]; + for (n = last_n_start; n < num_n; n++) { + result_m_index_2[n] += sp_col_dense_index_2[n]*v_v_f_2; + } + } + } + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(result_m_index + ((size_t)n) *SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n], _MM_LOAD_FP32(result_m_index + (size_t)n*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(result_m_index_2 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS], _MM_LOAD_FP32(result_m_index_2 + (size_t)n*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+1], _MM_LOAD_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32))); + _MM_STORE_FP32(result_m_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS], _MM_LOAD_FP32(result_m_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32))); + } + } + } + for (m = m_overall_start + num_m_aligned; m < m_overall_end; m++, m_local++) { + int start_j, end_j, num_j; + const uint16_t *LIBXSMM_RESTRICT sp_c_ptr_base; + const float *LIBXSMM_RESTRICT sp_v_ptr_base; + float *LIBXSMM_RESTRICT result_m_index; + const uint16_t* rowidx; + + if (m_local >= m_block_size) { block_A++; slice = a_sparse[block_A]; m_local = 0; } + + rowidx = slice.rowidx; + start_j = rowidx[m_local]; + end_j = rowidx[m_local+1]; + num_j = (end_j - start_j); + sp_c_ptr_base = slice.colidx + start_j; + sp_v_ptr_base = slice.values + start_j; + result_m_index = scratch_C_base + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + + if (!last_block_n) { + int64_t j = 0; + SIMDTYPE_FP32 sum[2*LIBXSMM_SPMDM_COMPUTE_NREGS]; + sum[0] = _MM_LOAD_FP32(result_m_index + 0*SIMD_WIDTH_FP32); + sum[1] = _MM_LOAD_FP32(result_m_index + 1*SIMD_WIDTH_FP32); + sum[2] = _MM_LOAD_FP32(result_m_index + 2*SIMD_WIDTH_FP32); + sum[3] = _MM_LOAD_FP32(result_m_index + 3*SIMD_WIDTH_FP32); + sum[4] = _MM_LOAD_FP32(result_m_index + 4*SIMD_WIDTH_FP32); + sum[5] = _MM_LOAD_FP32(result_m_index + 5*SIMD_WIDTH_FP32); + for (; j < num_j; j++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + sum[0] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 0*SIMD_WIDTH_FP32), sum[0]); + sum[1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 1*SIMD_WIDTH_FP32), sum[1]); + sum[2] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 2*SIMD_WIDTH_FP32), sum[2]); + sum[3] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 3*SIMD_WIDTH_FP32), sum[3]); + sum[4] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 4*SIMD_WIDTH_FP32), sum[4]); + sum[5] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 5*SIMD_WIDTH_FP32), sum[5]); + } + _MM_STORE_FP32(result_m_index + 0*SIMD_WIDTH_FP32, sum[0]); + _MM_STORE_FP32(result_m_index + 1*SIMD_WIDTH_FP32, sum[1]); + _MM_STORE_FP32(result_m_index + 2*SIMD_WIDTH_FP32, sum[2]); + _MM_STORE_FP32(result_m_index + 3*SIMD_WIDTH_FP32, sum[3]); + _MM_STORE_FP32(result_m_index + 4*SIMD_WIDTH_FP32, sum[4]); + _MM_STORE_FP32(result_m_index + 5*SIMD_WIDTH_FP32, sum[5]); + } + else { + SIMDTYPE_FP32 sum[2*LIBXSMM_SPMDM_COMPUTE_NREGS]; + int64_t j = 0; + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_SETZERO_FP32(); + sum[n+1] = _MM_SETZERO_FP32(); + } + for (; j < num_j; j++) { + const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; + SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); + for (n = 0; n < num_full_regs; n += 2) { + sum[n] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n) *SIMD_WIDTH_FP32), sum[n]); + sum[n+1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1]); + } + { + float v_v_f = sp_v_ptr_base[j]; + for (n = last_n_start; n < num_n; n++) { + result_m_index[n] += sp_col_dense_index[n]*v_v_f; + } + } + } + for (n = 0; n < num_full_regs; n += 2) { + _MM_STORE_FP32(result_m_index + ((size_t)n) *SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n], _MM_LOAD_FP32(result_m_index + ((size_t)n) *SIMD_WIDTH_FP32))); + _MM_STORE_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+1], _MM_LOAD_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32))); + } + } + } +} /* kb */ + +/* Copy out c matrix */ +if ('T' == transc || 't' == transc) { + int num_m_simd = num_m / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; + int n2; + + ptr_result = c + (size_t)n_overall_start*handle_m + m_overall_start; + for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { + for (m = 0; m < num_m_simd; m += SIMD_WIDTH_FP32) { + TRANSPOSE_SIMD_WIDTH_KERNEL(scratch_C + (size_t)m*n_block_size + n, n_block_size, ptr_result + (size_t)n*handle_m + m, handle_m); + } + /* Transpose a SIMD_WIDTH_FP32 * (num_m - num_m_simd) block of output space - input is of size (num_m - num_m_simd) * SIMD_WIDTH_FP32 */ + for (n2 = n; n2 < n + SIMD_WIDTH_FP32; n2++) { + for (m = num_m_simd; m < num_m; m++) { + ptr_result[n2*handle_m + m] = scratch_C[m*n_block_size + n2]; + } + } + } + /* Transpose a (num_n - num_n_simd) * num_m block of output space - input is of size num_m * (num_n - num_n_simd) */ + for (n = num_n_simd; n < num_n; n++) { + for (m = 0; m < num_m; m++) { + ptr_result[n*handle_m + m] = scratch_C[m*n_block_size + n]; + } + } +} +else { + if (!last_block_n) { + for (m = 0; m < num_m; m++) { + _MM_STOREU_FP32(ptr_result + (size_t)m*handle_n + 0*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle_n + 1*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle_n + 2*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle_n + 3*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle_n + 4*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle_n + 5*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32)); + } + } + else { + for (m = 0; m < num_m; m++) { + for (n = 0; n < num_full_regs; n += 2) { + _MM_STOREU_FP32(ptr_result + (size_t)m*handle_n + ((size_t)n)*SIMD_WIDTH_FP32, + _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32)); + _MM_STOREU_FP32(ptr_result + (size_t)m*handle_n + ((size_t)n+1)*SIMD_WIDTH_FP32, + _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32)); + } + for (n = last_n_start; n < num_n; n++) { + ptr_result[m*handle_n + n] = scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n]; + } + } + } +} + +#undef LIBXSMM_SPMDM_COMPUTE_NREGS diff --git a/third_party/libxsmm/src/template/libxsmm_spmdm_createSparseSlice_bfloat16_thread.tpl.c b/third_party/libxsmm/src/template/libxsmm_spmdm_createSparseSlice_bfloat16_thread.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..14b5720fed5876929c8320a0b005a9c20b6e03a6 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_spmdm_createSparseSlice_bfloat16_thread.tpl.c @@ -0,0 +1,126 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Nadathur Satish (Intel Corp.) +******************************************************************************/ + +int i, k; +int mb, kb; +#if SIMD_WIDTH_FP32 == 8 +const __m256i *const shufmasks = internal_spmdm_shufmasks_32; +#endif +#if SIMD_WIDTH_FP32 > 1 +const __m256i *const shufmasks2 = internal_spmdm_shufmasks_16; +#endif +int block_offset_base, block_offset; + +LIBXSMM_UNUSED(nthreads); +LIBXSMM_UNUSED(tid); + +kb = block_id / handle->mb; +mb = block_id % handle->mb; + +if ('T' == transa || 't' == transa) { + block_offset_base = mb * handle->bm; + block_offset = block_offset_base + kb * handle->m * handle->bk; +} +else { + block_offset_base = kb * handle->bk; + block_offset = block_offset_base + mb * handle->k * handle->bm; +} +{ + libxsmm_CSR_sparseslice slice = libxsmm_output_csr_a[kb*handle->mb + mb]; + int nrows = ((mb + 1)*handle->bm > handle->m)?(handle->m - (mb)*handle->bm):handle->bm; + int ncols = ((kb + 1)*handle->bk > handle->k)?(handle->k - (kb)*handle->bk):handle->bk; + /*printf("nrows: %d, ncols: %d\n", nrows, ncols);*/ + const uint16_t * input_ptr = a + block_offset; + uint16_t * rowidx_ptr = slice.rowidx; + uint16_t * colidx_ptr = slice.colidx; + float * values_ptr = (float *)(slice.values); + uint16_t cnt = 0; +#if SIMD_WIDTH_FP32 > 1 + const SIMDTYPE_INT32 vzero = _MM_SETZERO_INT32(); + const SIMDTYPE_FP32 vzerof = _MM_SETZERO_FP32(); + const int ncols_aligned = ncols / (4*SIMD_WIDTH_FP32)*(4*SIMD_WIDTH_FP32); +#else + const int ncols_aligned = 0; +#endif + for (i = 0; i < nrows; i++) { + rowidx_ptr[i] = cnt; + if ('T' == transa || 't' == transa) { +#if SIMD_WIDTH_FP32 > 1 + for (k = 0; k < ncols_aligned; k += 4*SIMD_WIDTH_FP32) { + int vals[32]; + int kk; + for (kk = 0; kk < 4*SIMD_WIDTH_FP32; kk += 2) { vals[kk/2] = (int)input_ptr[(k+kk)*handle->m + i]; vals[kk/2] |= ((int)(input_ptr[(k+kk+1)*handle->m + i]) << 16); } + { + SIMDTYPE_INT32 v1tmp = _MM_LOADU_INT32(vals); + SIMDTYPE_INT32 v2tmp = _MM_LOADU_INT32(vals + SIMD_WIDTH_FP32); + SIMDTYPE_FP32 v1, v2, v3, v4; + SIMDMASKTYPE_FP32 m1, m2, m3, m4; + EXPAND_BFLOAT16(v1tmp, v1, v2); + EXPAND_BFLOAT16(v2tmp, v3, v4); + m1 = _MM_CMPNEQ_FP32(v1, vzerof); + m2 = _MM_CMPNEQ_FP32(v2, vzerof); + m3 = _MM_CMPNEQ_FP32(v3, vzerof); + m4 = _MM_CMPNEQ_FP32(v4, vzerof); + COMPRESS_FP32(v1, k, m1, cnt); + COMPRESS_FP32(v2, k + SIMD_WIDTH_FP32, m2, cnt); + COMPRESS_FP32(v3, k + 2*SIMD_WIDTH_FP32, m3, cnt); + COMPRESS_FP32(v4, k + 3*SIMD_WIDTH_FP32, m4, cnt); + } + } +#endif + for (k = ncols_aligned; k < ncols; k++) { + uint16_t v1tmp = input_ptr[k*handle->m + i]; + union {int i; float f; } v1tmp_int; + v1tmp_int.i = v1tmp; + v1tmp_int.i <<= 16; + { + const int m1 = LIBXSMM_FEQ(0, v1tmp_int.f) ? 0 : 1; + if (m1) { colidx_ptr[cnt] = (uint16_t)k; values_ptr[cnt] = v1tmp_int.f; cnt++; } + } + } + } + else { +#if SIMD_WIDTH_FP32 > 1 + for (k = 0; k < ncols_aligned; k += 4*SIMD_WIDTH_FP32) { + SIMDTYPE_INT32 v1tmp, v2tmp; + SIMDTYPE_FP32 v1, v2, v3, v4; + SIMDMASKTYPE_FP32 m1, m2, m3, m4; + v1tmp = _MM_LOADU_INT32((const SIMDTYPE_INT32*)(input_ptr + (size_t)i*handle->k + k)); + _MM_PREFETCH((char *)(input_ptr + ((size_t)i+2)*handle->k + k), _MM_HINT_T0); + v2tmp = _MM_LOADU_INT32((const SIMDTYPE_INT32*)(input_ptr + (size_t)i*handle->k + k + 2*SIMD_WIDTH_FP32)); + _MM_PREFETCH((char *)(input_ptr + ((size_t)i+2)*handle->k + k + SIMD_WIDTH_FP32), _MM_HINT_T0); + EXPAND_BFLOAT16(v1tmp, v1, v2); + EXPAND_BFLOAT16(v2tmp, v3, v4); + m1 = _MM_CMPNEQ_FP32(v1, vzerof); + m2 = _MM_CMPNEQ_FP32(v2, vzerof); + m3 = _MM_CMPNEQ_FP32(v3, vzerof); + m4 = _MM_CMPNEQ_FP32(v4, vzerof); + COMPRESS_FP32(v1, k, m1, cnt); + COMPRESS_FP32(v2, k + SIMD_WIDTH_FP32, m2, cnt); + COMPRESS_FP32(v3, k + 2*SIMD_WIDTH_FP32, m3, cnt); + COMPRESS_FP32(v4, k + 3*SIMD_WIDTH_FP32, m4, cnt); + } +#endif + for (k = ncols_aligned; k < ncols; k++) { + uint16_t v1tmp = input_ptr[i*handle->k + k]; + union {int i; float f; } v1tmp_int; + v1tmp_int.i = v1tmp; + v1tmp_int.i <<= 16; + { + int m1 = LIBXSMM_FEQ(0, v1tmp_int.f) ? 0 : 1; + if (m1) { colidx_ptr[cnt] = (uint16_t)k; values_ptr[cnt] = v1tmp_int.f; cnt++; } + } + } + } + } + rowidx_ptr[nrows] = cnt; +} + diff --git a/third_party/libxsmm/src/template/libxsmm_spmdm_createSparseSlice_fp32_thread.tpl.c b/third_party/libxsmm/src/template/libxsmm_spmdm_createSparseSlice_fp32_thread.tpl.c new file mode 100644 index 0000000000000000000000000000000000000000..7d1bb3559b99b7d1b3bb948129ac5a1291bcf62c --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_spmdm_createSparseSlice_fp32_thread.tpl.c @@ -0,0 +1,129 @@ +/****************************************************************************** +* Copyright (c) Intel Corporation - All rights reserved. * +* This file is part of the LIBXSMM library. * +* * +* For information on the license, see the LICENSE file. * +* Further information: https://github.com/hfp/libxsmm/ * +* SPDX-License-Identifier: BSD-3-Clause * +******************************************************************************/ +/* Nadathur Satish (Intel Corp.) +******************************************************************************/ + +int i, k; +int mb, kb; +#if SIMD_WIDTH_FP32 == 8 +const __m256i *const shufmasks = internal_spmdm_shufmasks_32; +#endif +#if SIMD_WIDTH_FP32 > 1 +const __m256i *const shufmasks2 = internal_spmdm_shufmasks_16; +SIMDTYPE_INT32 vindex = _MM_SETZERO_INT32(); +int idx_array[16]; +#endif +int block_offset_base, block_offset; + +LIBXSMM_UNUSED(nthreads); +LIBXSMM_UNUSED(tid); + +kb = block_id / handle->mb; +mb = block_id % handle->mb; +if ('T' == transa || 't' == transa) { +#if SIMD_WIDTH_FP32 > 1 + int kk; + for (kk = 0; kk < SIMD_WIDTH_FP32; kk++) idx_array[kk] = kk * handle->m; + vindex = _MM_LOADU_INT32(idx_array); +#endif + block_offset_base = mb * handle->bm; + block_offset = block_offset_base + kb * handle->m * handle->bk; +} +else { + block_offset_base = kb * handle->bk; + block_offset = block_offset_base + mb * handle->k * handle->bm; +} +{ + libxsmm_CSR_sparseslice slice = libxsmm_output_csr_a[kb*handle->mb + mb]; + int nrows = ((mb + 1)*handle->bm > handle->m)?(handle->m - (mb)*handle->bm):handle->bm; + int ncols = ((kb + 1)*handle->bk > handle->k)?(handle->k - (kb)*handle->bk):handle->bk; + /*printf("nrows: %d, ncols: %d\n", nrows, ncols);*/ + const float * input_ptr = a + block_offset; + uint16_t * rowidx_ptr = slice.rowidx; + uint16_t * colidx_ptr = slice.colidx; + float * values_ptr = (float *)(slice.values); + uint16_t cnt = 0; +#if SIMD_WIDTH_FP32 > 1 + const SIMDTYPE_FP32 vzero = _MM_SETZERO_FP32(); + const int ncols_aligned = ncols / (4*SIMD_WIDTH_FP32)*(4*SIMD_WIDTH_FP32); + const int ncols_aligned_2 = ncols / (SIMD_WIDTH_FP32)*(SIMD_WIDTH_FP32); +#else + const int ncols_aligned_2 = 0; +#endif + for (i = 0; i < nrows; i++) { + rowidx_ptr[i] = cnt; + if ('T' == transa || 't' == transa) { +#if SIMD_WIDTH_FP32 > 1 + for (k = 0; k < ncols_aligned; k += 4*SIMD_WIDTH_FP32) { + SIMDTYPE_FP32 v1 = _MM_GATHER_FP32(input_ptr + (size_t)k * handle->m + i, vindex, 4); + SIMDTYPE_FP32 v2 = _MM_GATHER_FP32(input_ptr + ((size_t)k+1*SIMD_WIDTH_FP32) * handle->m + i, vindex, 4); + SIMDTYPE_FP32 v3 = _MM_GATHER_FP32(input_ptr + ((size_t)k+2*SIMD_WIDTH_FP32) * handle->m + i, vindex, 4); + SIMDTYPE_FP32 v4 = _MM_GATHER_FP32(input_ptr + ((size_t)k+3*SIMD_WIDTH_FP32) * handle->m + i, vindex, 4); + SIMDMASKTYPE_FP32 m1 = _MM_CMPNEQ_FP32(v1, vzero); + SIMDMASKTYPE_FP32 m2 = _MM_CMPNEQ_FP32(v2, vzero); + SIMDMASKTYPE_FP32 m3 = _MM_CMPNEQ_FP32(v3, vzero); + SIMDMASKTYPE_FP32 m4 = _MM_CMPNEQ_FP32(v4, vzero); + COMPRESS_FP32(v1, k, m1, cnt); + COMPRESS_FP32(v2, k + SIMD_WIDTH_FP32, m2, cnt); + COMPRESS_FP32(v3, k + 2*SIMD_WIDTH_FP32, m3, cnt); + COMPRESS_FP32(v4, k + 3*SIMD_WIDTH_FP32, m4, cnt); + } + for (k = ncols_aligned; k < ncols_aligned_2; k += SIMD_WIDTH_FP32) { + SIMDTYPE_FP32 v1 = _MM_GATHER_FP32(input_ptr + (size_t)k * handle->m + i, vindex, 4); + SIMDMASKTYPE_FP32 m1 = _MM_CMPNEQ_FP32(v1, vzero); + COMPRESS_FP32(v1, k, m1, cnt); + } +#endif + for (k = ncols_aligned_2; k < ncols; k++) { + const float v1 = input_ptr[i + k*handle->m]; + const int m1 = LIBXSMM_FEQ(0, v1) ? 0 : 1; + if (m1) { colidx_ptr[cnt] = (uint16_t)k; values_ptr[cnt] = v1; cnt++; } + } + } + else { +#if SIMD_WIDTH_FP32 > 1 + for (k = 0; k < ncols_aligned; k += 4*SIMD_WIDTH_FP32) { + SIMDTYPE_FP32 v1, v2, v3, v4; + SIMDMASKTYPE_FP32 m1, m2, m3, m4; + v1 = _MM_LOADU_FP32(input_ptr + ((size_t)i) * handle->k + (size_t)k); + _MM_PREFETCH((char*)input_ptr + ((size_t)i+2) * handle->k + (size_t)k, _MM_HINT_T0); + v2 = _MM_LOADU_FP32(input_ptr + ((size_t)i) * handle->k + (size_t)k + (size_t)SIMD_WIDTH_FP32); + _MM_PREFETCH((char*)input_ptr + ((size_t)i+2) * handle->k + (size_t)k + (size_t)SIMD_WIDTH_FP32, _MM_HINT_T0); + v3 = _MM_LOADU_FP32(input_ptr + ((size_t)i) * handle->k + (size_t)k + (size_t)2 * SIMD_WIDTH_FP32); + _MM_PREFETCH((char*)input_ptr + ((size_t)i+2) * handle->k + (size_t)k + (size_t)2 * SIMD_WIDTH_FP32, _MM_HINT_T0); + v4 = _MM_LOADU_FP32(input_ptr + ((size_t)i) * handle->k + (size_t)k + (size_t)3 * SIMD_WIDTH_FP32); + _MM_PREFETCH((char*)input_ptr + ((size_t)i+2) * handle->k + (size_t)k + (size_t)3 * SIMD_WIDTH_FP32, _MM_HINT_T0); + m1 = _MM_CMPNEQ_FP32(v1, vzero); + m2 = _MM_CMPNEQ_FP32(v2, vzero); + m3 = _MM_CMPNEQ_FP32(v3, vzero); + m4 = _MM_CMPNEQ_FP32(v4, vzero); + COMPRESS_FP32(v1, k, m1, cnt); + COMPRESS_FP32(v2, k + SIMD_WIDTH_FP32, m2, cnt); + COMPRESS_FP32(v3, k + 2*SIMD_WIDTH_FP32, m3, cnt); + COMPRESS_FP32(v4, k + 3*SIMD_WIDTH_FP32, m4, cnt); + } + for (k = ncols_aligned; k < ncols_aligned_2; k += SIMD_WIDTH_FP32) { + SIMDTYPE_FP32 v1; + SIMDMASKTYPE_FP32 m1; + v1 = _MM_LOADU_FP32(input_ptr + ((size_t)i) * handle->k + (size_t)k); + _MM_PREFETCH((char*)input_ptr + ((size_t)i+2) * handle->k + (size_t)k, _MM_HINT_T0); + m1 = _MM_CMPNEQ_FP32(v1, vzero); + COMPRESS_FP32(v1, k, m1, cnt); + } +#endif + for (k = ncols_aligned_2; k < ncols; k++) { + const float v1 = input_ptr[i*handle->k + k]; + const int m1 = LIBXSMM_FEQ(0, v1) ? 0 : 1; + if (m1) { colidx_ptr[cnt] = (uint16_t)k; values_ptr[cnt] = v1; cnt++; } + } + } + } + rowidx_ptr[nrows] = cnt; +} + diff --git a/third_party/libxsmm/src/template/libxsmm_version.h b/third_party/libxsmm/src/template/libxsmm_version.h new file mode 100644 index 0000000000000000000000000000000000000000..43bec8511ff4a6ec18dd71736689d49b612c44f7 --- /dev/null +++ b/third_party/libxsmm/src/template/libxsmm_version.h @@ -0,0 +1,12 @@ +#ifndef LIBXSMM_VERSION_H +#define LIBXSMM_VERSION_H + +#define LIBXSMM_CONFIG_VERSION "$VERSION" +#define LIBXSMM_CONFIG_BRANCH "$BRANCH" +#define LIBXSMM_CONFIG_VERSION_MAJOR $MAJOR +#define LIBXSMM_CONFIG_VERSION_MINOR $MINOR +#define LIBXSMM_CONFIG_VERSION_UPDATE $UPDATE +#define LIBXSMM_CONFIG_VERSION_PATCH $PATCH +#define LIBXSMM_CONFIG_BUILD_DATE $DATE + +#endif diff --git a/third_party/libxsmm/tests/mhd_image.mhd b/third_party/libxsmm/tests/mhd_image.mhd new file mode 100644 index 0000000000000000000000000000000000000000..495486bbfb9e521801d3a83596807380a5beca5c --- /dev/null +++ b/third_party/libxsmm/tests/mhd_image.mhd @@ -0,0 +1,13 @@ +ObjectType = Image +NDims = 3 +BinaryData = True +BinaryDataByteOrderMSB = False +CompressedData = False +TransformMatrix = 1 0 0 0 1 0 0 0 1 +Offset = 0 0 0 +CenterOfRotation = 0 0 0 +AnatomicalOrientation = RAI +ElementSpacing = 1 1 1 +DimSize = 202 134 1 +ElementType = MET_SHORT +ElementDataFile = mhd_image.raw diff --git a/third_party/nanoflann/.gitignore b/third_party/nanoflann/.gitignore deleted file mode 100644 index 9850badc34d0a82913a693b3ede6e8a936c7f71b..0000000000000000000000000000000000000000 --- a/third_party/nanoflann/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -*~ -build* - diff --git a/third_party/pcg/.gitignore b/third_party/pcg/.gitignore deleted file mode 100644 index 9f598fd590f92caf87567696e278ff08c7edb81b..0000000000000000000000000000000000000000 --- a/third_party/pcg/.gitignore +++ /dev/null @@ -1,33 +0,0 @@ -# Compiled Object files -*.slo -*.lo -*.o -*.obj - -# Precompiled Headers -*.gch -*.pch - -# Compiled Dynamic libraries -*.so -*.dylib -*.dll - -# Fortran module files -*.mod - -# Compiled Static libraries -*.lai -*.la -*.a -*.lib - -# Debug Information -*.dSYM - -# Executables -*.exe -*.out -*.app - -# Actual Project Executables diff --git a/third_party/phmap/.gitignore b/third_party/phmap/.gitignore deleted file mode 100644 index b208e2410809f3dafed62cc589276ed9b545e133..0000000000000000000000000000000000000000 --- a/third_party/phmap/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -VagrantFile -benchmark/build -benchmark/output -benchmark/charts.html -build* -.vagrant -**/.vscode -TAGS diff --git a/third_party/tensorpipe/.gitignore b/third_party/tensorpipe/.gitignore deleted file mode 100644 index 16e39eac95f521ce03186d77f3f9cdaaf2de2d0e..0000000000000000000000000000000000000000 --- a/third_party/tensorpipe/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -*~ -.DS_Store -/build/ -/cmake-build-debug/ diff --git a/third_party/thrust/.gitignore b/third_party/thrust/.gitignore deleted file mode 100644 index 93835e48c58c393046687e46f803b60074c87f1e..0000000000000000000000000000000000000000 --- a/third_party/thrust/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -discrete_voronoi.pgm -*build*/ -.idea/ diff --git a/third_party/tvm/.gitignore b/third_party/tvm/.gitignore deleted file mode 100644 index cdcf6780a3f246ca2789f37d947f0c25d4b5f4d1..0000000000000000000000000000000000000000 --- a/third_party/tvm/.gitignore +++ /dev/null @@ -1,235 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class -*.S -# C extensions -*.so -*.ll -.npm -# Distribution / packaging -.Python -env/ -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -*.egg-info/ -.installed.cfg -*.egg -.conda/ -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*,cover -.hypothesis/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ -docs/gen_modules - -# PyBuilder -/target/ - -# IPython Notebook -.ipynb_checkpoints - -# pyenv -.python-version - -# celery beat schedule file -celerybeat-schedule - -# dotenv -.env - -# virtualenv -venv/ -ENV/ - -# Spyder project settings -.spyderproject - -# Rope project settings -.ropeproject -*~ -*.pyc -*~ -config.mk -config.cmake -Win32 -*.dir -perf -*.wasm -.emscripten - -## IOS -DerivedData/ - -## Java -*.class -jvm/*/target/ -jvm/*/*/target/ -*.worksheet -*.idea -*.iml -*.classpath -*.project -*.settings -*/node_modules/ - -## Various settings -*.pbxuser -!default.pbxuser -*.mode1v3 -!default.mode1v3 -*.mode2v3 -!default.mode2v3 -*.perspectivev3 -!default.perspectivev3 -xcuserdata/ -.pkl_memoize_* - -.emscripten* -.m2 - -# Compiled Dynamic libraries -*.so -*.dylib -*.dll - -# Compiled Object files -*.slo -*.lo -*.o -*.obj - -# Precompiled Headers -*.gch -*.pch - -# Compiled Static libraries -*.lai -*.la -*.a -*.lib - -# Executables -*.exe -*.out -*.app - -## Other -*.moved-aside -*.xccheckout -*.xcscmblueprint -.DS_Store -tags -cscope* -*.lock - -# vim temporary files -*.swp -*.swo - -# TVM generated code -perf -.bash_history -*.json -*.params -*.onnx -*.h5 -synset.txt -cat.jpg -cat.png -docs.tgz -cat.png -*.mlmodel -tvm_u.* -tvm_t.* -# Mac OS X -.DS_Store - -# Jetbrain -.idea -.ipython -.jupyter -.nv -.pylint.d -.python_history -.pytest_cache -.local -cmake-build-debug - -# Visual Studio -.vs - -# Visual Studio Code -.vscode - -# tmp file -.nfs* - -# keys -*.pem -*.p12 -*.pfx -*.cer -*.crt -*.der - -# patch sentinel -patched.txt - -# Python type checking -.mypy_cache/ -.pyre/ - -# pipenv files -Pipfile -Pipfile.lock - -# conda package artifacts -conda/Dockerfile.cuda* -conda/pkg -.node_repl_history -# nix files -.envrc -*.nix diff --git a/third_party/tvm/apps/android_camera/app/src/main/jni/make/config.mk b/third_party/tvm/apps/android_camera/app/src/main/jni/make/config.mk new file mode 100644 index 0000000000000000000000000000000000000000..49e332665ad9cef68bbd0a4203c60365968944a3 --- /dev/null +++ b/third_party/tvm/apps/android_camera/app/src/main/jni/make/config.mk @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#------------------------------------------------------------------------------- +# Template configuration for compiling +# +# If you want to change the configuration, please use the following +# steps. Assume you are on the root directory. First copy the this +# file so that any local changes will be ignored by git +# +# cp make/config.mk . +# +# Next modify the according entries, and then compile by +# +# ./build.sh +# +#------------------------------------------------------------------------------- +APP_ABI = all + +APP_PLATFORM = android-24 + +# whether enable OpenCL during compile +USE_OPENCL = 0 + +# whether to enable Vulkan during compile +USE_VULKAN = 0 + +# whether to enable contrib sort functions during compile +USE_SORT = 1 + +ifeq ($(USE_VULKAN), 1) + # Statically linking vulkan requires API Level 24 or higher + APP_PLATFORM = android-24 +endif + +# the additional include headers you want to add, e.g., SDK_PATH/adrenosdk/Development/Inc +ADD_C_INCLUDES = + +# the additional link libs you want to add, e.g., ANDROID_LIB_PATH/libOpenCL.so +ADD_LDLIBS = diff --git a/third_party/tvm/apps/android_deploy/app/src/main/jni/make/config.mk b/third_party/tvm/apps/android_deploy/app/src/main/jni/make/config.mk new file mode 100644 index 0000000000000000000000000000000000000000..bcd56e37896de940062fdb750b7aaaf58a2ec766 --- /dev/null +++ b/third_party/tvm/apps/android_deploy/app/src/main/jni/make/config.mk @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#------------------------------------------------------------------------------- +# Template configuration for compiling +# +# If you want to change the configuration, please use the following +# steps. Assume you are on the root directory. First copy the this +# file so that any local changes will be ignored by git +# +# cp make/config.mk . +# +# Next modify the according entries, and then compile by +# +# ./build.sh +# +#------------------------------------------------------------------------------- +APP_ABI = all + +APP_PLATFORM = android-17 + +# whether enable OpenCL during compile +USE_OPENCL = 0 + +# the additional include headers you want to add, e.g., SDK_PATH/adrenosdk/Development/Inc +ADD_C_INCLUDES = + +# the additional link libs you want to add, e.g., ANDROID_LIB_PATH/libOpenCL.so +ADD_LDLIBS = diff --git a/third_party/tvm/apps/android_rpc/app/src/main/jni/make/config.mk b/third_party/tvm/apps/android_rpc/app/src/main/jni/make/config.mk new file mode 100644 index 0000000000000000000000000000000000000000..851430cd42a9cfa96d4c984122c53dcfaf9c0b57 --- /dev/null +++ b/third_party/tvm/apps/android_rpc/app/src/main/jni/make/config.mk @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#------------------------------------------------------------------------------- +# Template configuration for compiling +# +# If you want to change the configuration, please use the following +# steps. Assume you are on the root directory. First copy the this +# file so that any local changes will be ignored by git +# +# cp make/config.mk . +# +# Next modify the according entries, and then compile by +# +# ./build.sh +# +#------------------------------------------------------------------------------- +APP_ABI = all + +APP_PLATFORM = android-24 + +# whether enable OpenCL during compile +USE_OPENCL = 0 + +# whether to enable Vulkan during compile +USE_VULKAN = 0 + +# whether to enable contrib sort functions during compile +USE_SORT = 1 + +# whether to eanble contrib random functions during compile +USE_RANDOM = 1 + +ifeq ($(USE_VULKAN), 1) + # Statically linking vulkan requires API Level 24 or higher + APP_PLATFORM = android-24 +endif + +# the additional include headers you want to add, e.g., SDK_PATH/adrenosdk/Development/Inc +ADD_C_INCLUDES = + +# the additional link libs you want to add, e.g., ANDROID_LIB_PATH/libOpenCL.so +ADD_LDLIBS = diff --git a/third_party/tvm/apps/ios_rpc/tvmrpc/Assets.xcassets/AppIcon.appiconset/Contents.json b/third_party/tvm/apps/ios_rpc/tvmrpc/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..1d060ed28827ed6aca9565d946e6b5595c8978df --- /dev/null +++ b/third_party/tvm/apps/ios_rpc/tvmrpc/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,93 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/third_party/tvm/apps/sgx/Cargo.lock b/third_party/tvm/apps/sgx/Cargo.lock new file mode 100644 index 0000000000000000000000000000000000000000..b02ab331d8ce7d7c930002674716119506b60188 --- /dev/null +++ b/third_party/tvm/apps/sgx/Cargo.lock @@ -0,0 +1,853 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +[[package]] +name = "addr2line" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "gimli 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "adler" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "aho-corasick" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "ansi_term" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "arrayvec" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "hermit-abi 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.72 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "autocfg" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "backtrace" +version = "0.3.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "addr2line 0.13.0 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.72 (registry+https://github.com/rust-lang/crates.io-index)", + "miniz_oxide 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "object 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)", + "rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "bindgen" +version = "0.51.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "cexpr 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "clang-sys 0.28.1 (registry+https://github.com/rust-lang/crates.io-index)", + "clap 2.33.1 (registry+https://github.com/rust-lang/crates.io-index)", + "env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", + "peeking_take_while 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", + "regex 1.3.9 (registry+https://github.com/rust-lang/crates.io-index)", + "rustc-hash 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "shlex 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "which 3.1.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "bitflags" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "cc" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "cexpr" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "cfg-if" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "clang-sys" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "glob 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.72 (registry+https://github.com/rust-lang/crates.io-index)", + "libloading 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "clap" +version = "2.33.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", + "atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)", + "bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", + "textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", + "unicode-width 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", + "vec_map 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-channel 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-deque 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-epoch 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-queue 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-channel" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-utils 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "maybe-uninit 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-deque" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "crossbeam-epoch 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "maybe-uninit 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "maybe-uninit 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "memoffset 0.5.5 (registry+https://github.com/rust-lang/crates.io-index)", + "scopeguard 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-queue" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "maybe-uninit 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "crossbeam-utils" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "either" +version = "1.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "env_logger" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)", + "humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", + "regex 1.3.9 (registry+https://github.com/rust-lang/crates.io-index)", + "termcolor 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "failure" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "backtrace 0.3.50 (registry+https://github.com/rust-lang/crates.io-index)", + "failure_derive 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "failure_derive" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.34 (registry+https://github.com/rust-lang/crates.io-index)", + "synstructure 0.12.4 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "gimli" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "glob" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "goblin" +version = "0.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", + "plain 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "scroll 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "hermit-abi" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "libc 0.2.72 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "humantime" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "itertools" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "itertools" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "itoa" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "lexical-core" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "arrayvec 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "ryu 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", + "static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "libc" +version = "0.2.72" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "libloading" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cc 1.0.58 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "log" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "matrixmultiply" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rawpointer 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "maybe-uninit" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "memchr" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "memoffset" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "miniz_oxide" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "adler 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "ndarray" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "itertools 0.7.11 (registry+https://github.com/rust-lang/crates.io-index)", + "matrixmultiply 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)", + "num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "nom" +version = "4.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "nom" +version = "5.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "lexical-core 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)", + "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "version_check 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num-complex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num-traits" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "hermit-abi 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.72 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "object" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "old-tvm-macros" +version = "0.1.1" +dependencies = [ + "goblin 0.0.24 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.34 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "proc-macro2" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "proc-macro2" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "unicode-xid 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "quote" +version = "0.6.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "quote" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rawpointer" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "regex" +version = "1.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "aho-corasick 0.7.13 (registry+https://github.com/rust-lang/crates.io-index)", + "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "regex-syntax 0.6.18 (registry+https://github.com/rust-lang/crates.io-index)", + "thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "regex-syntax" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "rustc-demangle" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "rustc_version" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "ryu" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "scroll" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "scroll_derive 0.9.5 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "scroll_derive" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 0.15.44 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "semver" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "semver-parser" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "serde" +version = "1.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "serde_derive" +version = "1.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.34 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "serde_json" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "itoa 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", + "ryu 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", + "serde 1.0.114 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "sgx-demo" +version = "0.1.0" +dependencies = [ + "tvm-runtime 0.1.0", +] + +[[package]] +name = "shlex" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "strsim" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "syn" +version = "0.15.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)", + "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "syn" +version = "1.0.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", + "unicode-xid 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "synstructure" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.34 (registry+https://github.com/rust-lang/crates.io-index)", + "unicode-xid 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "termcolor" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "winapi-util 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "unicode-width 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "thread_local" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "tvm-common" +version = "0.1.0" +dependencies = [ + "bindgen 0.51.1 (registry+https://github.com/rust-lang/crates.io-index)", + "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", + "ndarray 0.12.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "tvm-runtime" +version = "0.1.0" +dependencies = [ + "crossbeam 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", + "failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", + "itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "libloading 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)", + "ndarray 0.12.1 (registry+https://github.com/rust-lang/crates.io-index)", + "nom 5.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)", + "old-tvm-macros 0.1.1", + "serde 1.0.114 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_derive 1.0.114 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_json 1.0.56 (registry+https://github.com/rust-lang/crates.io-index)", + "tvm-common 0.1.0", +] + +[[package]] +name = "unicode-width" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "unicode-xid" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "unicode-xid" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "version_check" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "version_check" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "which" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "libc 0.2.72 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[metadata] +"checksum addr2line 0.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1b6a2d3371669ab3ca9797670853d61402b03d0b4b9ebf33d677dfa720203072" +"checksum adler 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ee2a4ec343196209d6594e19543ae87a39f96d5534d7174822a3ad825dd6ed7e" +"checksum aho-corasick 0.7.13 (registry+https://github.com/rust-lang/crates.io-index)" = "043164d8ba5c4c3035fec9bbee8647c0261d788f3474306f93bb65901cae0e86" +"checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" +"checksum arrayvec 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cff77d8686867eceff3105329d4698d96c2391c176d5d03adc90c7389162b5b8" +"checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +"checksum autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f8aac770f1885fd7e387acedd76065302551364496e46b3dd00860b2f8359b9d" +"checksum backtrace 0.3.50 (registry+https://github.com/rust-lang/crates.io-index)" = "46254cf2fdcdf1badb5934448c1bcbe046a56537b3987d96c51a7afc5d03f293" +"checksum bindgen 0.51.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ebd71393f1ec0509b553aa012b9b58e81dadbdff7130bd3b8cba576e69b32f75" +"checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" +"checksum cc 1.0.58 (registry+https://github.com/rust-lang/crates.io-index)" = "f9a06fb2e53271d7c279ec1efea6ab691c35a2ae67ec0d91d7acec0caf13b518" +"checksum cexpr 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "fce5b5fb86b0c57c20c834c1b412fd09c77c8a59b9473f86272709e78874cd1d" +"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +"checksum clang-sys 0.28.1 (registry+https://github.com/rust-lang/crates.io-index)" = "81de550971c976f176130da4b2978d3b524eaa0fd9ac31f3ceb5ae1231fb4853" +"checksum clap 2.33.1 (registry+https://github.com/rust-lang/crates.io-index)" = "bdfa80d47f954d53a35a64987ca1422f495b8d6483c0fe9f7117b36c2a792129" +"checksum crossbeam 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "69323bff1fb41c635347b8ead484a5ca6c3f11914d784170b158d8449ab07f8e" +"checksum crossbeam-channel 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "cced8691919c02aac3cb0a1bc2e9b73d89e832bf9a06fc579d4e71b68a2da061" +"checksum crossbeam-deque 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "9f02af974daeee82218205558e51ec8768b48cf524bd01d550abe5573a608285" +"checksum crossbeam-epoch 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "058ed274caafc1f60c4997b5fc07bf7dc7cca454af7c6e81edffe5f33f70dace" +"checksum crossbeam-queue 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "774ba60a54c213d409d5353bda12d49cd68d14e45036a285234c8d6f91f92570" +"checksum crossbeam-utils 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8" +"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" +"checksum env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)" = "aafcde04e90a5226a6443b7aabdb016ba2f8307c847d524724bd9b346dd1a2d3" +"checksum failure 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "d32e9bd16cc02eae7db7ef620b392808b89f6a5e16bb3497d159c6b92a0f4f86" +"checksum failure_derive 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4" +"checksum gimli 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)" = "aaf91faf136cb47367fa430cd46e37a788775e7fa104f8b4bcb3861dc389b724" +"checksum glob 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" +"checksum goblin 0.0.24 (registry+https://github.com/rust-lang/crates.io-index)" = "e3fa261d919c1ae9d1e4533c4a2f99e10938603c4208d56c05bec7a872b661b0" +"checksum hermit-abi 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)" = "3deed196b6e7f9e44a2ae8d94225d80302d81208b1bb673fd21fe634645c85a9" +"checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f" +"checksum itertools 0.7.11 (registry+https://github.com/rust-lang/crates.io-index)" = "0d47946d458e94a1b7bcabbf6521ea7c037062c81f534615abcad76e84d4970d" +"checksum itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484" +"checksum itoa 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "dc6f3ad7b9d11a0c00842ff8de1b60ee58661048eb8049ed33c73594f359d7e6" +"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +"checksum lexical-core 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)" = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616" +"checksum libc 0.2.72 (registry+https://github.com/rust-lang/crates.io-index)" = "a9f8082297d534141b30c8d39e9b1773713ab50fdbe4ff30f750d063b3bfd701" +"checksum libloading 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f2b111a074963af1d37a139918ac6d49ad1d0d5e47f72fd55388619691a7d753" +"checksum log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)" = "14b6052be84e6b71ab17edffc2eeabf5c2c3ae1fdb464aae35ac50c67a44e1f7" +"checksum matrixmultiply 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)" = "dcad67dcec2d58ff56f6292582377e6921afdf3bfbd533e26fb8900ae575e002" +"checksum maybe-uninit 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "60302e4db3a61da70c0cb7991976248362f30319e88850c487b9b95bbf059e00" +"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400" +"checksum memoffset 0.5.5 (registry+https://github.com/rust-lang/crates.io-index)" = "c198b026e1bbf08a937e94c6c60f9ec4a2267f5b0d2eec9c1b21b061ce2be55f" +"checksum miniz_oxide 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "be0f75932c1f6cfae3c04000e40114adf955636e19040f9c0a2c380702aa1c7f" +"checksum ndarray 0.12.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7cf380a8af901ad627594013a3bbac903ae0a6f94e176e47e46b5bbc1877b928" +"checksum nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2ad2a91a8e869eeb30b9cb3119ae87773a8f4ae617f41b1eb9c154b2905f7bd6" +"checksum nom 5.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af" +"checksum num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" +"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611" +"checksum num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +"checksum object 0.20.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1ab52be62400ca80aa00285d25253d7f7c437b7375c4de678f5405d3afe82ca5" +"checksum peeking_take_while 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +"checksum plain 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" +"checksum proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)" = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759" +"checksum proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)" = "beae6331a816b1f65d04c45b078fd8e6c93e8071771f41b8163255bbd8d7c8fa" +"checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" +"checksum quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1" +"checksum quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)" = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37" +"checksum rawpointer 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ebac11a9d2e11f2af219b8b8d833b76b1ea0e054aa0e8d8e9e4cbde353bdf019" +"checksum regex 1.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "9c3780fcf44b193bc4d09f36d2a3c87b251da4a046c87795a0d35f4f927ad8e6" +"checksum regex-syntax 0.6.18 (registry+https://github.com/rust-lang/crates.io-index)" = "26412eb97c6b088a6997e05f69403a802a92d520de2f8e63c2b65f9e0f47c4e8" +"checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783" +"checksum rustc-hash 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" +"checksum ryu 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" +"checksum scopeguard 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +"checksum scroll 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)" = "2f84d114ef17fd144153d608fba7c446b0145d038985e7a8cc5d08bb0ce20383" +"checksum scroll_derive 0.9.5 (registry+https://github.com/rust-lang/crates.io-index)" = "8f1aa96c45e7f5a91cb7fabe7b279f02fea7126239fc40b732316e8b6a2d0fcb" +"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" +"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" +"checksum serde 1.0.114 (registry+https://github.com/rust-lang/crates.io-index)" = "5317f7588f0a5078ee60ef675ef96735a1442132dc645eb1d12c018620ed8cd3" +"checksum serde_derive 1.0.114 (registry+https://github.com/rust-lang/crates.io-index)" = "2a0be94b04690fbaed37cddffc5c134bf537c8e3329d53e982fe04c374978f8e" +"checksum serde_json 1.0.56 (registry+https://github.com/rust-lang/crates.io-index)" = "3433e879a558dde8b5e8feb2a04899cf34fdde1fafb894687e52105fc1162ac3" +"checksum shlex 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2" +"checksum static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" +"checksum syn 0.15.44 (registry+https://github.com/rust-lang/crates.io-index)" = "9ca4b3b69a77cbe1ffc9e198781b7acb0c7365a883670e8f1c1bc66fba79a5c5" +"checksum syn 1.0.34 (registry+https://github.com/rust-lang/crates.io-index)" = "936cae2873c940d92e697597c5eee105fb570cd5689c695806f672883653349b" +"checksum synstructure 0.12.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b834f2d66f734cb897113e34aaff2f1ab4719ca946f9a7358dba8f8064148701" +"checksum termcolor 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bb6bfa289a4d7c5766392812c0a1f4c1ba45afa1ad47803c11e1f407d846d75f" +"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +"checksum thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" +"checksum unicode-width 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3" +"checksum unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc" +"checksum unicode-xid 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564" +"checksum vec_map 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" +"checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd" +"checksum version_check 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)" = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed" +"checksum which 3.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d011071ae14a2f6671d0b74080ae0cd8ebf3a6f8c9589a2cd45f23126fe29724" +"checksum winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +"checksum winapi-util 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +"checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/third_party/tvm/nnvm/make/config.mk b/third_party/tvm/nnvm/make/config.mk new file mode 100644 index 0000000000000000000000000000000000000000..4a210ea484bc2ad5f46e57e3e5ceaad9102e7879 --- /dev/null +++ b/third_party/tvm/nnvm/make/config.mk @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#------------------------------------------------------------------------------- +# Template configuration for compiling nnvm +# +# If you want to change the configuration, please use the following +# steps. Assume you are on the root directory of nnvm. First copy the this +# file so that any local changes will be ignored by git +# +# $ cp make/config.mk . +# +# Next modify the according entries, and then compile by +# +# $ make +# +# or build in parallel with 8 threads +# +# $ make -j8 +#------------------------------------------------------------------------------- + +#--------------------- +# choice of compiler +#-------------------- + +export NVCC = nvcc + +# choice of archiver +export AR = ar + +# the additional link flags you want to add +ADD_LDFLAGS= + +# the additional compile flags you want to add +ADD_CFLAGS= + +# path to dmlc-core module +#DMLC_CORE_PATH= + +#---------------------------- +# plugins +#---------------------------- + +# whether to use fusion integration. This requires installing cuda. +# ifndef CUDA_PATH +# CUDA_PATH = /usr/local/cuda +# endif +# NNVM_FUSION_PATH = plugin/nnvm-fusion +# NNVM_PLUGINS += $(NNVM_FUSION_PATH)/nnvm-fusion.mk diff --git a/third_party/tvm/tests/python/contrib/test_arm_compute_lib/test_config.json b/third_party/tvm/tests/python/contrib/test_arm_compute_lib/test_config.json new file mode 100644 index 0000000000000000000000000000000000000000..c8168ae8d25530543c8cb918198a0f1046e82961 --- /dev/null +++ b/third_party/tvm/tests/python/contrib/test_arm_compute_lib/test_config.json @@ -0,0 +1,8 @@ +{ + "connection_type": "local", + "host": "localhost", + "port": 9090, + "target": "llvm -mtriple=aarch64-linux-gnu -mattr=+neon", + "device_key": "", + "cross_compile": "" +} diff --git a/third_party/tvm/tutorials/auto_scheduler/ci_logs/conv2d.json b/third_party/tvm/tutorials/auto_scheduler/ci_logs/conv2d.json new file mode 100644 index 0000000000000000000000000000000000000000..c748920d14db31b31fd6165c7f83daed43cfa66b --- /dev/null +++ b/third_party/tvm/tutorials/auto_scheduler/ci_logs/conv2d.json @@ -0,0 +1,2 @@ +# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI. +{"i": [["[\"conv2d_layer\", 1, 7, 7, 512, 512, 3, 3, [1, 1], [1, 1]]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32"], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 512, [1, 64, 2, 1], 1], ["SP", 3, 10, 7, [1, 1, 1, 1], 1], ["SP", 3, 15, 7, [1, 1, 7, 1], 1], ["SP", 3, 20, 512, [4, 2], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 3, [3, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 48, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 504, [4], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$512"]]]], "r": [[0.000429498], 0, 1.59126, 1603259147], "v": "v0.2"} diff --git a/third_party/tvm/tutorials/auto_scheduler/ci_logs/matmul.json b/third_party/tvm/tutorials/auto_scheduler/ci_logs/matmul.json new file mode 100644 index 0000000000000000000000000000000000000000..827cfc9a6dbb06863e729489ca541205c4e7c1b8 --- /dev/null +++ b/third_party/tvm/tutorials/auto_scheduler/ci_logs/matmul.json @@ -0,0 +1,2 @@ +# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI. +{"i": [["[\"matmul_add\", 128, 128, 128, \"float32\"]", "llvm -keys=cpu"], [[], [["SP", 2, 0, 128, [4, 2, 4], 1], ["SP", 2, 4, 128, [1, 32, 2], 1], ["SP", 2, 8, 128, [2], 1], ["RE", 2, [0, 4, 1, 5, 8, 2, 6, 9, 3, 7]], ["FSP", 4, 0, 0, 1], ["FSP", 4, 2, 1, 1], ["RE", 4, [0, 2, 1, 3]], ["CA", 2, 4, 1], ["FU", 4, [0, 1]], ["AN", 4, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$0"], ["AN", 2, 9, 2]]]], "r": [[5.80388e-05], 0, 0.299169, 1603402396], "v": "v0.2"} diff --git a/third_party/tvm/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json b/third_party/tvm/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json new file mode 100644 index 0000000000000000000000000000000000000000..41b6c0e554ed2118eda9abaaa9761c0f82301175 --- /dev/null +++ b/third_party/tvm/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json @@ -0,0 +1,26 @@ +# Provide valid schedules for resnet-18. +# This is used to run the tutorial on the documentation web server. +{"i": [["[\"b32ed43fb351136894c322ee49097a1a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 4, 1, 1000, [50], 1], ["AN", 4, 2, 6], ["FSP", 3, 1, 0, 1], ["AN", 3, 2, 6], ["CA", 3, 4, 0], ["CI", 2], ["FSP", 1, 1, 0, 1], ["AN", 1, 2, 6], ["CA", 1, 4, 0], ["AN", 4, 0, 5], ["PR", 1, 0, "auto_unroll_max_step$0"], ["PR", 3, 0, "auto_unroll_max_step$1024"]]]], "r": [[4.54041e-06], 0, 1.27943, 1605490839], "v": "v0.3"} +{"i": [["[\"d09dc1a6bb90d59c91b68989ad3492ff\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [1, 50, 1, 1], 1], ["SP", 2, 10, 512, [1, 4], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 6], ["FU", 6, [0, 1]], ["AN", 6, 0, 5], ["FU", 6, [1, 2]], ["AN", 6, 1, 4], ["FU", 6, [2, 3]], ["AN", 6, 2, 6], ["FU", 3, [0, 1]], ["SP", 3, 0, 4, [4], 1], ["AN", 3, 1, 2], ["FFSP", 3, 0, [1, 0], 1, 1], ["AN", 3, 1, 6], ["FU", 1, [0, 1]], ["SP", 1, 0, 4, [2], 1], ["AN", 1, 1, 2], ["FFSP", 1, 0, [1, 0], 1, 1], ["AN", 1, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[1.03431e-05], 0, 2.09134, 1605490924], "v": "v0.3"} +{"i": [["[\"7de313da0ca29a8c63f647791692430d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 512, [64], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["FU", 1, [0, 1, 2, 3]], ["SP", 1, 0, 512, [8], 1], ["AN", 1, 0, 5], ["AN", 1, 1, 6], ["PR", 1, 0, "auto_unroll_max_step$16"]]]], "r": [[5.51259e-06], 0, 1.30207, 1605491060], "v": "v0.3"} +{"i": [["[\"944921d3fd999ba7aa9ffe5a592a9241\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 4], ["CI", 1], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 200704, [64], 1], ["AN", 5, 0, 5], ["AN", 5, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 200704, [56], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["PR", 2, 0, "auto_unroll_max_step$512"]]]], "r": [[2.24305e-05], 0, 1.60311, 1605493879], "v": "v0.3"} +{"i": [["[\"a0eb8d6048282a4a0986cc2ccf14eaa2\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [2, 1, 1, 8], 1], ["SP", 3, 10, 112, [1, 8, 1, 1], 1], ["SP", 3, 15, 64, [2, 16, 2, 1], 1], ["SP", 3, 20, 7, [7, 1], 1], ["SP", 3, 23, 7, [1, 7], 1], ["SP", 3, 26, 3, [1, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 294, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 441, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[7.63468e-05], 0, 2.59544, 1605493932], "v": "v0.3"} +{"i": [["[\"bf78a7bf0209980f72953637dfd14a6f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [7, 4, 2, 1], 1], ["SP", 3, 10, 56, [1, 2, 2, 1], 1], ["SP", 3, 15, 64, [2, 16, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [8, 4], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 32, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 128, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[1.26775e-05], 0, 1.94247, 1605494103], "v": "v0.3"} +{"i": [["[\"6630936c26852f2b89dbfa2ff37fbb9c\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 7, 1, 2], 1], ["SP", 3, 10, 28, [1, 1, 2, 1], 1], ["SP", 3, 15, 128, [1, 16, 1, 8], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 128, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 144, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$512"]]]], "r": [[1.13004e-05], 0, 1.86312, 1605494224], "v": "v0.3"} +{"i": [["[\"ba5f918733ccbbd4a1d7fd3724665a2f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 2, 1], 1], ["SP", 3, 10, 14, [1, 14, 1, 1], 1], ["SP", 3, 15, 256, [1, 8, 4, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 64, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 48, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[1.29425e-05], 0, 1.70493, 1605494303], "v": "v0.3"} +{"i": [["[\"21ad409d72953de188314010134e3acd\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 7, 1], 1], ["SP", 3, 10, 7, [1, 1, 1, 1], 1], ["SP", 3, 15, 512, [2, 16, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [2, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 16, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 16, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$512"]]]], "r": [[2.04683e-05], 0, 1.80217, 1605494406], "v": "v0.3"} +{"i": [["[\"022ebb6b7c55c5ed030421380ec83a04\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 1, 1, 7], 1], ["SP", 3, 10, 28, [1, 4, 1, 1], 1], ["SP", 3, 15, 128, [1, 32, 2, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [3, 1], 1], ["SP", 3, 26, 64, [1, 4], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 72, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 348, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[4.93528e-05], 0, 1.74125, 1605498773], "v": "v0.3"} +{"i": [["[\"ac6920940de3797cc3f9f9c260675e5d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [8], 1], ["SP", 8, 4, 512, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 2, 1, 1], 1], ["SP", 6, 10, 16, [2, 1, 8, 1], 1], ["SP", 6, 15, 512, [1, 32, 2, 1], 1], ["SP", 6, 20, 512, [8, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [16], 1], ["SP", 4, 4, 512, [2], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 25088, [49], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 256, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.000129562], 0, 3.40317, 1605500470], "v": "v0.3"} +{"i": [["[\"1f6cd3637ec856bf5cf5010a623eed05\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 1, 7], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 16, 1, 1], 1], ["SP", 3, 20, 3, [1, 3], 1], ["SP", 3, 23, 3, [3, 1], 1], ["SP", 3, 26, 256, [4, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 288, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 1440, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[7.57476e-05], 0, 2.59558, 1605501054], "v": "v0.3"} +{"i": [["[\"c5ee3e05edd9754492d0763aa41fd025\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 2, 2, 1], 1], ["SP", 6, 10, 196, [4, 1, 1, 7], 1], ["SP", 6, 15, 128, [2, 32, 1, 1], 1], ["SP", 6, 20, 128, [2, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [14], 1], ["SP", 4, 4, 128, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 100352, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [49], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 8, [4], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 56, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$1024"]]]], "r": [[6.77244e-05], 0, 2.67201, 1605501438], "v": "v0.3"} +{"i": [["[\"c035cc8b0568a8e054d06bd7f4950550\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [1], 1], ["SP", 8, 4, 128, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 196, [1, 7, 7, 1], 1], ["SP", 6, 15, 128, [8, 16, 1, 1], 1], ["SP", 6, 20, 128, [1, 8], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [14], 1], ["SP", 4, 4, 128, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 100352, [64], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [16], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 8, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 8, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[6.23875e-05], 0, 1.93274, 1605501606], "v": "v0.3"} +{"i": [["[\"f2e3c09a00e7d0a9897f70497e089f1e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [1], 1], ["SP", 8, 4, 64, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 1, 2, 1], 1], ["SP", 6, 5, 6, [1, 1, 1, 1], 1], ["SP", 6, 10, 196, [1, 7, 1, 4], 1], ["SP", 6, 15, 64, [2, 16, 1, 1], 1], ["SP", 6, 20, 64, [1, 8], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [28], 1], ["SP", 4, 4, 64, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 200704, [64], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 16, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [4], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$512"]]]], "r": [[6.65448e-05], 0, 2.94376, 1605501803], "v": "v0.3"} +{"i": [["[\"81aae4b8e2c076a4014d403e8a2c70a1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 1, 2], 1], ["SP", 3, 10, 14, [2, 7, 1, 1], 1], ["SP", 3, 15, 256, [1, 32, 2, 1], 1], ["SP", 3, 20, 3, [1, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 128, [2, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 192, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 240, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[6.31245e-05], 0, 1.9322, 1605501903], "v": "v0.3"} +{"i": [["[\"7e83a2ee5cd5d50282ed19310700046a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [16], 1], ["SP", 8, 4, 512, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 16, [1, 2, 4, 2], 1], ["SP", 6, 15, 512, [2, 32, 1, 1], 1], ["SP", 6, 20, 512, [16, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [8], 1], ["SP", 4, 4, 512, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 25088, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [4], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 128, [4], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[0.000143154], 0, 2.20107, 1605502293], "v": "v0.3"} +{"i": [["[\"424ba83160af31badc0b098136e1a3b0\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 49, [1, 49, 1, 1], 1], ["SP", 6, 15, 256, [8, 2, 2, 2], 1], ["SP", 6, 20, 256, [2, 16], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [1], 1], ["SP", 4, 4, 256, [1], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 50176, [64], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 128, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 32, [4], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[0.000115017], 0, 3.89122, 1605502608], "v": "v0.3"} +{"i": [["[\"c7a6b56bdc04b94c829fb2ef9874019e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [4], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [2, 1, 1, 1], 1], ["SP", 6, 10, 196, [1, 1, 2, 14], 1], ["SP", 6, 15, 128, [1, 32, 1, 2], 1], ["SP", 6, 20, 128, [1, 8], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [1], 1], ["SP", 4, 4, 128, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 100352, [64], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 224, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[7.20936e-05], 0, 3.36582, 1605502968], "v": "v0.3"} +{"i": [["[\"0141ffc4fbabc10cc5a94c954419055b\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [7], 1], ["SP", 8, 4, 256, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 49, [1, 49, 1, 1], 1], ["SP", 6, 15, 256, [8, 1, 2, 2], 1], ["SP", 6, 20, 256, [1, 32], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [1], 1], ["SP", 4, 4, 256, [2], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 50176, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 128, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 16, [2], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$64"]]]], "r": [[0.000122349], 0, 4.2774, 1605503135], "v": "v0.3"} +{"i": [["[\"a169cd0053d3a7ca82998fcb62e42c58\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 49, [1, 7, 1, 7], 1], ["SP", 6, 15, 256, [8, 4, 1, 1], 1], ["SP", 6, 20, 256, [1, 16], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [7], 1], ["SP", 4, 4, 256, [2], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 50176, [64], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 256, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[7.9277e-05], 0, 3.07064, 1605503350], "v": "v0.3"} +{"i": [["[\"fa26946d7ac51126bfa859cb183f9ca1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 64, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 1], 1], ["SP", 6, 5, 6, [1, 2, 1, 1], 1], ["SP", 6, 10, 196, [7, 7, 1, 4], 1], ["SP", 6, 15, 64, [1, 8, 4, 1], 1], ["SP", 6, 20, 64, [4, 2], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [28], 1], ["SP", 4, 4, 64, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 200704, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 16, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[7.64176e-05], 0, 5.45091, 1605503568], "v": "v0.3"} +{"i": [["[\"de0df0893e01892cfe69f7bc2c24111f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [1], 1], ["SP", 8, 4, 64, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 3, 1, 1], 1], ["SP", 6, 5, 6, [1, 1, 1, 1], 1], ["SP", 6, 10, 196, [14, 7, 1, 2], 1], ["SP", 6, 15, 64, [1, 16, 1, 2], 1], ["SP", 6, 20, 64, [1, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [2], 1], ["SP", 4, 4, 64, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 200704, [64], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 16, [4], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 4, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$0"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$64"]]]], "r": [[7.60496e-05], 0, 3.00771, 1605503805], "v": "v0.3"} +{"i": [["[\"8d5a93959138dc7b2ee1f1b3219dfa14\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [16], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 16, [1, 1, 4, 4], 1], ["SP", 6, 15, 512, [1, 64, 1, 1], 1], ["SP", 6, 20, 512, [1, 32], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 13, 3], ["FSP", 7, 4, 14, 3], ["FSP", 7, 8, 15, 3], ["FSP", 7, 12, 16, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [8], 1], ["SP", 4, 4, 512, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 19, [0, 1, 2, 3]], ["SP", 19, 0, 25088, [32], 1], ["AN", 19, 0, 5], ["AN", 19, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [16], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [16, 15, 14, 13], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [16, 15, 14, 13], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [16], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$0"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$64"]]]], "r": [[0.000135079], 0, 2.40957, 1605504233], "v": "v0.3"} diff --git a/third_party/tvm/web/.eslintrc.json b/third_party/tvm/web/.eslintrc.json new file mode 100644 index 0000000000000000000000000000000000000000..0724c440041b3247d15d402bbef9e10bad445e5e --- /dev/null +++ b/third_party/tvm/web/.eslintrc.json @@ -0,0 +1,34 @@ +{ + "env": { + "browser": true, + "es6": true + }, + "extends": ["eslint:recommended"], + "root": true, + "parser": "@typescript-eslint/parser", + "parserOptions": { + "ecmaVersion": 2018, + "sourceType": "module" + }, + "overrides": [ + { + "files": ["src/**.ts", "src/**.tsx"], + "plugins": ["@typescript-eslint"], + "extends": [ + "plugin:@typescript-eslint/eslint-recommended", + "plugin:@typescript-eslint/recommended" + ], + "rules": { + "require-jsdoc": 0, + "@typescript-eslint/no-explicit-any": 0, + "@typescript-eslint/no-empty-function": 0 + } + }, + { + "files": ["tests/node/*.js", "apps/node/*.js"], + "env": { + "node": true + } + } + ] +} diff --git a/third_party/tvm/web/package.json b/third_party/tvm/web/package.json new file mode 100644 index 0000000000000000000000000000000000000000..dafccb0a864856ecb8be2bc6c9cd2e3b3605d818 --- /dev/null +++ b/third_party/tvm/web/package.json @@ -0,0 +1,32 @@ +{ + "name": "tvmjs", + "displayName": "TVM Wasm JS runtime", + "license": "Apache-2.0", + "version": "0.8.0-dev0", + "scripts": { + "prepwasm": "make && python3 tests/python/prepare_test_libs.py", + "build": "tsc -b && make rmtypedep", + "lint": "eslint -c .eslintrc.json .", + "typedoc": "typedoc .", + "test": "jest", + "bundle": "npm run build && rollup -c rollup.config.js", + "example": "npm run bundle && node apps/node/example.js", + "example:wasi": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_example.js", + "rpc": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_rpc_server.js" + }, + "devDependencies": { + "@rollup/plugin-commonjs": "^11.1.0", + "@rollup/plugin-node-resolve": "^7.1.3", + "@types/node": "^12.12.37", + "@typescript-eslint/eslint-plugin": "^2.29.0", + "@typescript-eslint/parser": "^2.29.0", + "@webgpu/types": "^0.0.31", + "eslint": "^6.8.0", + "jest": "^26.0.1", + "rollup": "^2.7.6", + "rollup-plugin-typescript2": "^0.27.0", + "typedoc": "^0.17.6", + "typescript": "^3.8.3", + "ws": "^7.2.5" + } +} diff --git a/third_party/tvm/web/tsconfig.json b/third_party/tvm/web/tsconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..6aec44858a7ae203c0268aef09dda9394668d372 --- /dev/null +++ b/third_party/tvm/web/tsconfig.json @@ -0,0 +1,13 @@ +{ + "compilerOptions": { + "module": "commonjs", + "target": "es6", + "outDir": "dist", + "rootDir": "src", + "declaration": true, + "sourceMap": true, + "strict": true + }, + "include": ["src"], + "exclude": ["node_modules"] +} diff --git a/third_party/tvm/web/typedoc.json b/third_party/tvm/web/typedoc.json new file mode 100644 index 0000000000000000000000000000000000000000..65631ea5efa8bb3d0b007f2ade27685ea14fb799 --- /dev/null +++ b/third_party/tvm/web/typedoc.json @@ -0,0 +1,11 @@ +{ + "out": "dist/docs", + "readme": "none", + "mode": "file", + "excludeNotExported": true, + "excludePrivate": true, + "listInvalidSymbolLinks": true, + "module": "umd", + "includes": ["src"], + "exclude": ["node_modules"] +} diff --git a/third_party/xbyak/.gitignore b/third_party/xbyak/.gitignore deleted file mode 100644 index 24b0b1de5b56d939246534aa312e88b145f82160..0000000000000000000000000000000000000000 --- a/third_party/xbyak/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/build* # cmake