"examples/sampling/vscode:/vscode.git/clone" did not exist on "346197c47722993cf8cd9f41891d1457ef82decc"
Commit c454d419 authored by lisj's avatar lisj
Browse files

删除子模块的gitignore

parent 3359c1f1
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Kunal Banerjee, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_elementwise.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
LIBXSMM_API_INTERN void libxsmm_internal_matrix_zero(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
src[i] = (LIBXSMM_DNN_ELTWISE_FTYPE)0;
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_add(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *a, LIBXSMM_DNN_ELTWISE_FTYPE *b, LIBXSMM_DNN_ELTWISE_FTYPE *c, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
c[i] = a[i] + b[i];
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_mult(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *a, LIBXSMM_DNN_ELTWISE_FTYPE *b, LIBXSMM_DNN_ELTWISE_FTYPE *c, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
c[i] = a[i] * b[i];
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
const LIBXSMM_DNN_ELTWISE_FTYPE exp_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[i]);
dst[i] = 1 / (1 + exp_value);
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
dst[i] = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double)src[i]);
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
dst[i] = (src[i] > 0.0f) ? src[i] : 0.0f;
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
const LIBXSMM_DNN_ELTWISE_FTYPE exp_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[i]);
const LIBXSMM_DNN_ELTWISE_FTYPE sig_exp = 1 / (1 + exp_value);
dst[i] = (1 - sig_exp)*sig_exp;
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
const LIBXSMM_DNN_ELTWISE_FTYPE tanh_value = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double)src[i]);
dst[i] = 1 - (tanh_value * tanh_value);
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
dst[i] = (LIBXSMM_DNN_ELTWISE_FTYPE)(src[i] > 0.0f ? 1.0f : 0.0f);
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_transpose(libxsmm_blasint rows, libxsmm_blasint cols, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* number of tasks that could be run in parallel */
const libxsmm_blasint size = rows * cols;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
LIBXSMM_VLA_DECL(2, LIBXSMM_DNN_ELTWISE_FTYPE, src2D, src, cols);
LIBXSMM_VLA_DECL(2, LIBXSMM_DNN_ELTWISE_FTYPE, dst2D, dst, rows);
libxsmm_blasint job;
for (job = thr_begin; job < thr_end; ++job) {
const libxsmm_blasint i = job / cols;
const libxsmm_blasint j = job % cols;
LIBXSMM_VLA_ACCESS(2, dst2D, j, i, rows) = LIBXSMM_VLA_ACCESS(2, src2D, i, j, cols);
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_copy(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
dst[i] = src[i];
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
dst[i] = 1 - src[i];
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_square(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
dst[i] = 1 - (src[i] * src[i]);
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (size % nthreads == 0) ? (size / nthreads) : (size / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < size) ? (ltid * chunksize) : size;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, size);
libxsmm_blasint i;
for (i = thr_begin; i < thr_end; i++) {
dst[i] = -src[i];
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_1D_2D(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint bm, libxsmm_blasint bn, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads)
{
const int ltid = tid - start_thread;
/* compute chunk size */
const libxsmm_blasint chunksize = (m % nthreads == 0) ? (m / nthreads) : (m / nthreads) + 1;
/* compute thr_begin and thr_end */
const libxsmm_blasint thr_begin = (ltid * chunksize < m) ? (ltid * chunksize) : m;
const libxsmm_blasint thr_end = LIBXSMM_MIN(ltid * chunksize + chunksize, m);
libxsmm_blasint i, j;
LIBXSMM_VLA_DECL(4, LIBXSMM_DNN_ELTWISE_FTYPE, real_dst, (LIBXSMM_DNN_ELTWISE_FTYPE*)dst, m/bm, bn, bm);
for (i = thr_begin; i < thr_end; i++) {
const libxsmm_blasint mb = i/bm;
const libxsmm_blasint ibm = i%bm;
for (j = 0; j < n; j++) {
const libxsmm_blasint nb = j/bn;
const libxsmm_blasint ibn = j%bn;
LIBXSMM_VLA_ACCESS(4, real_dst, nb, mb, ibn, ibm, m/bm, bn, bm) = src[i];
}
}
}
/* #define LSTM_TIMING */
#if defined(LSTM_TIMING)
extern double Gbl_t_input_total, Gbl_t_recur_total, Gbl_t_eltwise_total, Gbl_t_nonlin_total;
extern unsigned long long Gbl_t_input, Gbl_t_recur, Gbl_t_eltwise, Gbl_t_nonlin;
extern double Gbl_duration_input, Gbl_duration_recur, Gbl_duration_eltwise, Gbl_duration_nonlin;
#endif
LIBXSMM_API_INTERN void libxsmm_internal_matrix_zero_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
srcdst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)0;
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_copy_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] = src[(j*ld)+i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_add_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] = src0[(j*ld)+i] + src1[(j*ld)+i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sub_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] = src0[(j*ld)+i] - src1[(j*ld)+i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] = src0[(j*ld)+i] * src1[(j*ld)+i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
srcdst[(j*ld)+i] *= src0[(j*ld)+i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_fma_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] += src0[(j*ld)+i] * src1[(j*ld)+i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_add_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
srcdst[(j*ld)+i] += colv[i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
srcdst[(j*ld)+i] = colv[i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_cvt_bf16_fp32_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, libxsmm_bfloat16 *colv) {
libxsmm_blasint i, j;
libxsmm_bfloat16_hp t;
t.i[0] = 0;
for ( j = 0; j < n; ++j ) {
for ( i = 0; i < m; ++i ) {
t.i[1] = colv[i];
srcdst[(j*ld)+i] = t.f;
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_colvector_const_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv, LIBXSMM_DNN_ELTWISE_FTYPE const_bias) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
srcdst[(j*ld)+i] = colv[i] + const_bias;
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_cvt_bf16_fp32_colvector_const_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, libxsmm_bfloat16 *colv, LIBXSMM_DNN_ELTWISE_FTYPE const_bias) {
libxsmm_blasint i, j;
libxsmm_bfloat16_hp t;
t.i[0] = 0;
for ( j = 0; j < n; ++j ) {
for ( i = 0; i < m; ++i ) {
t.i[1] = colv[i];
srcdst[(j*ld)+i] = t.f + const_bias;
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
const LIBXSMM_DNN_ELTWISE_FTYPE mid_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[(j*ld)+i]);
dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)1 / ((LIBXSMM_DNN_ELTWISE_FTYPE)1 + mid_value);
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double) src[(j*ld)+i]);
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] = (src[(j*ld)+i] < 0) ? (LIBXSMM_DNN_ELTWISE_FTYPE)0 : src[(j*ld)+i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
LIBXSMM_DNN_ELTWISE_FTYPE exp_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[(j*ld)+i]);
LIBXSMM_DNN_ELTWISE_FTYPE mid_value = (LIBXSMM_DNN_ELTWISE_FTYPE)1 / ((LIBXSMM_DNN_ELTWISE_FTYPE)1 + exp_value);
dst[(j*ld)+i] = ((LIBXSMM_DNN_ELTWISE_FTYPE)1 - mid_value) * mid_value;
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
LIBXSMM_DNN_ELTWISE_FTYPE tanh_value = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double) src[(j*ld)+i]);
dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)1 - (tanh_value * tanh_value);
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] = (src[(j*ld)+i] < 0) ? (LIBXSMM_DNN_ELTWISE_FTYPE)0 : (LIBXSMM_DNN_ELTWISE_FTYPE)1;
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
LIBXSMM_DNN_ELTWISE_FTYPE exp_value = (LIBXSMM_DNN_ELTWISE_FTYPE)exp((double) -src[(j*ld)+i]);
LIBXSMM_DNN_ELTWISE_FTYPE mid_value = (LIBXSMM_DNN_ELTWISE_FTYPE)1 / ((LIBXSMM_DNN_ELTWISE_FTYPE)1 + exp_value);
dst[(j*ld)+i] *= ((LIBXSMM_DNN_ELTWISE_FTYPE)1 - mid_value) * mid_value;
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
LIBXSMM_DNN_ELTWISE_FTYPE tanh_value = (LIBXSMM_DNN_ELTWISE_FTYPE)tanh((double) src[(j*ld)+i]);
dst[(j*ld)+i] *= (LIBXSMM_DNN_ELTWISE_FTYPE)1 - (tanh_value * tanh_value);
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] *= (src[(j*ld)+i] < 0) ? (LIBXSMM_DNN_ELTWISE_FTYPE)0 : (LIBXSMM_DNN_ELTWISE_FTYPE)1;
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)1 - src[(j*ld)+i];
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_square_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i = 0, j;
for ( j = 0; j < n; ++j ) {
LIBXSMM_PRAGMA_SIMD
for ( i = 0; i < m; ++i ) {
dst[(j*ld)+i] = (LIBXSMM_DNN_ELTWISE_FTYPE)1 - (src[(j*ld)+i] * src[(j*ld)+i]);
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_rne_mask_fp32_bfp16_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, float* src, float* dst) {
libxsmm_blasint i,j;
/* rnaz buffer to bfp16 */
for ( j = 0; j < n; ++j ) {
for ( i = 0; i < m; ++i ) {
unsigned int int_round = 0;
unsigned int do_round = 1;
const void *const ptr = &int_round;
int_round = *((unsigned int*)&(src[(j*ld)+i]));
/* we don't round NaN and inf */
if ( (int_round & 0x7f800000) == 0x7f800000 ) {
do_round = 0;
}
/* perform round nearest tie even */
if ( do_round != 0 ) {
unsigned int fixup = (int_round >> 16) & 1;
int_round = int_round + 0x00007fff + fixup;
}
/* chop bits to create BFP16 in FP32 */
int_round = int_round & 0xffff0000;
dst[(j*ld)+i] = *((float*)ptr);
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_rne_cvt_fp32_bfp16_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, float* src, libxsmm_bfloat16* dst) {
libxsmm_blasint i,j;
/* truncate buffer to bfp16 */
for ( j = 0; j < n; ++j ) {
for ( i = 0; i < m; ++i ) {
unsigned int int_round = 0;
unsigned int do_round = 1;
int_round = *((unsigned int*)&(src[(j*ld)+i]));
/* we don't round NaN and inf */
if ( (int_round & 0x7f800000) == 0x7f800000 ) {
do_round = 0;
}
/* perform round nearest tie even */
if ( do_round != 0 ) {
unsigned int fixup = (int_round >> 16) & 1;
int_round = int_round + 0x00007fff + fixup;
}
/* create the bfp16 value by shifting out the lower 16bits */
int_round = int_round >> 16;
dst[(j*ld)+i] = (unsigned short)int_round;
}
}
}
LIBXSMM_API_INTERN void libxsmm_internal_matrix_cvt_bf16_fp32_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, libxsmm_bfloat16 *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst) {
libxsmm_blasint i, j;
libxsmm_bfloat16_hp t;
t.i[0] = 0;
for ( j = 0; j < n; ++j ) {
for ( i = 0; i < m; ++i ) {
t.i[1] = src[(j*ld)+i];
dst[(j*ld)+i] = t.f;
}
}
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Kunal Banerjee, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_ELEMENTWISE_H
#define LIBXSMM_DNN_ELEMENTWISE_H
#include <libxsmm.h>
#if !defined(LIBXSMM_DNN_ELTWISE_FTYPE)
# define LIBXSMM_DNN_ELTWISE_FTYPE float
#endif
LIBXSMM_API_INTERN void libxsmm_internal_matrix_zero(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_add(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *a, LIBXSMM_DNN_ELTWISE_FTYPE *b, LIBXSMM_DNN_ELTWISE_FTYPE *c, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_mult(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *a, LIBXSMM_DNN_ELTWISE_FTYPE *b, LIBXSMM_DNN_ELTWISE_FTYPE *c, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_transpose(libxsmm_blasint rows, libxsmm_blasint cols, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_copy(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_square(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_inverse(libxsmm_blasint size, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_1D_2D(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint bm, libxsmm_blasint bn, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst, int start_thread, int tid, int nthreads);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_zero_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_add_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sub_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_copy_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_eltwise_fma_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src0, LIBXSMM_DNN_ELTWISE_FTYPE *src1, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_add_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_colvector_const_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, LIBXSMM_DNN_ELTWISE_FTYPE *colv, LIBXSMM_DNN_ELTWISE_FTYPE const_bias);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_cvt_bf16_fp32_colvector_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, libxsmm_bfloat16 *colv);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_bcst_cvt_bf16_fp32_colvector_const_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *srcdst, libxsmm_bfloat16 *colv, LIBXSMM_DNN_ELTWISE_FTYPE const_bias);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_sigmoid_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_tanh_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_relu_inverse_inplace_eltwise_mult_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_complement_square_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, LIBXSMM_DNN_ELTWISE_FTYPE *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_rne_mask_fp32_bfp16_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, float* src, float* dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_rne_cvt_fp32_bfp16_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, float* src, libxsmm_bfloat16* dst);
LIBXSMM_API_INTERN void libxsmm_internal_matrix_cvt_bf16_fp32_ld(libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ld, libxsmm_bfloat16 *src, LIBXSMM_DNN_ELTWISE_FTYPE *dst);
#endif /*LIBXSMM_DNN_ELEMENTWISE_H*/
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_fullyconnected_backward_weight_update.h"
#include "libxsmm_dnn_fullyconnected_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API libxsmm_dnn_fullyconnected* libxsmm_dnn_create_fullyconnected(libxsmm_dnn_fullyconnected_desc fullyconnected_desc, libxsmm_dnn_err_t* status) {
libxsmm_dnn_fullyconnected* handle = 0;
/* init libxsmm */
LIBXSMM_INIT
if ( ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ||
((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle = (libxsmm_dnn_fullyconnected*)calloc(1, sizeof(libxsmm_dnn_fullyconnected));
if (0 != handle) {
*status = LIBXSMM_DNN_SUCCESS;
/* let's make the description persistent */
handle->desc = fullyconnected_desc;
handle->target_archid = libxsmm_target_archid;
if ( ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) && ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && ((handle->desc.C % 16 != 0) || (handle->desc.K % 16 != 0)) ) {
handle->target_archid = LIBXSMM_X86_AVX512_CPX;
}
/* @TODO perhaps we need a better switch here */
if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
handle->bk = handle->desc.bk;
handle->bn = handle->desc.bn;
handle->bc = handle->desc.bc;
if ( handle->desc.N % handle->bn != 0 ) {
handle->bn = handle->desc.N;
*status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING;
}
if ( handle->desc.C % handle->bc != 0 ) {
handle->bc = handle->desc.C;
*status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING;
}
if ( handle->desc.K % handle->bk != 0 ) {
handle->bk = handle->desc.K;
*status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING;
}
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
#if 0
handle->fwd_bf = atoi(getenv("FWD_BF"));
handle->bwd_bf = atoi(getenv("BWD_BF"));
handle->upd_bf = atoi(getenv("UPD_BF"));
handle->fwd_2d_blocking = atoi(getenv("FWD_2D_BLOCKING"));
handle->bwd_2d_blocking = atoi(getenv("BWD_2D_BLOCKING"));
handle->upd_2d_blocking = atoi(getenv("UPD_2D_BLOCKING"));
handle->fwd_row_teams = atoi(getenv("FWD_ROW_TEAMS"));
handle->fwd_column_teams = atoi(getenv("FWD_COLUMN_TEAMS"));
handle->bwd_row_teams = atoi(getenv("BWD_ROW_TEAMS"));
handle->bwd_column_teams = atoi(getenv("BWD_COLUMN_TEAMS"));
handle->upd_row_teams = atoi(getenv("UPD_ROW_TEAMS"));
handle->upd_column_teams = atoi(getenv("UPD_COLUMN_TEAMS"));
handle->ifm_subtasks = atoi(getenv("IFM_SUBTASKS"));
handle->ofm_subtasks = atoi(getenv("OFM_SUBTASKS"));
#else
/* Initialize with default values */
handle->fwd_bf = 1;
handle->bwd_bf = 1;
handle->upd_bf = 1;
handle->fwd_2d_blocking = 0;
handle->bwd_2d_blocking = 0;
handle->upd_2d_blocking = 0;
handle->fwd_row_teams = 1;
handle->fwd_column_teams = 1;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1;
handle->ofm_subtasks = 1;
if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 28) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 1;
handle->fwd_row_teams = 14;
handle->fwd_column_teams = 2;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 28) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 1;
handle->fwd_row_teams = 7;
handle->fwd_column_teams = 4;
handle->bwd_bf = ((handle->desc.K/handle->bk) % 8 == 0) ? 8 : 1;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 7;
handle->bwd_column_teams = 4;
handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 7;
handle->upd_column_teams = 4;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 28) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 1;
handle->fwd_column_teams = 1;
handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 28) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 1;
handle->fwd_column_teams = 1;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 1;
handle->bwd_row_teams = 14;
handle->bwd_column_teams = 2;
handle->upd_bf = ((handle->desc.N/handle->bn) % 2 == 0) ? 2 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 20) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 1;
handle->bwd_row_teams = 5;
handle->bwd_column_teams = 4;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 5;
handle->upd_column_teams = 4;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 20) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 1;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_bf = ((handle->desc.N/handle->bn) % 9 == 0) ? 9 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = ((handle->bk % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
}
if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 24) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 6;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 6;
handle->bwd_column_teams = 4;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 6;
handle->upd_column_teams = 4;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 24) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 1;
handle->bwd_row_teams = 12;
handle->bwd_column_teams = 2;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 5;
handle->upd_column_teams = 4;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 24) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 5;
handle->bwd_column_teams = 4;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 5;
handle->upd_column_teams = 4;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 20) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 1;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 24) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 5;
handle->bwd_column_teams = 4;
handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 5;
handle->upd_column_teams = 4;
handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 20) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 6;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 1;
handle->bwd_row_teams = 5;
handle->bwd_column_teams = 4;
handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 6;
handle->upd_column_teams = 4;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
#endif
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
#if 0
handle->fwd_bf = atoi(getenv("FWD_BF"));
handle->bwd_bf = atoi(getenv("BWD_BF"));
handle->upd_bf = atoi(getenv("UPD_BF"));
handle->fwd_2d_blocking = atoi(getenv("FWD_2D_BLOCKING"));
handle->bwd_2d_blocking = atoi(getenv("BWD_2D_BLOCKING"));
handle->upd_2d_blocking = atoi(getenv("UPD_2D_BLOCKING"));
handle->fwd_row_teams = atoi(getenv("FWD_ROW_TEAMS"));
handle->fwd_column_teams = atoi(getenv("FWD_COLUMN_TEAMS"));
handle->bwd_row_teams = atoi(getenv("BWD_ROW_TEAMS"));
handle->bwd_column_teams = atoi(getenv("BWD_COLUMN_TEAMS"));
handle->upd_row_teams = atoi(getenv("UPD_ROW_TEAMS"));
handle->upd_column_teams = atoi(getenv("UPD_COLUMN_TEAMS"));
handle->ifm_subtasks = atoi(getenv("IFM_SUBTASKS"));
handle->ofm_subtasks = atoi(getenv("OFM_SUBTASKS"));
#else
if (handle->desc.compressed_A > 0) {
handle->compressed_A = 1;
handle->sparsity_factor_A = handle->desc.sparsity_factor_A;
}
/* Initialize with default values */
handle->fwd_bf = 1;
handle->bwd_bf = 1;
handle->upd_bf = 1;
handle->fwd_2d_blocking = 0;
handle->bwd_2d_blocking = 0;
handle->upd_2d_blocking = 0;
handle->fwd_row_teams = 1;
handle->fwd_column_teams = 1;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1;
handle->ofm_subtasks = 1;
if (handle->desc.threads == 14) {
handle->fwd_bf = 1;
handle->bwd_bf = 1;
handle->upd_bf = 1;
handle->fwd_2d_blocking = 1;
handle->bwd_2d_blocking = 1;
handle->upd_2d_blocking = 0;
handle->fwd_row_teams = 2;
handle->fwd_column_teams = 7;
handle->bwd_row_teams = 2;
handle->bwd_column_teams = 7;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1;
handle->ofm_subtasks = 1;
}
if (handle->desc.threads == 2) {
handle->fwd_bf = 1;
handle->bwd_bf = 1;
handle->upd_bf = 1;
handle->fwd_2d_blocking = 1;
handle->bwd_2d_blocking = 1;
handle->upd_2d_blocking = 0;
handle->fwd_row_teams = 2;
handle->fwd_column_teams = 1;
handle->bwd_row_teams = 2;
handle->bwd_column_teams = 1;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1;
handle->ofm_subtasks = 1;
}
if (handle->desc.threads == 4) {
handle->fwd_bf = 1;
handle->bwd_bf = 1;
handle->upd_bf = 1;
handle->fwd_2d_blocking = 1;
handle->bwd_2d_blocking = 1;
handle->upd_2d_blocking = 0;
handle->fwd_row_teams = 2;
handle->fwd_column_teams = 2;
handle->bwd_row_teams = 2;
handle->bwd_column_teams = 2;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1;
handle->ofm_subtasks = 1;
}
if (handle->desc.threads == 8) {
handle->fwd_bf = 1;
handle->bwd_bf = 1;
handle->upd_bf = 1;
handle->fwd_2d_blocking = 1;
handle->bwd_2d_blocking = 1;
handle->upd_2d_blocking = 0;
handle->fwd_row_teams = 2;
handle->fwd_column_teams = 4;
handle->bwd_row_teams = 2;
handle->bwd_column_teams = 4;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1;
handle->ofm_subtasks = 1;
}
if (handle->desc.threads == 16) {
handle->fwd_bf = 1;
handle->bwd_bf = 1;
handle->upd_bf = 1;
handle->fwd_2d_blocking = 1;
handle->bwd_2d_blocking = 1;
handle->upd_2d_blocking = 0;
handle->fwd_row_teams = 2;
handle->fwd_column_teams = 8;
handle->bwd_row_teams = 2;
handle->bwd_column_teams = 8;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1;
handle->ofm_subtasks = 1;
}
if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 28) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 1;
handle->fwd_row_teams = 14;
handle->fwd_column_teams = 2;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 28) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 1;
handle->fwd_row_teams = 7;
handle->fwd_column_teams = 4;
handle->bwd_bf = ((handle->desc.K/handle->bk) % 8 == 0) ? 8 : 1;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 7;
handle->bwd_column_teams = 4;
handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 7;
handle->upd_column_teams = 4;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 28) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 1;
handle->fwd_column_teams = 1;
handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 28) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 1;
handle->fwd_column_teams = 1;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 1;
handle->bwd_row_teams = 14;
handle->bwd_column_teams = 2;
handle->upd_bf = ((handle->desc.N/handle->bn) % 2 == 0) ? 2 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 20) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 1;
handle->bwd_row_teams = 5;
handle->bwd_column_teams = 4;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 5;
handle->upd_column_teams = 4;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 20) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 1;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_bf = ((handle->desc.N/handle->bn) % 9 == 0) ? 9 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = ((handle->bk % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
}
if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 24) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 6;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 6;
handle->bwd_column_teams = 4;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 6;
handle->upd_column_teams = 4;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 24) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 1;
handle->bwd_row_teams = 12;
handle->bwd_column_teams = 2;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 5;
handle->upd_column_teams = 4;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 24) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 5;
handle->bwd_column_teams = 4;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 5;
handle->upd_column_teams = 4;
handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 20) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 1;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 1;
handle->bwd_column_teams = 1;
handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 1;
handle->upd_column_teams = 1;
handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 24) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 5;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 0;
handle->bwd_row_teams = 5;
handle->bwd_column_teams = 4;
handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 5;
handle->upd_column_teams = 4;
handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 20) {
handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
handle->fwd_2d_blocking = 0;
handle->fwd_row_teams = 6;
handle->fwd_column_teams = 4;
handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
handle->bwd_2d_blocking = 1;
handle->bwd_row_teams = 5;
handle->bwd_column_teams = 4;
handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
handle->upd_2d_blocking = 0;
handle->upd_row_teams = 6;
handle->upd_column_teams = 4;
handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
}
#endif
/* In this case force 2D decomposition */
if (handle->compressed_A == 1) {
handle->fwd_2d_blocking = 1;
handle->fwd_row_teams = 2;
while (handle->desc.threads % handle->fwd_row_teams != 0) {
handle->fwd_row_teams--;
}
handle->fwd_column_teams = handle->desc.threads/handle->fwd_row_teams;
}
}
} else {
/* check that we cannot fuse */
if ( handle->desc.fuse_ops != LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
free( handle );
*status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
return 0;
}
/* we need to compute the memory layout given the */
if ( (handle->desc.C % 16 == 0) && (handle->desc.K % 16 == 0) ) {
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
*status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K,
&(handle->ifmblock), &(handle->ofmblock), &(handle->fm_lp_block),
LIBXSMM_DNN_DATATYPE_F32, LIBXSMM_DNN_DATATYPE_F32 );
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
*status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K,
&(handle->ifmblock), &(handle->ofmblock), &(handle->fm_lp_block),
handle->desc.datatype_in, handle->desc.datatype_out );
} else {
/* should not happen, not implemented */
}
} else if ( (handle->desc.C % 64 == 0) && (handle->desc.K == 1000) ) {
/* @TODO this a hack for the last FC layer */
handle->ifmblock = 64;
handle->fm_lp_block = 1;
handle->ofmblock = 10;
} else if ( (handle->desc.C % 16 == 0) && (handle->desc.K == 1000) ) {
/* @TODO this a hack for the last FC layer */
handle->ifmblock = 16;
handle->fm_lp_block = 1;
handle->ofmblock = 10;
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
free( handle );
return 0;
}
/* compute the outer blocks */
handle->blocksifm = handle->desc.C / handle->ifmblock;
handle->blocksofm = handle->desc.K / handle->ofmblock;
}
/* create barrier */
handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
/* If in SPR, generate tilerelease kernel */
if ((handle->target_archid >= LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) {
int l_tr_flags = LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
handle->tilerelease_kernel = libxsmm_bsmmdispatch(handle->bk, handle->bk, handle->bk, NULL, NULL, NULL, NULL, NULL, &l_tr_flags, NULL);
}
/* calculate scratch size */
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
handle->scratch_size = sizeof(float) * ( ( (size_t)handle->desc.C * (size_t)handle->desc.N ) + ( (size_t)handle->desc.C * (size_t)handle->desc.K ) );
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
/* Let's allocate maximum required scratch */
size_t size_fwd = sizeof(float) * LIBXSMM_MAX(handle->desc.K * handle->desc.N, handle->desc.threads * LIBXSMM_MAX(handle->bk * handle->bn, handle->desc.K));
/* In case of K = 1 we pad A and B to "bk=2" */
size_t size_bwd = (handle->desc.K != 1) ? ( sizeof(float) * LIBXSMM_MAX(handle->desc.C * handle->desc.N, handle->desc.threads * handle->bc * handle->bn) + sizeof(libxsmm_bfloat16) * handle->desc.C * handle->desc.K ) : ( sizeof(float) * handle->desc.C * handle->desc.N + sizeof(libxsmm_bfloat16) * handle->desc.C * 2 + sizeof(libxsmm_bfloat16) * 2 * handle->desc.N );
size_t size_upd = sizeof(float) * LIBXSMM_MAX(handle->desc.C * handle->desc.K, handle->desc.threads * handle->bc * handle->bk) + sizeof(libxsmm_bfloat16) * handle->desc.threads * handle->bk * handle->bc + sizeof(libxsmm_bfloat16) * (handle->desc.N * (handle->desc.C + handle->desc.K));
if (handle->compressed_A == 1) {
size_fwd += handle->desc.threads * handle->desc.C * handle->bk *sizeof(libxsmm_bfloat16);
}
handle->scratch_size = LIBXSMM_MAX(LIBXSMM_MAX(size_fwd, size_bwd), size_upd);
handle->doutput_scratch_mark = handle->scratch_size;
handle->scratch_size += 2 * sizeof(libxsmm_bfloat16) * handle->desc.N * handle->desc.K;
} else {
handle->scratch_size = sizeof(float) * ( (((size_t)handle->desc.C + (size_t)handle->desc.K) * (size_t)handle->desc.N) + ((size_t)handle->desc.C * (size_t)handle->desc.K) );
}
/* create code pointers in some special cases */
if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) && ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) ) {
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
float alpha = 1.0f;
/* beta is set to 1 for ncnc kcck format because ifm is split into 2 blocks */
float beta = 1.0f;
float zerobeta = 0.0f;
int updflags = LIBXSMM_GEMM_FLAGS( 'N', 'T' );
/* For UPD kernels we consider subtasking... */
libxsmm_blasint M = handle->bk/handle->ofm_subtasks;
libxsmm_blasint N = handle->bc/handle->ifm_subtasks;
libxsmm_blasint lda = (libxsmm_blasint)handle->bk;
libxsmm_blasint ldb = (libxsmm_blasint)handle->bc;
libxsmm_blasint ldc = (libxsmm_blasint)handle->bk;
handle->gemm_fwd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(float), handle->bc*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
handle->gemm_fwd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(float), handle->bc*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL);
handle->gemm_bwd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(float), handle->bk*handle->bn*sizeof(float), &ldb, &lda, &ldb, &alpha, &beta, NULL, NULL);
handle->gemm_bwd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(float), handle->bk*handle->bn*sizeof(float), &ldb, &lda, &ldb, &alpha, &zerobeta, NULL, NULL);
/* Transpose kernel used for weight transpose in bwd pass */
handle->tr_kernel = libxsmm_dispatch_meltw_unary((libxsmm_blasint)(handle->bk), (libxsmm_blasint)(handle->bc), (const libxsmm_blasint*)&(handle->bk), (const libxsmm_blasint*)&(handle->bc), LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
/* update has different LDs */
lda = (libxsmm_blasint)handle->bk;
ldb = (libxsmm_blasint)handle->bc;
ldc = (libxsmm_blasint)handle->bk;
handle->gemm_upd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(M, N, handle->bn, handle->desc.K*handle->bn*sizeof(float), handle->desc.C*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &beta, &updflags, NULL);
handle->gemm_upd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(M, N, handle->bn, handle->desc.K*handle->bn*sizeof(float), handle->desc.C*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &zerobeta, &updflags, NULL);
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
float alpha = 1.0f;
float beta = 1.0f;
float zerobeta = 0.0f;
/* For UPD kernels we consider subtasking... */
libxsmm_blasint M = handle->bk/handle->ofm_subtasks;
libxsmm_blasint N = handle->bc/handle->ifm_subtasks;
libxsmm_blasint lda = (libxsmm_blasint)handle->bk;
libxsmm_blasint ldb = (libxsmm_blasint)handle->bc;
libxsmm_blasint ldc = (libxsmm_blasint)handle->bk;
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) {
libxsmm_meltw_flags fusion_flags;
int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG;
int l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
libxsmm_blasint unroll_hint = (handle->desc.C/handle->bc)/handle->fwd_bf;
handle->gemm_fwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &beta, &l_flags, NULL);
handle->gemm_fwd2.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL);
handle->fwd_config_kernel = libxsmm_bsmmdispatch(handle->bk, handle->bn, handle->bc, &lda, &ldb, &ldc, NULL, &beta, &l_tc_flags, NULL);
handle->gemm_fwd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL);
fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_OVERWRITE_C;
handle->gemm_fwd4.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0);
fusion_flags = LIBXSMM_MELTW_FLAG_ACT_RELU_OVERWRITE_C;
handle->gemm_fwd5.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0);
fusion_flags = LIBXSMM_MELTW_FLAG_ACT_SIGM_OVERWRITE_C;
handle->gemm_fwd6.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0);
fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_ACT_RELU_OVERWRITE_C;
handle->gemm_fwd7.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0);
fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_ACT_SIGM_OVERWRITE_C;
handle->gemm_fwd8.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT, LIBXSMM_DATATYPE_F32, fusion_flags, 0, 0, 0, 0);
if (handle->compressed_A == 1) {
fusion_flags = LIBXSMM_MELTW_FLAG_FUSE_NONE;
handle->gemm_fwd9.xgemm.bsmrs_meltwfused = libxsmm_bsmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &beta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0);
handle->gemm_fwd10.xgemm.bsmrs_meltwfused = libxsmm_bsmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0);
handle->fwd_config_kernel = libxsmm_bsmmdispatch(handle->bk, handle->bn, handle->bc, &lda, &ldb, &ldc, NULL, &beta, &l_tc_flags, NULL);
handle->gemm_fwd11.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0);
fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_OVERWRITE_C;
handle->gemm_fwd12.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0);
fusion_flags = LIBXSMM_MELTW_FLAG_ACT_RELU_OVERWRITE_C;
handle->gemm_fwd13.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0);
fusion_flags = LIBXSMM_MELTW_FLAG_ACT_SIGM_OVERWRITE_C;
handle->gemm_fwd14.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0);
fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_ACT_RELU_OVERWRITE_C;
handle->gemm_fwd15.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0);
fusion_flags = LIBXSMM_MELTW_FLAG_COLBIAS_ACT_SIGM_OVERWRITE_C;
handle->gemm_fwd16.xgemm.bmrs_meltwfused = libxsmm_bmmdispatch_reducebatch_strd_meltwfused_unroll(handle->bk, handle->bn, handle->bc, (handle->bk*handle->bc*sizeof(libxsmm_bfloat16))/handle->sparsity_factor_A, handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL, LIBXSMM_MELTW_OPERATION_COLBIAS_ACT_DECOMPRESS_A, LIBXSMM_DATATYPE_F32, fusion_flags, handle->sparsity_factor_A, 0, 0, 0);
}
/* Also JIT eltwise functions... */
handle->fwd_cvtfp32bf16_kernel = libxsmm_dispatch_meltw_unary(handle->bk, handle->bn, &ldc, &ldc, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY);
handle->fwd_cvtfp32bf16_relu_kernel = libxsmm_dispatch_meltw_unary(handle->bk, handle->bn, &ldc, &ldc, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_BITMASK, LIBXSMM_MELTW_TYPE_UNARY_RELU);
handle->fwd_sigmoid_cvtfp32bf16_kernel = libxsmm_dispatch_meltw_unary(handle->bk, handle->bn, &ldc, &ldc, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_SIGMOID);
} else {
handle->gemm_fwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
handle->gemm_fwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL);
handle->gemm_fwd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
}
/* Special bwd kernels for K == 1 */
if (handle->desc.K == 1) {
libxsmm_blasint _bk = 2;
handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bc, handle->bn, _bk, _bk*handle->bc*sizeof(libxsmm_bfloat16), _bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &_bk, &ldb, &alpha, &beta, NULL, NULL);
handle->gemm_bwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bc, handle->bn, _bk, _bk*handle->bc*sizeof(libxsmm_bfloat16), _bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &_bk, &ldb, &alpha, &zerobeta, NULL, NULL);
} else {
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) {
int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG;
int l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
libxsmm_blasint unroll_hint = (handle->desc.K/handle->bk)/handle->bwd_bf;
handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &ldb, &lda, &ldb, &alpha, &beta, &l_flags, NULL);
handle->gemm_bwd2.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &ldb, &lda, &ldb, &alpha, &zerobeta, &l_flags, NULL);
handle->bwd_config_kernel = libxsmm_bsmmdispatch(handle->bc, handle->bn, handle->bk, &ldb, &lda, &ldb, NULL, &beta, &l_tc_flags, NULL);
handle->gemm_bwd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd_unroll(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &ldb, &lda, &ldb, &alpha, &zerobeta, &l_flags, NULL);
/* Also JIT eltwise functions... */
handle->bwd_cvtfp32bf16_kernel = libxsmm_dispatch_meltw_unary(handle->bc, handle->bn, &ldb, &ldb, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_IDENTITY);
handle->bwd_relu_kernel = libxsmm_dispatch_meltw_unary(handle->bc, handle->bn, &ldb, &ldb, LIBXSMM_DATATYPE_BF16, LIBXSMM_DATATYPE_BF16, LIBXSMM_DATATYPE_BF16, LIBXSMM_MELTW_FLAG_UNARY_BITMASK, LIBXSMM_MELTW_TYPE_UNARY_RELU_INV);
} else {
handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &lda, &ldb, &alpha, &beta, NULL, NULL);
handle->gemm_bwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &lda, &ldb, &alpha, &zerobeta, NULL, NULL);
}
}
lda = (libxsmm_blasint)handle->bk;
ldb = (libxsmm_blasint)handle->bn;
ldc = (libxsmm_blasint)handle->bk;
if ((handle->target_archid == LIBXSMM_X86_AVX512_SPR) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT)) {
int l_flags = ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') ) | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG;
int l_tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
libxsmm_blasint unroll_hint = (handle->desc.N/handle->bn)/handle->upd_bf;
handle->gemm_upd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &beta, &l_flags, NULL);
handle->gemm_upd2.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd_unroll(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL);
handle->upd_config_kernel = libxsmm_bsmmdispatch(M, N, handle->bn, &lda, &ldb, &ldc, NULL, &beta, &l_tc_flags, NULL);
l_flags = l_flags | LIBXSMM_GEMM_FLAG_VNNI_C;
handle->gemm_upd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd_unroll(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), unroll_hint, &lda, &ldb, &ldc, &alpha, &zerobeta, &l_flags, NULL);
} else {
handle->gemm_upd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
handle->gemm_upd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL);
}
} else {
}
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
}
} else {
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
return handle;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fullyconnected(const libxsmm_dnn_fullyconnected* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
/* Deallocate barrier */
if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
/* deallocate handle structure */
free(/*remove constness*/(libxsmm_dnn_fullyconnected*)handle);
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fullyconnected_create_tensor_datalayout(const libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor_datalayout* layout;
*status = LIBXSMM_DNN_SUCCESS;
layout = 0;
if (handle != 0) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout));
if (layout != 0) {
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->format = handle->desc.buffer_format;
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = 1;
layout->dim_size[2] = 1;
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = 1;
layout->dim_size[2] = 1;
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = 1;
layout->dim_size[2] = 1;
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->datatype = handle->desc.datatype_out;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = 1;
layout->dim_size[2] = 1;
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->dim_size[0] = handle->desc.C;
layout->dim_size[1] = 1;
layout->dim_size[2] = 1;
layout->dim_size[3] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->desc.K;
layout->dim_size[1] = 1;
layout->dim_size[2] = 1;
layout->dim_size[3] = handle->desc.N;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = (unsigned int)handle->bc;
layout->dim_size[1] = (unsigned int)handle->bn;
layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)handle->bn;
layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) || (type == LIBXSMM_DNN_FILTER) ) {
layout->format = handle->desc.filter_format;
layout->tensor_type = LIBXSMM_DNN_FILTER;
if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 6;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->ifmblock;
layout->dim_size[2] = 1;
layout->dim_size[3] = 1;
layout->dim_size[4] = handle->blocksifm;
layout->dim_size[5] = handle->blocksofm;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ( ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) ||
( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 7;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = handle->fm_lp_block;
layout->dim_size[1] = handle->ofmblock;
layout->dim_size[2] = handle->ifmblock/handle->fm_lp_block;
layout->dim_size[3] = 1;
layout->dim_size[4] = 1;
layout->dim_size[5] = handle->blocksifm;
layout->dim_size[6] = handle->blocksofm;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_size[0] = handle->ofmblock * handle->blocksofm;
layout->dim_size[1] = handle->ifmblock * handle->blocksifm;
layout->dim_size[2] = 1;
layout->dim_size[3] = 1;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)handle->bc;
layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 5;
if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = (unsigned int)2;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)handle->bc/2;
layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) || (type == LIBXSMM_DNN_CHANNEL_BIAS) ) {
layout->format = handle->desc.buffer_format;
layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) {
if ( (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) || (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
layout->datatype = handle->desc.datatype_out;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 2;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)(handle->desc.K / handle->bk);
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else if ( (type == LIBXSMM_DNN_RELU_MASK) ) {
layout->format = handle->desc.buffer_format;
layout->tensor_type = LIBXSMM_DNN_RELU_MASK;
if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_I8;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 1;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = handle->desc.N * handle->desc.K;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return layout;
}
LIBXSMM_API size_t libxsmm_dnn_fullyconnected_get_scratch_size(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status) {
size_t l_scratch_size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return l_scratch_size;
}
LIBXSMM_API void* libxsmm_dnn_fullyconnected_get_scratch_ptr(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status)
{
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
return handle->scratch;
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return 0;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_scratch(libxsmm_dnn_fullyconnected* handle, const void* scratch) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
uintptr_t address = (uintptr_t)scratch;
size_t offset = 0;
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
if (0 != handle) {
/* align the internal scratch buffer if needed */
if (address % 64 == 0) {
handle->scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch = (void*)(address+offset);
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_scratch(libxsmm_dnn_fullyconnected* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
handle->scratch = 0;
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
(type != LIBXSMM_DNN_RELU_MASK) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0 && tensor != 0) {
libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout(handle, type, &status);
if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
handle->reg_filter = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
handle->grad_filter = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
handle->reg_bias = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
handle->grad_bias = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RELU_MASK ) {
handle->relumask = (libxsmm_dnn_tensor*)tensor;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
}
libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fullyconnected_get_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor* return_tensor = 0;
*status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
(type != LIBXSMM_DNN_RELU_MASK) ) {
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return return_tensor;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
return_tensor = handle->reg_input;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
return_tensor = handle->grad_input;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
return_tensor = handle->reg_output;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
return_tensor = handle->grad_output;
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
return_tensor = handle->reg_filter;
} else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
return_tensor = handle->grad_filter;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
return_tensor = handle->reg_bias;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
return_tensor = handle->grad_bias;
} else if ( type == LIBXSMM_DNN_RELU_MASK ) {
return_tensor = handle->relumask;
} else {
/* cannot happen */
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return return_tensor;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
(type != LIBXSMM_DNN_RELU_MASK) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
handle->reg_filter = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
handle->grad_filter = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
handle->reg_bias = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
handle->grad_bias = 0;
} else if ( type == LIBXSMM_DNN_RELU_MASK ) {
handle->relumask = 0;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_execute_st(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind,
/*unsigned*/int start_thread, /*unsigned*/int tid) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
LIBXSMM_UNUSED( start_thread );
LIBXSMM_UNUSED( tid );
if (0 != handle) {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) ) {
status = libxsmm_dnn_fullyconnected_st_fwd_custom( handle, start_thread, tid );
} else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FC;
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: {
if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) ) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_custom( handle, kind, start_thread, tid );
} else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck( handle, kind, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FC;
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_fullyconnected_backward_weight_update.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
#if 0
#define USE_CLDEMOTE
#endif
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void bf16_vnni_transpose_16x16(void* source_void, void* dest_void, int source_stride, int dest_stride)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
libxsmm_bfloat16 *source = (libxsmm_bfloat16*)source_void;
libxsmm_bfloat16 *dest = (libxsmm_bfloat16*)dest_void;
__m512i zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7;
__m512i tmp0, tmp1, tmp2, tmp3;
const __m512i abcdefgh_to_abefcdgh = _mm512_set4_epi32(0x0f0e0b0a, 0x0d0c0908, 0x07060302, 0x05040100);
zmm0 = _mm512_loadu_si512(source);
zmm1 = _mm512_loadu_si512(source + source_stride);
zmm2 = _mm512_loadu_si512(source + source_stride*2);
zmm3 = _mm512_loadu_si512(source + source_stride*3);
zmm4 = _mm512_loadu_si512(source + source_stride*4);
zmm5 = _mm512_loadu_si512(source + source_stride*5);
zmm6 = _mm512_loadu_si512(source + source_stride*6);
zmm7 = _mm512_loadu_si512(source + source_stride*7);
zmm0 = _mm512_shuffle_epi8(zmm0, abcdefgh_to_abefcdgh);
zmm1 = _mm512_shuffle_epi8(zmm1, abcdefgh_to_abefcdgh);
zmm2 = _mm512_shuffle_epi8(zmm2, abcdefgh_to_abefcdgh);
zmm3 = _mm512_shuffle_epi8(zmm3, abcdefgh_to_abefcdgh);
zmm4 = _mm512_shuffle_epi8(zmm4, abcdefgh_to_abefcdgh);
zmm5 = _mm512_shuffle_epi8(zmm5, abcdefgh_to_abefcdgh);
zmm6 = _mm512_shuffle_epi8(zmm6, abcdefgh_to_abefcdgh);
zmm7 = _mm512_shuffle_epi8(zmm7, abcdefgh_to_abefcdgh);
tmp0 = _mm512_unpacklo_epi64(zmm0, zmm1);
tmp1 = _mm512_unpackhi_epi64(zmm0, zmm1);
tmp2 = _mm512_unpacklo_epi64(zmm2, zmm3);
tmp3 = _mm512_unpackhi_epi64(zmm2, zmm3);
zmm0 = _mm512_unpacklo_epi64(zmm4, zmm5);
zmm1 = _mm512_unpackhi_epi64(zmm4, zmm5);
zmm2 = _mm512_unpacklo_epi64(zmm6, zmm7);
zmm3 = _mm512_unpackhi_epi64(zmm6, zmm7);
zmm4 = _mm512_shuffle_i32x4(tmp0, tmp2, 0x88);
zmm6 = _mm512_shuffle_i32x4(tmp0, tmp2, 0xdd);
zmm5 = _mm512_shuffle_i32x4(tmp1, tmp3, 0x88);
zmm7 = _mm512_shuffle_i32x4(tmp1, tmp3, 0xdd);
tmp0 = _mm512_shuffle_i32x4(zmm0, zmm2, 0x88);
tmp1 = _mm512_shuffle_i32x4(zmm0, zmm2, 0xdd);
tmp2 = _mm512_shuffle_i32x4(zmm1, zmm3, 0x88);
tmp3 = _mm512_shuffle_i32x4(zmm1, zmm3, 0xdd);
zmm0 = _mm512_shuffle_i32x4(zmm4, tmp0, 0x88);
zmm1 = _mm512_shuffle_i32x4(zmm5, tmp2, 0x88);
zmm2 = _mm512_shuffle_i32x4(zmm6, tmp1, 0x88);
zmm3 = _mm512_shuffle_i32x4(zmm7, tmp3, 0x88);
zmm4 = _mm512_shuffle_i32x4(zmm4, tmp0, 0xdd);
zmm5 = _mm512_shuffle_i32x4(zmm5, tmp2, 0xdd);
zmm6 = _mm512_shuffle_i32x4(zmm6, tmp1, 0xdd);
zmm7 = _mm512_shuffle_i32x4(zmm7, tmp3, 0xdd);
_mm512_storeu_si512(dest, zmm0);
_mm512_storeu_si512(dest + dest_stride, zmm1);
_mm512_storeu_si512(dest + dest_stride * 2, zmm2);
_mm512_storeu_si512(dest + dest_stride * 3, zmm3);
_mm512_storeu_si512(dest + dest_stride * 4, zmm4);
_mm512_storeu_si512(dest + dest_stride * 5, zmm5);
_mm512_storeu_si512(dest + dest_stride * 6, zmm6);
_mm512_storeu_si512(dest + dest_stride * 7, zmm7);
#ifdef USE_CLDEMOTE
_mm_cldemote(dest);
_mm_cldemote(dest + dest_stride);
_mm_cldemote(dest + dest_stride * 2);
_mm_cldemote(dest + dest_stride * 3);
_mm_cldemote(dest + dest_stride * 4);
_mm_cldemote(dest + dest_stride * 5);
_mm_cldemote(dest + dest_stride * 6);
_mm_cldemote(dest + dest_stride * 7);
#endif
#else
LIBXSMM_UNUSED(source_void); LIBXSMM_UNUSED(dest_void); LIBXSMM_UNUSED(source_stride); LIBXSMM_UNUSED(dest_stride);
#endif
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void bf16_vnni_transpose(libxsmm_bfloat16* src, libxsmm_bfloat16* dst, int M, int N, int ld_in, int ld_out)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
const int _M = M/16, _N = N/16;
int i = 0, j = 0;
for (i = 0; i < _N; i++) {
for (j = 0; j < _M; j++) {
bf16_vnni_transpose_16x16((libxsmm_bfloat16*) src+i*16*ld_in+j*32, (libxsmm_bfloat16*) dst+j*16*ld_out+i*32, ld_in*2, ld_out*2);
}
}
#else
LIBXSMM_UNUSED(src); LIBXSMM_UNUSED(dst); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
#endif
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void bf16_transpose_32x16(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int ld_in, int ld_out)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf;
__m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf;
const int in_width=ld_in, out_width=ld_out;
const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10);
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
rc = _mm512_loadu_si512(in + 12*in_width);
rd = _mm512_loadu_si512(in + 13*in_width);
re = _mm512_loadu_si512(in + 14*in_width);
rf = _mm512_loadu_si512(in + 15*in_width);
t0 = _mm512_unpacklo_epi16(r0,r1);
t1 = _mm512_unpackhi_epi16(r0,r1);
t2 = _mm512_unpacklo_epi16(r2,r3);
t3 = _mm512_unpackhi_epi16(r2,r3);
t4 = _mm512_unpacklo_epi16(r4,r5);
t5 = _mm512_unpackhi_epi16(r4,r5);
t6 = _mm512_unpacklo_epi16(r6,r7);
t7 = _mm512_unpackhi_epi16(r6,r7);
t8 = _mm512_unpacklo_epi16(r8,r9);
t9 = _mm512_unpackhi_epi16(r8,r9);
ta = _mm512_unpacklo_epi16(ra,rb);
tb = _mm512_unpackhi_epi16(ra,rb);
tc = _mm512_unpacklo_epi16(rc,rd);
td = _mm512_unpackhi_epi16(rc,rd);
te = _mm512_unpacklo_epi16(re,rf);
tf = _mm512_unpackhi_epi16(re,rf);
r0 = _mm512_unpacklo_epi32(t0,t2);
r1 = _mm512_unpackhi_epi32(t0,t2);
r2 = _mm512_unpacklo_epi32(t1,t3);
r3 = _mm512_unpackhi_epi32(t1,t3);
r4 = _mm512_unpacklo_epi32(t4,t6);
r5 = _mm512_unpackhi_epi32(t4,t6);
r6 = _mm512_unpacklo_epi32(t5,t7);
r7 = _mm512_unpackhi_epi32(t5,t7);
r8 = _mm512_unpacklo_epi32(t8,ta);
r9 = _mm512_unpackhi_epi32(t8,ta);
ra = _mm512_unpacklo_epi32(t9,tb);
rb = _mm512_unpackhi_epi32(t9,tb);
rc = _mm512_unpacklo_epi32(tc,te);
rd = _mm512_unpackhi_epi32(tc,te);
re = _mm512_unpacklo_epi32(td,tf);
rf = _mm512_unpackhi_epi32(td,tf);
t0 = _mm512_unpacklo_epi64(r0,r4);
t1 = _mm512_unpackhi_epi64(r0,r4);
t2 = _mm512_unpacklo_epi64(r1,r5);
t3 = _mm512_unpackhi_epi64(r1,r5);
t4 = _mm512_unpacklo_epi64(r2,r6);
t5 = _mm512_unpackhi_epi64(r2,r6);
t6 = _mm512_unpacklo_epi64(r3,r7);
t7 = _mm512_unpackhi_epi64(r3,r7);
t8 = _mm512_unpacklo_epi64(r8,rc);
t9 = _mm512_unpackhi_epi64(r8,rc);
ta = _mm512_unpacklo_epi64(r9,rd);
tb = _mm512_unpackhi_epi64(r9,rd);
tc = _mm512_unpacklo_epi64(ra,re);
td = _mm512_unpackhi_epi64(ra,re);
te = _mm512_unpacklo_epi64(rb,rf);
tf = _mm512_unpackhi_epi64(rb,rf);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
ra = _mm512_shuffle_i32x4(tc, td, 0x88);
rb = _mm512_shuffle_i32x4(te, tf, 0x88);
rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
re = _mm512_shuffle_i32x4(tc, td, 0xdd);
rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8);
t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9);
t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra);
t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb);
t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc);
t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd);
t6 = _mm512_permutex2var_epi64(r6, idx_lo, re);
t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf);
t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0);
t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1);
ta = _mm512_permutex2var_epi64(ra, idx_hi, r2);
tb = _mm512_permutex2var_epi64(rb, idx_hi, r3);
tc = _mm512_permutex2var_epi64(rc, idx_hi, r4);
td = _mm512_permutex2var_epi64(rd, idx_hi, r5);
te = _mm512_permutex2var_epi64(re, idx_hi, r6);
tf = _mm512_permutex2var_epi64(rf, idx_hi, r7);
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 0*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 1*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 2*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 3*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 4*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 5*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 6*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 7*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 8*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 9*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 10*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 11*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 12*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 13*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 14*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 15*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 16*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 17*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 18*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 19*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 20*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 21*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 22*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 23*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 24*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 25*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 26*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 27*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 28*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 29*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 30*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 31*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1));
#ifdef USE_CLDEMOTE
_mm_cldemote(out + 0*out_width);
_mm_cldemote(out + 1*out_width);
_mm_cldemote(out + 2*out_width);
_mm_cldemote(out + 3*out_width);
_mm_cldemote(out + 4*out_width);
_mm_cldemote(out + 5*out_width);
_mm_cldemote(out + 6*out_width);
_mm_cldemote(out + 7*out_width);
_mm_cldemote(out + 8*out_width);
_mm_cldemote(out + 9*out_width);
_mm_cldemote(out + 10*out_width);
_mm_cldemote(out + 11*out_width);
_mm_cldemote(out + 12*out_width);
_mm_cldemote(out + 13*out_width);
_mm_cldemote(out + 14*out_width);
_mm_cldemote(out + 15*out_width);
_mm_cldemote(out + 16*out_width);
_mm_cldemote(out + 17*out_width);
_mm_cldemote(out + 18*out_width);
_mm_cldemote(out + 19*out_width);
_mm_cldemote(out + 20*out_width);
_mm_cldemote(out + 21*out_width);
_mm_cldemote(out + 22*out_width);
_mm_cldemote(out + 23*out_width);
_mm_cldemote(out + 24*out_width);
_mm_cldemote(out + 25*out_width);
_mm_cldemote(out + 26*out_width);
_mm_cldemote(out + 27*out_width);
_mm_cldemote(out + 28*out_width);
_mm_cldemote(out + 29*out_width);
_mm_cldemote(out + 30*out_width);
_mm_cldemote(out + 31*out_width);
#endif
#else
LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
#endif
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void bf16_transpose_32xcols(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int col, int ld_in, int ld_out)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i r0 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r1 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rf = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
__m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf;
const int in_width=ld_in, out_width=ld_out;
const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10);
__mmask16 store_mask = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(((unsigned int)1 << col) - 1);
if (col == 15) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
rc = _mm512_loadu_si512(in + 12*in_width);
rd = _mm512_loadu_si512(in + 13*in_width);
re = _mm512_loadu_si512(in + 14*in_width);
} else if (col == 14) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
rc = _mm512_loadu_si512(in + 12*in_width);
rd = _mm512_loadu_si512(in + 13*in_width);
} else if (col == 13) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
rc = _mm512_loadu_si512(in + 12*in_width);
} else if (col == 12) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
} else if (col == 11) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
ra = _mm512_loadu_si512(in + 10*in_width);
} else if (col == 10) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
} else if (col == 9) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
r8 = _mm512_loadu_si512(in + 8*in_width);
} else if (col == 8) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
} else if (col == 7) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
r6 = _mm512_loadu_si512(in + 6*in_width);
} else if (col == 6) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
} else if (col == 5) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
r4 = _mm512_loadu_si512(in + 4*in_width);
} else if (col == 4) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
} else if (col == 3) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
r2 = _mm512_loadu_si512(in + 2*in_width);
} else if (col == 2) {
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
} else if (col == 1) {
r0 = _mm512_loadu_si512(in + 0*in_width);
}
t0 = _mm512_unpacklo_epi16(r0,r1);
t1 = _mm512_unpackhi_epi16(r0,r1);
t2 = _mm512_unpacklo_epi16(r2,r3);
t3 = _mm512_unpackhi_epi16(r2,r3);
t4 = _mm512_unpacklo_epi16(r4,r5);
t5 = _mm512_unpackhi_epi16(r4,r5);
t6 = _mm512_unpacklo_epi16(r6,r7);
t7 = _mm512_unpackhi_epi16(r6,r7);
t8 = _mm512_unpacklo_epi16(r8,r9);
t9 = _mm512_unpackhi_epi16(r8,r9);
ta = _mm512_unpacklo_epi16(ra,rb);
tb = _mm512_unpackhi_epi16(ra,rb);
tc = _mm512_unpacklo_epi16(rc,rd);
td = _mm512_unpackhi_epi16(rc,rd);
te = _mm512_unpacklo_epi16(re,rf);
tf = _mm512_unpackhi_epi16(re,rf);
r0 = _mm512_unpacklo_epi32(t0,t2);
r1 = _mm512_unpackhi_epi32(t0,t2);
r2 = _mm512_unpacklo_epi32(t1,t3);
r3 = _mm512_unpackhi_epi32(t1,t3);
r4 = _mm512_unpacklo_epi32(t4,t6);
r5 = _mm512_unpackhi_epi32(t4,t6);
r6 = _mm512_unpacklo_epi32(t5,t7);
r7 = _mm512_unpackhi_epi32(t5,t7);
r8 = _mm512_unpacklo_epi32(t8,ta);
r9 = _mm512_unpackhi_epi32(t8,ta);
ra = _mm512_unpacklo_epi32(t9,tb);
rb = _mm512_unpackhi_epi32(t9,tb);
rc = _mm512_unpacklo_epi32(tc,te);
rd = _mm512_unpackhi_epi32(tc,te);
re = _mm512_unpacklo_epi32(td,tf);
rf = _mm512_unpackhi_epi32(td,tf);
t0 = _mm512_unpacklo_epi64(r0,r4);
t1 = _mm512_unpackhi_epi64(r0,r4);
t2 = _mm512_unpacklo_epi64(r1,r5);
t3 = _mm512_unpackhi_epi64(r1,r5);
t4 = _mm512_unpacklo_epi64(r2,r6);
t5 = _mm512_unpackhi_epi64(r2,r6);
t6 = _mm512_unpacklo_epi64(r3,r7);
t7 = _mm512_unpackhi_epi64(r3,r7);
t8 = _mm512_unpacklo_epi64(r8,rc);
t9 = _mm512_unpackhi_epi64(r8,rc);
ta = _mm512_unpacklo_epi64(r9,rd);
tb = _mm512_unpackhi_epi64(r9,rd);
tc = _mm512_unpacklo_epi64(ra,re);
td = _mm512_unpackhi_epi64(ra,re);
te = _mm512_unpacklo_epi64(rb,rf);
tf = _mm512_unpackhi_epi64(rb,rf);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
ra = _mm512_shuffle_i32x4(tc, td, 0x88);
rb = _mm512_shuffle_i32x4(te, tf, 0x88);
rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
re = _mm512_shuffle_i32x4(tc, td, 0xdd);
rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8);
t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9);
t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra);
t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb);
t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc);
t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd);
t6 = _mm512_permutex2var_epi64(r6, idx_lo, re);
t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf);
t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0);
t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1);
ta = _mm512_permutex2var_epi64(ra, idx_hi, r2);
tb = _mm512_permutex2var_epi64(rb, idx_hi, r3);
tc = _mm512_permutex2var_epi64(rc, idx_hi, r4);
td = _mm512_permutex2var_epi64(rd, idx_hi, r5);
te = _mm512_permutex2var_epi64(re, idx_hi, r6);
tf = _mm512_permutex2var_epi64(rf, idx_hi, r7);
_mm256_mask_storeu_epi16(out + 0*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0));
_mm256_mask_storeu_epi16(out + 1*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1));
_mm256_mask_storeu_epi16(out + 2*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0));
_mm256_mask_storeu_epi16(out + 3*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1));
_mm256_mask_storeu_epi16(out + 4*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0));
_mm256_mask_storeu_epi16(out + 5*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1));
_mm256_mask_storeu_epi16(out + 6*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0));
_mm256_mask_storeu_epi16(out + 7*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1));
_mm256_mask_storeu_epi16(out + 8*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0));
_mm256_mask_storeu_epi16(out + 9*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1));
_mm256_mask_storeu_epi16(out + 10*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0));
_mm256_mask_storeu_epi16(out + 11*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1));
_mm256_mask_storeu_epi16(out + 12*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0));
_mm256_mask_storeu_epi16(out + 13*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1));
_mm256_mask_storeu_epi16(out + 14*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0));
_mm256_mask_storeu_epi16(out + 15*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1));
_mm256_mask_storeu_epi16(out + 16*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0));
_mm256_mask_storeu_epi16(out + 17*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1));
_mm256_mask_storeu_epi16(out + 18*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0));
_mm256_mask_storeu_epi16(out + 19*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1));
_mm256_mask_storeu_epi16(out + 20*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0));
_mm256_mask_storeu_epi16(out + 21*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1));
_mm256_mask_storeu_epi16(out + 22*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0));
_mm256_mask_storeu_epi16(out + 23*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1));
_mm256_mask_storeu_epi16(out + 24*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0));
_mm256_mask_storeu_epi16(out + 25*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1));
_mm256_mask_storeu_epi16(out + 26*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0));
_mm256_mask_storeu_epi16(out + 27*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1));
_mm256_mask_storeu_epi16(out + 28*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0));
_mm256_mask_storeu_epi16(out + 29*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1));
_mm256_mask_storeu_epi16(out + 30*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0));
_mm256_mask_storeu_epi16(out + 31*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1));
#ifdef USE_CLDEMOTE
_mm_cldemote(out + 0*out_width);
_mm_cldemote(out + 1*out_width);
_mm_cldemote(out + 2*out_width);
_mm_cldemote(out + 3*out_width);
_mm_cldemote(out + 4*out_width);
_mm_cldemote(out + 5*out_width);
_mm_cldemote(out + 6*out_width);
_mm_cldemote(out + 7*out_width);
_mm_cldemote(out + 8*out_width);
_mm_cldemote(out + 9*out_width);
_mm_cldemote(out + 10*out_width);
_mm_cldemote(out + 11*out_width);
_mm_cldemote(out + 12*out_width);
_mm_cldemote(out + 13*out_width);
_mm_cldemote(out + 14*out_width);
_mm_cldemote(out + 15*out_width);
_mm_cldemote(out + 16*out_width);
_mm_cldemote(out + 17*out_width);
_mm_cldemote(out + 18*out_width);
_mm_cldemote(out + 19*out_width);
_mm_cldemote(out + 20*out_width);
_mm_cldemote(out + 21*out_width);
_mm_cldemote(out + 22*out_width);
_mm_cldemote(out + 23*out_width);
_mm_cldemote(out + 24*out_width);
_mm_cldemote(out + 25*out_width);
_mm_cldemote(out + 26*out_width);
_mm_cldemote(out + 27*out_width);
_mm_cldemote(out + 28*out_width);
_mm_cldemote(out + 29*out_width);
_mm_cldemote(out + 30*out_width);
_mm_cldemote(out + 31*out_width);
#endif
#else
LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out); LIBXSMM_UNUSED(col);
#endif
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void bf16_transpose(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int M, int N, int ld_in, int ld_out){
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
int i, j;
int full16_chunks = N/16;
int remainder_cols = N%16;
int _N = N - remainder_cols;
if (full16_chunks) {
for (i=0; i<M; i+=32) {
for (j=0; j<_N; j+=16) {
bf16_transpose_32x16((libxsmm_bfloat16*)in + i + ld_in*j, (libxsmm_bfloat16*)out + j + i*ld_out, ld_in, ld_out);
}
}
}
if (remainder_cols) {
for (i=0; i<M; i+=32) {
bf16_transpose_32xcols((libxsmm_bfloat16*)in + i + ld_in*full16_chunks*16, (libxsmm_bfloat16*)out + full16_chunks*16 + i*ld_out, remainder_cols, ld_in, ld_out);
}
}
#else
LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
#endif
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void bf16_vnni_reformat(libxsmm_bfloat16 *_in, libxsmm_bfloat16 *_out, int M, int N, int ld_in, int ld_out) {
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
int n_full_pairs = N/2, n_pair, m;
int half_n_pair = N%2;
libxsmm_bfloat16 *in = _in, *out = _out;
const __m512i selector = LIBXSMM_INTRINSICS_MM512_SET_EPI16(32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0);
const __m512i offsets_lo = LIBXSMM_INTRINSICS_MM512_SET_EPI16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0);
const __m512i offsets_hi = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 31, 30, 30, 29, 29, 28, 28, 27, 27, 26, 26, 25, 25, 24, 24, 23, 23, 22, 22, 21, 21, 20, 20, 19, 19, 18, 18, 17, 17, 16, 16);
const __m512i idx_lo = _mm512_or_epi32(selector, offsets_lo);
const __m512i idx_hi = _mm512_or_epi32(selector, offsets_hi);
const __m512i zero_reg = _mm512_setzero_si512();
__m512i n0, n1, out_lo, out_hi;
LIBXSMM_UNUSED(ld_out);
for (n_pair = 0; n_pair < n_full_pairs; n_pair++) {
for (m = 0; m < M; m+=32) {
n0 = _mm512_loadu_si512((const libxsmm_bfloat16*)in+m);
n1 = _mm512_loadu_si512((const libxsmm_bfloat16*)in+m+ld_in);
out_lo = _mm512_permutex2var_epi16(n0, idx_lo, n1);
out_hi = _mm512_permutex2var_epi16(n0, idx_hi, n1);
_mm512_storeu_si512((libxsmm_bfloat16*)out+m*2, out_lo);
_mm512_storeu_si512((libxsmm_bfloat16*)out+m*2+32, out_hi);
#ifdef USE_CLDEMOTE
_mm_cldemote((libxsmm_bfloat16*)out+m*2);
_mm_cldemote((libxsmm_bfloat16*)out+m*2+32);
#endif
}
in += 2*ld_in;
out += 2*ld_in;
}
if (half_n_pair == 1) {
for (m = 0; m < M; m+=32) {
n0 = _mm512_loadu_si512((const libxsmm_bfloat16*)in+m);
n1 = zero_reg;
out_lo = _mm512_permutex2var_epi16(n0, idx_lo, n1);
out_hi = _mm512_permutex2var_epi16(n0, idx_lo, n1);
_mm512_storeu_si512((libxsmm_bfloat16*)out+m*2, out_lo);
_mm512_storeu_si512((libxsmm_bfloat16*)out+m*2+32, out_hi);
#ifdef USE_CLDEMOTE
_mm_cldemote((libxsmm_bfloat16*)out+m*2);
_mm_cldemote((libxsmm_bfloat16*)out+m*2+32);
#endif
}
}
#else
LIBXSMM_UNUSED(_in); LIBXSMM_UNUSED(_out); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
#endif
}
#undef USE_CLDEMOTE
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
element_input_type alpha = (element_input_type)1;
element_input_type beta = (element_input_type)0;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
typedef libxsmm_smmfunction gemm_function;
gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef float element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_smmfunction gemm_function;
libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
float alpha = (element_input_type)1;
float beta = (element_input_type)0;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
# define LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
# define LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
# undef LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.smrs;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.smrs;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.smrs;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.smrs;
#define LIBXSMM_DNN_FC_BWD_USE_AVX512
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
#undef LIBXSMM_DNN_FC_BWD_USE_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.bmrs;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.bmrs;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.bmrs;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.bmrs;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
return libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid );
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd3.xgemm.bmrs;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_upd_zerobeta = handle->gemm_upd3.xgemm.bmrs;
libxsmm_bsmmfunction bwd_tile_config_kernel = handle->bwd_config_kernel;
/*libxsmm_bsmmfunction upd_tile_config_kernel = handle->upd_config_kernel;*/
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
return libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu(handle, kind, start_thread, tid);
}
#endif
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd3.xgemm.bmrs;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_upd_zerobeta = handle->gemm_upd3.xgemm.bmrs;
libxsmm_bsmmfunction bwd_tile_config_kernel = handle->bwd_config_kernel;
/*libxsmm_bsmmfunction upd_tile_config_kernel = handle->upd_config_kernel;*/
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if all required tensors are bound */
if ( kind == LIBXSMM_DNN_COMPUTE_KIND_BWD ) {
if (handle->grad_input == 0 || handle->grad_output == 0 ||
handle->reg_filter == 0 || handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
} else if ( kind == LIBXSMM_DNN_COMPUTE_KIND_UPD ) {
if (handle->reg_input == 0 || handle->grad_output == 0 ||
handle->grad_filter == 0 || handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
} else {
if (handle->grad_input == 0 || handle->grad_output == 0 ||
handle->reg_input == 0 || handle->grad_filter == 0 ||
handle->reg_filter == 0 || handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32( handle, kind, start_thread, tid);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__*/
else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32( handle, kind, start_thread, tid);
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
element_input_type alpha = (element_input_type)1;
element_input_type beta = (element_input_type)0;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef libxsmm_bfloat16 element_input_type;
typedef float element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_smmfunction gemm_function;
libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
float alpha = (element_input_type)1;
float beta = (element_input_type)0;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
# define LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
# define LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
# undef LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
int l_emu_amx = 0;
const char *const l_env_emu_amx = getenv("EMULATE_AMX");
if ( 0 == l_env_emu_amx ) {
} else {
l_emu_amx = atoi(l_env_emu_amx);
}
/* check if all required tensors are bound */
if ( kind == LIBXSMM_DNN_COMPUTE_KIND_BWD ) {
if (handle->grad_input == 0 || handle->grad_output == 0 ||
handle->reg_filter == 0 || handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
} else if ( kind == LIBXSMM_DNN_COMPUTE_KIND_UPD ) {
if (handle->reg_input == 0 || handle->grad_output == 0 ||
handle->grad_filter == 0 || handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
} else {
if (handle->grad_input == 0 || handle->grad_output == 0 ||
handle->reg_input == 0 || handle->grad_filter == 0 ||
handle->reg_filter == 0 || handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) != 0) && ( handle->grad_bias == 0 ) ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) != 0) && ( handle->relumask == 0 ) ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32( handle, kind, start_thread, tid);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16( handle, kind, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
if ( l_emu_amx == 0 ) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid);
} else {
status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu( handle, kind, start_thread, tid);
}
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR ) {
if ( l_emu_amx == 0 ) {
status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid);
} else {
status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_amx_emu( handle, kind, start_thread, tid);
}
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
LIBXSMM_UNUSED( l_emu_amx );
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.smrs;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.smrs;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.smrs;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.smrs;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_nhwc(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
LIBXSMM_UNUSED( handle );
LIBXSMM_UNUSED( kind );
LIBXSMM_UNUSED( start_thread );
LIBXSMM_UNUSED( tid );
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_FULLYCONNECTED_BACKWARD_WEIGHT_UPDATE_H
#define LIBXSMM_DNN_FULLYCONNECTED_BACKWARD_WEIGHT_UPDATE_H
#include <libxsmm_dnn_fullyconnected.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_nhwc(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
#endif /* LIBXSMM_DNN_FULLYCONNECTED_BACKWARD_WEIGHT_UPDATE_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_fullyconnected_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
element_input_type alpha = (element_input_type)1;
element_input_type beta = (element_input_type)0;
libxsmm_blasint lda = (libxsmm_blasint)handle->ofmblock;
libxsmm_blasint ldb = (libxsmm_blasint)handle->desc.C;
libxsmm_blasint ldc = (libxsmm_blasint)handle->desc.K;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->desc.N, handle->desc.C, &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
# include "template/libxsmm_dnn_fullyconnected_st_fwd_custom_generic.tpl.c"
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef float element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
typedef libxsmm_smmfunction gemm_function;
libxsmm_blasint lda = (libxsmm_blasint)handle->ofmblock;
libxsmm_blasint ldb = (libxsmm_blasint)handle->desc.C;
libxsmm_blasint ldc = (libxsmm_blasint)handle->desc.K;
float alpha = (element_input_type)1;
float beta = (element_input_type)0;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->desc.N, handle->desc.C, &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
# define LIBXSMM_DNN_FULLYCONNECTED_FWD_BF16_F32
# include "template/libxsmm_dnn_fullyconnected_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FULLYCONNECTED_FWD_BF16_F32
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_beta = handle->gemm_fwd.xgemm.smrs;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_zerobeta = handle->gemm_fwd2.xgemm.smrs;
#define LIBXSMM_DNN_FC_FWD_USE_AVX512
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_NONE
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
#undef LIBXSMM_DNN_FC_FWD_USE_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel = handle->gemm_fwd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_zerobeta = handle->gemm_fwd2.xgemm.bmrs;
libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_beta = handle->gemm_fwd3.xgemm.bmrs;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_NONE
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel = handle->gemm_fwd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_zerobeta = handle->gemm_fwd2.xgemm.bmrs;
libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_beta = handle->gemm_fwd3.xgemm.bmrs;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_NONE
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
return libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu( handle, start_thread, tid );
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel = handle->gemm_fwd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_zerobeta = handle->gemm_fwd3.xgemm.bmrs;
libxsmm_bsmmfunction tile_config_kernel = handle->fwd_config_kernel;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if (handle->compressed_A == 1) {
libxsmm_bsmmfunction_reducebatch_strd_meltwfused batchreduce_kernel_decompress = handle->gemm_fwd9.xgemm.bsmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_decompress = handle->gemm_fwd11.xgemm.bmrs_meltwfused;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_NONE
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd4.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd12.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd5.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd13.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd6.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd14.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd7.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd15.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd8.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd16.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
} else {
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_NONE
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd4.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd5.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd6.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd7.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd8.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid) {
return libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu( handle, start_thread, tid );
}
#endif
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel = handle->gemm_fwd.xgemm.bsmrs;
libxsmm_bmmfunction_reducebatch_strd bf16_batchreduce_kernel_zerobeta = handle->gemm_fwd3.xgemm.bmrs;
libxsmm_bsmmfunction tile_config_kernel = handle->fwd_config_kernel;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if (handle->compressed_A == 1) {
libxsmm_bsmmfunction_reducebatch_strd_meltwfused batchreduce_kernel_decompress = handle->gemm_fwd9.xgemm.bsmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_decompress = handle->gemm_fwd11.xgemm.bmrs_meltwfused;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_NONE
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd4.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd12.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd5.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd13.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd6.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd14.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd7.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd15.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd8.xgemm.bmrs_meltwfused;
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise_decompress = handle->gemm_fwd16.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_sparse_A_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
} else {
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_NONE
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd4.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd5.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd6.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd7.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
libxsmm_bmmfunction_reducebatch_strd_meltwfused bf16_batchreduce_kernel_zerobeta_fused_eltwise = handle->gemm_fwd8.xgemm.bmrs_meltwfused;
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if all required tensors are bound */
if (handle->reg_input == 0 || handle->reg_output == 0 ||
handle->reg_filter == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fullyconnected_st_fwd_custom_f32_f32( handle, start_thread, tid);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE ) {
status = libxsmm_dnn_fullyconnected_st_fwd_custom_bf16_f32( handle, start_thread, tid);
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
typedef libxsmm_smmfunction gemm_function;
libxsmm_blasint lda = (libxsmm_blasint)handle->ofmblock;
libxsmm_blasint ldb = (libxsmm_blasint)handle->desc.C;
libxsmm_blasint ldc = (libxsmm_blasint)handle->desc.K;
element_input_type beta = (element_input_type)0;
element_input_type alpha = (element_input_type)1;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
gemm_function gemm_kernel = libxsmm_smmdispatch(handle->ofmblock, handle->desc.N, handle->desc.C, &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
# include "template/libxsmm_dnn_fullyconnected_st_fwd_custom_generic.tpl.c"
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
int l_emu_amx = 0;
const char *const l_env_emu_amx = getenv("EMULATE_AMX");
if ( 0 == l_env_emu_amx ) {
} else {
l_emu_amx = atoi(l_env_emu_amx);
}
/* check if all required tensors are bound */
if (handle->reg_input == 0 || handle->reg_output == 0 ||
handle->reg_filter == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) != 0) && ( handle->reg_bias == 0 ) ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) != 0) && ( handle->relumask == 0 ) ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (handle->target_archid >= LIBXSMM_X86_AVX512) && (handle->target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_f32_f32( handle, start_thread, tid);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_CPX) {
status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CPX && handle->target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR) {
if ( l_emu_amx == 0 ) {
status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx( handle, start_thread, tid);
} else {
status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu( handle, start_thread, tid);
}
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_CORE && handle->target_archid < LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_emu( handle, start_thread, tid);
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && handle->target_archid >= LIBXSMM_X86_AVX512_SPR ) {
if ( l_emu_amx == 0 ) {
status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx( handle, start_thread, tid);
} else {
status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_bf16_bf16_amx_emu( handle, start_thread, tid);
}
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
LIBXSMM_UNUSED( l_emu_amx );
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_beta = handle->gemm_fwd.xgemm.smrs;
libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_zerobeta = handle->gemm_fwd2.xgemm.smrs;
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_NONE
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_NONE
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_RELU
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_RELU
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
#define LIBXSMM_DNN_FC_FWD_FUSE_BIAS
#define LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
# include "template/libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck_generic.tpl.c"
#undef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
#undef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
} else {
status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_nhwc(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
LIBXSMM_UNUSED( handle );
LIBXSMM_UNUSED( start_thread );
LIBXSMM_UNUSED( tid );
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_FULLYCONNECTED_FORWARD_H
#define LIBXSMM_DNN_FULLYCONNECTED_FORWARD_H
#include <libxsmm_dnn_fullyconnected.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_custom(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_fwd_nhwc(libxsmm_dnn_fullyconnected* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_FULLYCONNECTED_FORWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_fusedbatchnorm_backward.h"
#include "libxsmm_dnn_fusedbatchnorm_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API libxsmm_dnn_fusedbatchnorm* libxsmm_dnn_create_fusedbatchnorm(libxsmm_dnn_fusedbatchnorm_desc fusedbatchnorm_desc, libxsmm_dnn_err_t* status) {
libxsmm_dnn_fusedbatchnorm* handle = 0;
int lpb;
/* init libxsmm */
LIBXSMM_INIT
if ( fusedbatchnorm_desc.partN > fusedbatchnorm_desc.fullN ) {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
return handle;
} else if ( (fusedbatchnorm_desc.partN != fusedbatchnorm_desc.fullN) && ((fusedbatchnorm_desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) == 0 ) && ((fusedbatchnorm_desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) == 0 ) ) {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
return handle;
} else {
}
if ( ((fusedbatchnorm_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fusedbatchnorm_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ||
((fusedbatchnorm_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (fusedbatchnorm_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle = (libxsmm_dnn_fusedbatchnorm*)calloc(1, sizeof(libxsmm_dnn_fusedbatchnorm));
if (0 != handle) {
*status = LIBXSMM_DNN_SUCCESS;
/* let's make the description persistent */
handle->desc = fusedbatchnorm_desc;
/* we need to compute the memory layout given the */
*status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C,
&(handle->ifmblock), &(handle->ofmblock), &lpb,
handle->desc.datatype_in, handle->desc.datatype_out );
/* compute the outer blocks */
handle->blocksifm = handle->desc.C / handle->ifmblock;
handle->blocksofm = handle->desc.C / handle->ofmblock;
/* create barrier */
handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
/* calculate scratch size for batchstats */
handle->scratch_size = (sizeof(float) * 2 * handle->desc.C * handle->desc.partN);
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
}
} else {
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
return handle;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fusedbatchnorm(const libxsmm_dnn_fusedbatchnorm* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
/* Deallocate barrier */
if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
/* deallocate handle structure */
free(/*remove constness*/(libxsmm_dnn_fusedbatchnorm*)handle);
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fusedbatchnorm_create_tensor_datalayout(const libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor_datalayout* layout;
*status = LIBXSMM_DNN_SUCCESS;
layout = 0;
if (handle != 0) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout));
if (layout != 0) {
layout->format = handle->desc.buffer_format;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ||
(type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) {
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.partN;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.partN;
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.partN;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.partN;
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) {
layout->dim_size[0] = handle->desc.C;
layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
layout->dim_size[3] = handle->desc.partN;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->desc.C;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->desc.partN;
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) || (type == LIBXSMM_DNN_CHANNEL_BETA) ||
(type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) || (type == LIBXSMM_DNN_CHANNEL_GAMMA) ||
(type == LIBXSMM_DNN_CHANNEL_EXPECTVAL) || (type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV) || (type == LIBXSMM_DNN_CHANNEL_VARIANCE) ) {
layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) {
layout->datatype = handle->desc.datatype_stats;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 2;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->blocksifm;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) {
layout->datatype = handle->desc.datatype_stats;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 1;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->desc.C;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_RELU_MASK) ) {
layout->tensor_type = LIBXSMM_DNN_RELU_MASK;
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
layout->datatype = LIBXSMM_DNN_DATATYPE_I8;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.partN;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
layout->datatype = LIBXSMM_DNN_DATATYPE_I8;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 6;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ofmblock*handle->blocksofm;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->desc.partN;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return layout;
}
LIBXSMM_API size_t libxsmm_dnn_fusedbatchnorm_get_scratch_size(const libxsmm_dnn_fusedbatchnorm* handle, libxsmm_dnn_err_t* status) {
size_t l_scratch_size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return l_scratch_size;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_bind_scratch(libxsmm_dnn_fusedbatchnorm* handle, const void* scratch) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
uintptr_t address = (uintptr_t)scratch;
size_t offset = 0;
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
if (0 != handle) {
/* align the internal scratch buffer if needed */
if (address % 64 == 0) {
handle->scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch = (void*)(address+offset);
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_release_scratch(libxsmm_dnn_fusedbatchnorm* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
handle->scratch = 0;
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_bind_tensor(libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) &&
(type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) &&
(type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0 && tensor != 0) {
libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_fusedbatchnorm_create_tensor_datalayout(handle, type, &status);
if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) {
handle->reg_add = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) {
handle->grad_add = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) {
handle->reg_beta = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) {
handle->grad_beta = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) {
handle->reg_gamma = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) {
handle->grad_gamma = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) {
handle->expvalue = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) {
handle->rcpstddev = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) {
handle->variance = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RELU_MASK ) {
handle->relumask = (libxsmm_dnn_tensor*)tensor;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
}
libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fusedbatchnorm_get_tensor(libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor* return_tensor = 0;
*status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) &&
(type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) &&
(type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) {
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return return_tensor;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
return_tensor = handle->reg_input;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
return_tensor = handle->grad_input;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
return_tensor = handle->reg_output;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
return_tensor = handle->grad_output;
} else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) {
return_tensor = handle->reg_add;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) {
return_tensor = handle->grad_add;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) {
return_tensor = handle->reg_beta;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) {
return_tensor = handle->grad_beta;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) {
return_tensor = handle->reg_gamma;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) {
return_tensor = handle->grad_gamma;
} else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) {
return_tensor = handle->expvalue;
} else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) {
return_tensor = handle->rcpstddev;
} else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) {
return_tensor = handle->variance;
} else if ( type == LIBXSMM_DNN_RELU_MASK ) {
return_tensor = handle->relumask;
} else {
/* cannot happen */
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return return_tensor;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_release_tensor(libxsmm_dnn_fusedbatchnorm* handle, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) &&
(type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) &&
(type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) {
handle->reg_add = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) {
handle->grad_add = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) {
handle->reg_beta = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) {
handle->grad_beta = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) {
handle->reg_gamma = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) {
handle->grad_gamma = 0;
} else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) {
handle->expvalue = 0;
} else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) {
handle->rcpstddev = 0;
} else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) {
handle->variance = 0;
} else if ( type == LIBXSMM_DNN_RELU_MASK ) {
handle->relumask = 0;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_execute_st(libxsmm_dnn_fusedbatchnorm* handle, libxsmm_dnn_compute_kind kind,
/*unsigned*/int start_thread, /*unsigned*/int tid) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
switch (handle->desc.buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom( handle, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
}
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD: {
switch (handle->desc.buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom( handle, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, libxsmm_dnn_compute_kind kind,
/*unsigned*/int start_thread, /*unsigned*/int tid) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handles && num_handles > 0) {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
switch (handles[0]->desc.buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom( handles, num_handles, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
}
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD: {
switch (handles[0]->desc.buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_bwd_custom( handles, num_handles, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_fusedbatchnorm_backward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDBN_BWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDBN_BWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDBN_BWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDBN_BWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDBN_BWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDBN_BWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if all required tensors are bound */
if ( handle->reg_input == 0 || handle->reg_gamma == 0 ||
handle->grad_input == 0 || handle->grad_output == 0 ||
handle->grad_beta == 0 || handle->grad_gamma == 0 ||
handle->expvalue == 0 || handle->rcpstddev == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0 ) {
if ( handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) > 0 ) {
if ( handle->grad_add == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) > 0 ) {
if ( handle->reg_output == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) > 0 ) {
if ( handle->relumask == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 16) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c16( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c16( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 32) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c32( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c32( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 64) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_f32_f32_c64( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedbatchnorm_st_bwd_custom_bf16_bf16_c64( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDBN_BWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) {
# define LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDBN_BWD_BF16
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
LIBXSMM_UNUSED( handle );
LIBXSMM_UNUSED( start_thread );
LIBXSMM_UNUSED( tid );
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_bwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
int l_count;
/* check if all required tensors are bound */
for ( l_count = 0; l_count < num_handles; ++l_count ) {
if ( handles[l_count]->grad_beta == 0 || handles[l_count]->grad_gamma == 0 || handles[l_count]->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
#if 0
/* check if we are on an AVX512 platform */
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_bwd_custom_avx512( handles, num_handles, start_thread, tid );
} else
#endif
{
const int nImg = handles[0]->desc.partN;
const int nBlocksFm = handles[0]->blocksifm;
const int nFmBlock = handles[0]->ifmblock;
/* computing first logical thread */
const int ltid = tid - start_thread;
/* number of tasks that could be run in parallel */
const int work2 = nBlocksFm;
/* compute chunk size */
const int chunksize2 = (work2 % handles[0]->desc.threads == 0) ? (work2 / handles[0]->desc.threads) : ((work2 / handles[0]->desc.threads) + 1);
/* compute thr_begin and thr_end */
const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2;
const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2;
int v = 0, fm;
LIBXSMM_VLA_DECL(2, float, dgamma0, (float*)handles[0]->grad_gamma->data, nFmBlock);
LIBXSMM_VLA_DECL(2, float, dbeta0, (float*)handles[0]->grad_beta->data, nFmBlock);
LIBXSMM_VLA_DECL(3, float, dgamma_img0, (float*)handles[0]->scratch, nImg, nFmBlock);
LIBXSMM_VLA_DECL(3, float, dbeta_img0, ((float*)handles[0]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock);
/* lazy barrier init */
libxsmm_barrier_init(handles[0]->barrier, ltid);
for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock);
float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock);
float* dgamma_img0_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_img0, fm, 0, 0, nImg, nFmBlock);
float* dbeta_img0_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_img0, fm, 0, 0, nImg, nFmBlock);
LIBXSMM_PRAGMA_SIMD
for ( v=0; v < nFmBlock; v++ ) {
dgamma0_ptr[v] = dgamma_img0_ptr[v];
dbeta0_ptr[v] = dbeta_img0_ptr[v];
}
}
/* now we need to reduce the dgamma and dbeta */
for ( l_count = 1; l_count < num_handles; ++l_count ) {
LIBXSMM_VLA_DECL(3, float, dgamma_imgr, (float*)handles[l_count]->scratch, nImg, nFmBlock);
LIBXSMM_VLA_DECL(3, float, dbeta_imgr, ((float*)handles[l_count]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock);
for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock);
float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock);
float* dgamma_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, dgamma_imgr, fm, 0, 0, nImg, nFmBlock);
float* dbeta_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, dbeta_imgr, fm, 0, 0, nImg, nFmBlock);
LIBXSMM_PRAGMA_SIMD
for ( v=0; v < nFmBlock; v++ ) {
dgamma0_ptr[v] += dgamma_imgr_ptr[v];
dbeta0_ptr[v] += dbeta_imgr_ptr[v];
}
}
}
for ( l_count = 1; l_count < num_handles; ++l_count ) {
LIBXSMM_VLA_DECL(2, float, dgammar, (float*)handles[l_count]->grad_gamma->data, nFmBlock);
LIBXSMM_VLA_DECL(2, float, dbetar, (float*)handles[l_count]->grad_beta->data, nFmBlock);
for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock);
float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock);
float* dgammar_ptr = &LIBXSMM_VLA_ACCESS(2, dgammar, fm, 0, nFmBlock);
float* dbetar_ptr = &LIBXSMM_VLA_ACCESS(2, dbetar, fm, 0, nFmBlock);
LIBXSMM_PRAGMA_SIMD
for ( v=0; v < nFmBlock; v++ ) {
dgammar_ptr[v] = dgamma0_ptr[v];
dbetar_ptr[v] = dbeta0_ptr[v];
}
}
}
libxsmm_barrier_wait(handles[0]->barrier, ltid);
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_FUSEDBATCHNORM_BACKWARD_H
#define LIBXSMM_DNN_FUSEDBATCHNORM_BACKWARD_H
#include <libxsmm_dnn_fusedbatchnorm.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_bwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_bwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid);
#endif /* LIBXSMM_DNN_FUSEDBATCHNORM_BACKWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_fusedbatchnorm_forward.h"
#include "libxsmm_main.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDBN_FWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDBN_FWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDBN_FWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if all required tensors are bound */
if ( handle->reg_input == 0 || handle->reg_output == 0 ||
handle->reg_beta == 0 || handle->reg_gamma == 0 ||
handle->expvalue == 0 || handle->rcpstddev == 0 || handle->variance == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0 ) {
if ( handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) > 0 ) {
if ( handle->reg_add == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) > 0 ) {
if ( handle->relumask == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 16) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 32) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 64) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDBN_FWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
(handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
LIBXSMM_UNUSED( handle );
LIBXSMM_UNUSED( start_thread );
LIBXSMM_UNUSED( tid );
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
int l_count;
/* check if all required tensors are bound */
for ( l_count = 0; l_count < num_handles; ++l_count ) {
if ( handles[l_count]->expvalue == 0 || handles[l_count]->rcpstddev == 0 || handles[l_count]->variance == 0 || handles[l_count]->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
#if 0
/* check if we are on an AVX512 platform */
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom_avx512( handles, num_handles, start_thread, tid );
} else
#endif
{
const int nImg = handles[0]->desc.partN;
const int nBlocksFm = handles[0]->blocksifm;
const int nFmBlock = handles[0]->ifmblock;
/* computing first logical thread */
const int ltid = tid - start_thread;
/* number of tasks that could be run in parallel */
const int work2 = nBlocksFm;
/* compute chunk size */
const int chunksize2 = (work2 % handles[0]->desc.threads == 0) ? (work2 / handles[0]->desc.threads) : ((work2 / handles[0]->desc.threads) + 1);
/* compute thr_begin and thr_end */
const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2;
const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2;
int v = 0, fm;
const float sqrt_eps = 1e-7f;
const float nhw = (float)(handles[0]->desc.fullN * handles[0]->desc.H * handles[0]->desc.W);
const float recp_nhw = 1.0f/nhw;
LIBXSMM_VLA_DECL(2, float, bmean0, (float*)handles[0]->expvalue->data, nFmBlock);
LIBXSMM_VLA_DECL(2, float, brstd0, (float*)handles[0]->rcpstddev->data, nFmBlock);
LIBXSMM_VLA_DECL(2, float, variance0, (float*)handles[0]->variance->data, nFmBlock);
LIBXSMM_VLA_DECL(3, float, sum_img0, (float*)handles[0]->scratch, nImg, nFmBlock);
LIBXSMM_VLA_DECL(3, float, sumsq_img0, ((float*)handles[0]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock);
/* lazy barrier init */
libxsmm_barrier_init(handles[0]->barrier, ltid);
/* now we need to reduce the sum and sum^2, we use the final */
for ( l_count = 1; l_count < num_handles; ++l_count ) {
LIBXSMM_VLA_DECL(3, float, sum_imgr, (float*)handles[l_count]->scratch, nImg, nFmBlock);
LIBXSMM_VLA_DECL(3, float, sumsq_imgr, ((float*)handles[l_count]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock);
for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
float* sum_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img0, fm, 0, 0, nImg, nFmBlock);
float* sumsq_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img0, fm, 0, 0, nImg, nFmBlock);
float* sum_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, sum_imgr, fm, 0, 0, nImg, nFmBlock);
float* sumsq_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_imgr, fm, 0, 0, nImg, nFmBlock);
LIBXSMM_PRAGMA_SIMD
for ( v=0; v < nFmBlock; v++ ) {
sum_img0_ptr[v] += sum_imgr_ptr[v];
sumsq_img0_ptr[v] += sumsq_imgr_ptr[v];
}
}
}
for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
float* bmean0_ptr = &LIBXSMM_VLA_ACCESS(2, bmean0, fm, 0, nFmBlock);
float* brstd0_ptr = &LIBXSMM_VLA_ACCESS(2, brstd0, fm, 0, nFmBlock);
float* tvar0_ptr = &LIBXSMM_VLA_ACCESS(2, variance0, fm, 0, nFmBlock);
float* sum_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img0, fm, 0, 0, nImg, nFmBlock);
float* sumsq_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img0, fm, 0, 0, nImg, nFmBlock);
LIBXSMM_PRAGMA_SIMD
for ( v=0; v < nFmBlock; v++ ) {
const float tbmean = (recp_nhw * sum_img0_ptr[v]);
const float tbmeansq = tbmean * tbmean;
const float tsqbmean = recp_nhw * sumsq_img0_ptr[v];
const float tvar = tsqbmean - tbmeansq;
const float tbrstd = (float)(1.0/sqrt((double)tvar + sqrt_eps));
bmean0_ptr[v] = tbmean;
brstd0_ptr[v] = tbrstd;
tvar0_ptr[v] = tvar;
}
}
for ( l_count = 1; l_count < num_handles; ++l_count ) {
LIBXSMM_VLA_DECL(2, float, bmeanr, (float*)handles[l_count]->expvalue->data, nFmBlock);
LIBXSMM_VLA_DECL(2, float, brstdr, (float*)handles[l_count]->rcpstddev->data, nFmBlock);
LIBXSMM_VLA_DECL(2, float, variancer, (float*)handles[l_count]->variance->data, nFmBlock);
for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
float* bmean0_ptr = &LIBXSMM_VLA_ACCESS(2, bmean0, fm, 0, nFmBlock);
float* brstd0_ptr = &LIBXSMM_VLA_ACCESS(2, brstd0, fm, 0, nFmBlock);
float* tvar0_ptr = &LIBXSMM_VLA_ACCESS(2, variance0, fm, 0, nFmBlock);
float* bmeanr_ptr = &LIBXSMM_VLA_ACCESS(2, bmeanr, fm, 0, nFmBlock);
float* brstdr_ptr = &LIBXSMM_VLA_ACCESS(2, brstdr, fm, 0, nFmBlock);
float* tvarr_ptr = &LIBXSMM_VLA_ACCESS(2, variancer, fm, 0, nFmBlock);
LIBXSMM_PRAGMA_SIMD
for ( v=0; v < nFmBlock; v++ ) {
bmeanr_ptr[v] = bmean0_ptr[v];
brstdr_ptr[v] = brstd0_ptr[v];
tvarr_ptr[v] = tvar0_ptr[v];
}
}
}
libxsmm_barrier_wait(handles[0]->barrier, ltid);
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_FUSEDBATCHNORM_FORWARD_H
#define LIBXSMM_DNN_FUSEDBATCHNORM_FORWARD_H
#include <libxsmm_dnn_fusedbatchnorm.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid);
#endif /* LIBXSMM_DNN_FUSEDBATCHNORM_FORWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_fusedgroupnorm_backward.h"
#include "libxsmm_dnn_fusedgroupnorm_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API libxsmm_dnn_fusedgroupnorm* libxsmm_dnn_create_fusedgroupnorm(libxsmm_dnn_fusedgroupnorm_desc fusedgroupnorm_desc, libxsmm_dnn_err_t* status) {
libxsmm_dnn_fusedgroupnorm* handle = 0;
int lpb;
/* init libxsmm */
LIBXSMM_INIT
if ( ((fusedgroupnorm_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fusedgroupnorm_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ||
((fusedgroupnorm_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (fusedgroupnorm_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle = (libxsmm_dnn_fusedgroupnorm*)calloc(1, sizeof(libxsmm_dnn_fusedgroupnorm));
if (0 != handle) {
*status = LIBXSMM_DNN_SUCCESS;
/* let's make the description persistent */
handle->desc = fusedgroupnorm_desc;
/* we need to compute the memory layout given the */
*status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C,
&(handle->ifmblock), &(handle->ofmblock), &lpb,
handle->desc.datatype_in, handle->desc.datatype_out );
/* compute the outer blocks */
handle->blocksifm = handle->desc.C / handle->ifmblock;
handle->blocksofm = handle->desc.C / handle->ofmblock;
/* create barrier */
handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
/* calculate scratch size for batchstats */
handle->scratch_size = (sizeof(float) * 2 * ((handle->desc.C * handle->desc.N) + (handle->desc.G * handle->desc.N)));
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
}
} else {
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
return handle;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fusedgroupnorm(const libxsmm_dnn_fusedgroupnorm* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
/* Deallocate barrier */
if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
/* deallocate handle structure */
free(/*remove constness*/(libxsmm_dnn_fusedgroupnorm*)handle);
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fusedgroupnorm_create_tensor_datalayout(const libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor_datalayout* layout;
*status = LIBXSMM_DNN_SUCCESS;
layout = 0;
if (handle != 0) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout));
if (layout != 0) {
layout->format = handle->desc.buffer_format;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ||
(type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) {
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_INPUT_ADD) || (type == LIBXSMM_DNN_GRADIENT_INPUT_ADD) ) {
layout->dim_size[0] = handle->desc.C;
layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
layout->dim_size[3] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->desc.C;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->desc.N;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) || (type == LIBXSMM_DNN_CHANNEL_BETA) ||
(type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) || (type == LIBXSMM_DNN_CHANNEL_GAMMA) ) {
layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) {
layout->datatype = handle->desc.datatype_stats;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 2;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->blocksifm;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) {
layout->datatype = handle->desc.datatype_stats;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 1;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->desc.C;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_CHANNEL_EXPECTVAL) || (type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV) || (type == LIBXSMM_DNN_CHANNEL_VARIANCE) ) {
layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) || ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) ) {
if ( handle->desc.datatype_stats == LIBXSMM_DNN_DATATYPE_F32 ) {
layout->datatype = handle->desc.datatype_stats;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 2;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_G;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->desc.G;
layout->dim_size[1] = handle->desc.N;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_RELU_MASK) ) {
layout->tensor_type = LIBXSMM_DNN_RELU_MASK;
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
layout->datatype = LIBXSMM_DNN_DATATYPE_I8;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
layout->datatype = LIBXSMM_DNN_DATATYPE_I8;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 6;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->ofmblock*handle->blocksofm;
layout->dim_size[1] = (handle->desc.W/handle->desc.v) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->desc.H/handle->desc.u) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->desc.N;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return layout;
}
LIBXSMM_API size_t libxsmm_dnn_fusedgroupnorm_get_scratch_size(const libxsmm_dnn_fusedgroupnorm* handle, libxsmm_dnn_err_t* status) {
size_t l_scratch_size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return l_scratch_size;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_bind_scratch(libxsmm_dnn_fusedgroupnorm* handle, const void* scratch) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
uintptr_t address = (uintptr_t)scratch;
size_t offset = 0;
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
if (0 != handle) {
/* align the internal scratch buffer if needed */
if (address % 64 == 0) {
handle->scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch = (void*)(address+offset);
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_release_scratch(libxsmm_dnn_fusedgroupnorm* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
handle->scratch = 0;
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_bind_tensor(libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) &&
(type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) &&
(type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0 && tensor != 0) {
libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_fusedgroupnorm_create_tensor_datalayout(handle, type, &status);
if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) {
handle->reg_add = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) {
handle->grad_add = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) {
handle->reg_beta = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) {
handle->grad_beta = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) {
handle->reg_gamma = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) {
handle->grad_gamma = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) {
handle->expvalue = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) {
handle->rcpstddev = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) {
handle->variance = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RELU_MASK ) {
handle->relumask = (libxsmm_dnn_tensor*)tensor;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
}
libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fusedgroupnorm_get_tensor(libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor* return_tensor = 0;
*status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) &&
(type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) &&
(type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) {
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return return_tensor;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
return_tensor = handle->reg_input;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
return_tensor = handle->grad_input;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
return_tensor = handle->reg_output;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
return_tensor = handle->grad_output;
} else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) {
return_tensor = handle->reg_add;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) {
return_tensor = handle->grad_add;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) {
return_tensor = handle->reg_beta;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) {
return_tensor = handle->grad_beta;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) {
return_tensor = handle->reg_gamma;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) {
return_tensor = handle->grad_gamma;
} else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) {
return_tensor = handle->expvalue;
} else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) {
return_tensor = handle->rcpstddev;
} else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) {
return_tensor = handle->variance;
} else if ( type == LIBXSMM_DNN_RELU_MASK ) {
return_tensor = handle->relumask;
} else {
/* cannot happen */
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return return_tensor;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_release_tensor(libxsmm_dnn_fusedgroupnorm* handle, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_REGULAR_INPUT_ADD) && (type != LIBXSMM_DNN_GRADIENT_INPUT_ADD) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_BETA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BETA) &&
(type != LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA) &&
(type != LIBXSMM_DNN_CHANNEL_EXPECTVAL) && (type != LIBXSMM_DNN_CHANNEL_RCPSTDDEV) &&
(type != LIBXSMM_DNN_CHANNEL_VARIANCE) && (type != LIBXSMM_DNN_RELU_MASK) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_INPUT_ADD ) {
handle->reg_add = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT_ADD ) {
handle->grad_add = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BETA ) {
handle->reg_beta = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BETA ) {
handle->grad_beta = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA ) {
handle->reg_gamma = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA ) {
handle->grad_gamma = 0;
} else if ( type == LIBXSMM_DNN_CHANNEL_EXPECTVAL ) {
handle->expvalue = 0;
} else if ( type == LIBXSMM_DNN_CHANNEL_RCPSTDDEV ) {
handle->rcpstddev = 0;
} else if ( type == LIBXSMM_DNN_CHANNEL_VARIANCE ) {
handle->variance = 0;
} else if ( type == LIBXSMM_DNN_RELU_MASK ) {
handle->relumask = 0;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_execute_st(libxsmm_dnn_fusedgroupnorm* handle, libxsmm_dnn_compute_kind kind,
/*unsigned*/int start_thread, /*unsigned*/int tid) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
switch (handle->desc.buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom( handle, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
}
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD: {
switch (handle->desc.buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom( handle, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_reduce_stats_st(libxsmm_dnn_fusedgroupnorm** handles, int num_handles, libxsmm_dnn_compute_kind kind,
/*unsigned*/int start_thread, /*unsigned*/int tid) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handles && num_handles > 0) {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_BWD: {
switch (handles[0]->desc.buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_fusedgroupnorm_reduce_stats_st_bwd_custom( handles, num_handles, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_fusedgroupnorm_backward.h"
#include "libxsmm_main.h"
#if 0
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDGN_BWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDGN_BWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDGN_BWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDGN_BWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDGN_BWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDGN_BWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#endif
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if all required tensors are bound */
if ( handle->reg_input == 0 || handle->reg_gamma == 0 ||
handle->grad_input == 0 || handle->grad_output == 0 ||
handle->grad_beta == 0 || handle->grad_gamma == 0 ||
handle->expvalue == 0 || handle->rcpstddev == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_GN) > 0 ) {
if ( handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
if ( handle->grad_add == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
if ( handle->reg_output == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
if ( handle->relumask == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
/* check if we are on an AVX512 platform */
#if 0
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 16) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c16( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c16( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 32) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c32( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c32( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 64) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_f32_f32_c64( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedgroupnorm_st_bwd_custom_bf16_bf16_c64( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDGN_BWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_BWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDGN_BWD_BF16
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_nhwc(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
LIBXSMM_UNUSED( handle );
LIBXSMM_UNUSED( start_thread );
LIBXSMM_UNUSED( tid );
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_reduce_stats_st_bwd_custom(libxsmm_dnn_fusedgroupnorm** handles, int num_handles, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
int l_count;
/* check if all required tensors are bound */
for ( l_count = 0; l_count < num_handles; ++l_count ) {
if ( handles[l_count]->grad_beta == 0 || handles[l_count]->grad_gamma == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
#if 0
/* check if we are on an AVX512 platform */
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
status = libxsmm_dnn_fusedgroupnorm_reduce_stats_st_bwd_custom_avx512( handles, num_handles, start_thread, tid );
} else
#endif
{
const int nBlocksFm = handles[0]->blocksifm;
const int nFmBlock = handles[0]->ifmblock;
/* computing first logical thread */
const int ltid = tid - start_thread;
/* number of tasks that could be run in parallel */
const int work2 = nBlocksFm;
/* compute chunk size */
const int chunksize2 = (work2 % handles[0]->desc.threads == 0) ? (work2 / handles[0]->desc.threads) : ((work2 / handles[0]->desc.threads) + 1);
/* compute thr_begin and thr_end */
const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2;
const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2;
int v = 0, fm;
LIBXSMM_VLA_DECL(2, float, dgamma0, (float*)handles[0]->grad_gamma->data, nFmBlock);
LIBXSMM_VLA_DECL(2, float, dbeta0, (float*)handles[0]->grad_beta->data, nFmBlock);
/* lazy barrier init */
libxsmm_barrier_init(handles[0]->barrier, ltid);
/* now we need to reduce the dgamma and dbeta */
for ( l_count = 1; l_count < num_handles; ++l_count ) {
LIBXSMM_VLA_DECL(2, float, dgammar, (float*)handles[l_count]->grad_gamma->data, nFmBlock);
LIBXSMM_VLA_DECL(2, float, dbetar, (float*)handles[l_count]->grad_beta->data, nFmBlock);
for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock);
float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock);
float* dgammar_ptr = &LIBXSMM_VLA_ACCESS(2, dgammar, fm, 0, nFmBlock);
float* dbetar_ptr = &LIBXSMM_VLA_ACCESS(2, dbetar, fm, 0, nFmBlock);
LIBXSMM_PRAGMA_SIMD
for ( v=0; v < nFmBlock; v++ ) {
dgamma0_ptr[v] += dgammar_ptr[v];
dbeta0_ptr[v] += dbetar_ptr[v];
}
}
}
for ( l_count = 1; l_count < num_handles; ++l_count ) {
LIBXSMM_VLA_DECL(2, float, dgammar, (float*)handles[l_count]->grad_gamma->data, nFmBlock);
LIBXSMM_VLA_DECL(2, float, dbetar, (float*)handles[l_count]->grad_beta->data, nFmBlock);
for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
float* dgamma0_ptr = &LIBXSMM_VLA_ACCESS(2, dgamma0, fm, 0, nFmBlock);
float* dbeta0_ptr = &LIBXSMM_VLA_ACCESS(2, dbeta0, fm, 0, nFmBlock);
float* dgammar_ptr = &LIBXSMM_VLA_ACCESS(2, dgammar, fm, 0, nFmBlock);
float* dbetar_ptr = &LIBXSMM_VLA_ACCESS(2, dbetar, fm, 0, nFmBlock);
LIBXSMM_PRAGMA_SIMD
for ( v=0; v < nFmBlock; v++ ) {
dgammar_ptr[v] = dgamma0_ptr[v];
dbetar_ptr[v] = dbeta0_ptr[v];
}
}
}
libxsmm_barrier_wait(handles[0]->barrier, ltid);
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_FUSEDGROUPNORM_BACKWARD_H
#define LIBXSMM_DNN_FUSEDGROUPNORM_BACKWARD_H
#include <libxsmm_dnn_fusedgroupnorm.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_custom(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_bwd_nhwc(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_reduce_stats_st_bwd_custom(libxsmm_dnn_fusedgroupnorm** handles, int num_handles, int start_thread, int tid);
#endif /* LIBXSMM_DNN_FUSEDGROUPNORM_BACKWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_fusedgroupnorm_forward.h"
#include "libxsmm_main.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#if 0
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDGN_FWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDGN_FWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDGN_FWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDGN_FWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDGN_FWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN) ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDGN_FWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#endif
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if all required tensors are bound */
if ( handle->reg_input == 0 || handle->reg_output == 0 ||
handle->reg_beta == 0 || handle->reg_gamma == 0 ||
handle->expvalue == 0 || handle->rcpstddev == 0 || handle->variance == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_GN) > 0 ) {
if ( handle->scratch == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) > 0 ) {
if ( handle->reg_add == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) > 0 ) {
if ( handle->relumask == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
}
/* check if we are on an AVX512 platform */
#if 0
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 16) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c16( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 32) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c32( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 64) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_f32_f32_c64( handle, start_thread, tid );
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_fusedgroupnorm_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_stats_type;
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef float element_stats_type;
# define LIBXSMM_DNN_FUSEDGN_FWD_BF16
if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU ) {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER;
} else {
if ( handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDGN_OPS_GN ) {
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_ELTWISE
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU) == LIBXSMM_DNN_FUSEDGN_OPS_RELU ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU
} else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK ) {
# define LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
# include "template/libxsmm_dnn_fusedgroupnorm_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_FUSEDGN_FWD_ENABLE_RELU_WITH_MASK
} else {
status = LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION;
}
}
# undef LIBXSMM_DNN_FUSEDGN_FWD_BF16
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_nhwc(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
LIBXSMM_UNUSED( handle );
LIBXSMM_UNUSED( start_thread );
LIBXSMM_UNUSED( tid );
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_FUSEDGROUPNORM_FORWARD_H
#define LIBXSMM_DNN_FUSEDGROUPNORM_FORWARD_H
#include <libxsmm_dnn_fusedgroupnorm.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_custom(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedgroupnorm_st_fwd_nhwc(libxsmm_dnn_fusedgroupnorm* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_FUSEDGROUPNORM_FORWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_optimizer_sgd.h"
#include "libxsmm_main.h"
LIBXSMM_API libxsmm_dnn_optimizer* libxsmm_dnn_create_optimizer(libxsmm_dnn_optimizer_desc optimizer_desc, libxsmm_dnn_err_t* status) {
libxsmm_dnn_optimizer* handle = 0;
/* init libxsmm */
LIBXSMM_INIT
if ( (optimizer_desc.datatype == LIBXSMM_DNN_DATATYPE_F32) || (optimizer_desc.datatype == LIBXSMM_DNN_DATATYPE_BF16) ) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle = (libxsmm_dnn_optimizer*)calloc(1, sizeof(libxsmm_dnn_optimizer));
if (0 != handle) {
*status = LIBXSMM_DNN_SUCCESS;
/* let's make the description persistent */
handle->desc = optimizer_desc;
if ( (handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
/* we need to compute the memory layout given the */
*status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K,
&(handle->bc), &(handle->bk), &(handle->fm_lp_block),
handle->desc.datatype, handle->desc.datatype );
/* compute the outer blocks */
handle->Bc = handle->desc.C / handle->bc;
handle->Bk = handle->desc.K / handle->bk;
} else if ( (handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0 ) {
if ( optimizer_desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
handle->fm_lp_block = 1;
} else if ( optimizer_desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
handle->fm_lp_block = 2;
} else {
}
handle->bc = handle->desc.bc;
handle->bk = handle->desc.bk;
handle->Bc = handle->desc.C / handle->bc;
handle->Bk = handle->desc.K / handle->bk;
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
free( handle );
handle = 0;
return handle;
}
/* create barrier */
handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
/* calculate scratch size for local optimizer copies of one feature map block per thread */
handle->scratch_size = 1;
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
}
} else {
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
return handle;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_optimizer(const libxsmm_dnn_optimizer* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
/* Deallocate barrier */
if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
/* deallocate handle structure */
free(/*remove constness*/(libxsmm_dnn_optimizer*)handle);
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_optimizer_create_tensor_datalayout(const libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor_datalayout* layout;
*status = LIBXSMM_DNN_SUCCESS;
layout = 0;
if (handle != 0) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout));
if (layout != 0) {
layout->format = handle->desc.filter_format;
if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) || (type == LIBXSMM_DNN_MASTER_FILTER) ) {
if ( ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) || ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) ) {
if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
layout->datatype = handle->desc.datatype;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = handle->bk;
layout->dim_size[1] = handle->bc;
layout->dim_size[2] = handle->Bc;
layout->dim_size[3] = handle->Bk;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
layout->datatype = handle->desc.datatype;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = handle->fm_lp_block;
layout->dim_size[1] = handle->bk;
layout->dim_size[2] = handle->bc/handle->fm_lp_block;
layout->dim_size[3] = handle->Bc;
layout->dim_size[4] = handle->Bk;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return layout;
}
LIBXSMM_API size_t libxsmm_dnn_optimizer_get_scratch_size(const libxsmm_dnn_optimizer* handle, libxsmm_dnn_err_t* status) {
size_t l_scratch_size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return l_scratch_size;
}
LIBXSMM_API void* libxsmm_dnn_optimizer_get_scratch_ptr(const libxsmm_dnn_optimizer* handle, libxsmm_dnn_err_t* status)
{
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
return handle->scratch;
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return 0;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_bind_scratch(libxsmm_dnn_optimizer* handle, const void* scratch) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
uintptr_t address = (uintptr_t)scratch;
size_t offset = 0;
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
if (0 != handle) {
/* align the internal scratch buffer if needed */
if (address % 64 == 0) {
handle->scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch = (void*)(address+offset);
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_release_scratch(libxsmm_dnn_optimizer* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
handle->scratch = 0;
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_bind_tensor(libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && (type != LIBXSMM_DNN_MASTER_FILTER) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0 && tensor != 0) {
libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_optimizer_create_tensor_datalayout(handle, type, &status);
if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
handle->reg_filter = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
handle->grad_filter = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_MASTER_FILTER ) {
handle->master_filter = (libxsmm_dnn_tensor*)tensor;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
}
libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_optimizer_get_tensor(libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor* return_tensor = 0;
*status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && (type != LIBXSMM_DNN_MASTER_FILTER) ) {
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return return_tensor;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
return_tensor = handle->reg_filter;
} else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
return_tensor = handle->grad_filter;
} else if ( type == LIBXSMM_DNN_MASTER_FILTER ) {
return_tensor = handle->master_filter;
} else {
/* cannot happen */
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return return_tensor;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_release_tensor(libxsmm_dnn_optimizer* handle, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) && (type != LIBXSMM_DNN_MASTER_FILTER) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
handle->reg_filter = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
handle->grad_filter = 0;
} else if ( type == LIBXSMM_DNN_MASTER_FILTER ) {
handle->master_filter = 0;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_optimizer_execute_st(libxsmm_dnn_optimizer* handle, /*unsigned*/int start_thread, /*unsigned*/int tid) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
if (handle->desc.opt_type == LIBXSMM_DNN_OPTIMIZER_SGD) {
libxsmm_dnn_optimizer_sgd_st( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_optimizer_sgd.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_f32_f32(libxsmm_dnn_optimizer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_bf16_bf16(libxsmm_dnn_optimizer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_f32_f32(libxsmm_dnn_optimizer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_filter_type;
# define LIBXSMM_DNN_OPTIMIZER_SGD_F32_AVX512
# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c"
# undef LIBXSMM_DNN_OPTIMIZER_SGD_F32_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_bf16_bf16(libxsmm_dnn_optimizer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_filter_type;
typedef float element_master_type;
# define LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512
# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c"
# undef LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st(libxsmm_dnn_optimizer* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have filter, grad_filter */
if ( handle->reg_filter == 0 || handle->grad_filter == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
if ( (handle->master_filter == 0) && (handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16) ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_optimizer_sgd_st_f32_f32( handle, start_thread, tid);
} else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_optimizer_sgd_st_bf16_bf16( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_filter_type;
# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c"
} else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
typedef libxsmm_bfloat16 element_filter_type;
typedef float element_master_type;
# define LIBXSMM_DNN_OPTIMIZER_SGD_BF16
# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c"
# undef LIBXSMM_DNN_OPTIMIZER_SGD_BF16
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_OPTIMIZER_SGD_H
#define LIBXSMM_DNN_OPTIMIZER_SGD_H
#include <libxsmm_dnn_optimizer.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st(libxsmm_dnn_optimizer* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_OPTIMIZER_SGD_H */
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment