"vscode:/vscode.git/clone" did not exist on "52a4480d70592dde520240b1694184612108ca6f"
Commit c454d419 authored by lisj's avatar lisj
Browse files

删除子模块的gitignore

parent 3359c1f1
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_pooling_backward.h"
#include "libxsmm_dnn_pooling_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API libxsmm_dnn_pooling* libxsmm_dnn_create_pooling(libxsmm_dnn_pooling_desc pooling_desc, libxsmm_dnn_err_t* status) {
libxsmm_dnn_pooling* handle = 0;
int lpb;
/* init libxsmm */
LIBXSMM_INIT
if ( ((pooling_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (pooling_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ||
((pooling_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (pooling_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle = (libxsmm_dnn_pooling*)calloc(1, sizeof(libxsmm_dnn_pooling));
if (0 != handle) {
*status = LIBXSMM_DNN_SUCCESS;
/* let's make the description persistent */
handle->desc = pooling_desc;
/* we need to compute the memory layout given the */
*status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C,
&(handle->ifmblock), &(handle->ofmblock), &lpb,
handle->desc.datatype_in, handle->desc.datatype_out );
/* compute the outer blocks */
handle->blocksifm = handle->desc.C / handle->ifmblock;
handle->blocksofm = handle->desc.C / handle->ofmblock;
/* setting ofh and ofw */
handle->ofh = (handle->desc.H + 2 * handle->desc.pad_h - handle->desc.R) / handle->desc.u + 1;
handle->ofw = (handle->desc.W + 2 * handle->desc.pad_w - handle->desc.S) / handle->desc.v + 1;
/* create barrier */
handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
/* calculate scratch size for local pooling copies of one feature map block per thread */
handle->scratch_size = (sizeof(float) * ( (size_t)handle->desc.H + (size_t)LIBXSMM_MAX(handle->desc.pad_h_in, handle->desc.pad_h_out)*2 )
* ( (size_t)handle->desc.W + (size_t)LIBXSMM_MAX(handle->desc.pad_w_in, handle->desc.pad_w_out)*2 )
* LIBXSMM_MAX( handle->ofmblock, handle->ifmblock )
* handle->desc.threads );
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
}
} else {
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
return handle;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_pooling(const libxsmm_dnn_pooling* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
/* Deallocate barrier */
if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
/* deallocate handle structure */
free(/*remove constness*/(libxsmm_dnn_pooling*)handle);
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_pooling_create_tensor_datalayout(const libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor_datalayout* layout;
*status = LIBXSMM_DNN_SUCCESS;
layout = 0;
if (handle != 0) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout));
if (layout != 0) {
layout->format = handle->desc.buffer_format;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ||
(type == LIBXSMM_DNN_POOLING_MASK) ) {
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
if ( type == LIBXSMM_DNN_POOLING_MASK ) {
layout->datatype = handle->desc.datatype_mask;
} else {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
}
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_POOLING_MASK) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->ofw;
layout->dim_size[2] = handle->ofh;
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
if ( type == LIBXSMM_DNN_POOLING_MASK ) {
layout->datatype = handle->desc.datatype_mask;
} else {
layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
}
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 5;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->dim_size[0] = handle->ifmblock;
layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
layout->dim_size[3] = handle->blocksifm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_POOLING_MASK) ) {
layout->dim_size[0] = handle->ofmblock;
layout->dim_size[1] = handle->ofw;
layout->dim_size[2] = handle->ofh;
layout->dim_size[3] = handle->blocksofm;
layout->dim_size[4] = handle->desc.N;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
if ( type == LIBXSMM_DNN_POOLING_MASK ) {
layout->datatype = handle->desc.datatype_mask;
} else {
layout->datatype = handle->desc.datatype_in;
}
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
layout->dim_size[0] = handle->desc.C;
layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
layout->dim_size[3] = handle->desc.N;
} else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
layout->dim_size[0] = handle->desc.C;
layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out);
layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out);
layout->dim_size[3] = handle->desc.N;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return layout;
}
LIBXSMM_API size_t libxsmm_dnn_pooling_get_scratch_size(const libxsmm_dnn_pooling* handle, libxsmm_dnn_err_t* status) {
size_t l_scratch_size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return l_scratch_size;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_scratch(libxsmm_dnn_pooling* handle, const void* scratch) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
uintptr_t address = (uintptr_t)scratch;
size_t offset = 0;
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
if (0 != handle) {
/* align the internal scratch buffer if needed */
if (address % 64 == 0) {
handle->scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch = (void*)(address+offset);
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_scratch(libxsmm_dnn_pooling* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
handle->scratch = 0;
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_POOLING_MASK) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0 && tensor != 0) {
libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_pooling_create_tensor_datalayout(handle, type, &status);
if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_POOLING_MASK ) {
handle->mask = (libxsmm_dnn_tensor*)tensor;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
}
libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_pooling_get_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor* return_tensor = 0;
*status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_POOLING_MASK) ) {
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return return_tensor;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
return_tensor = handle->reg_input;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
return_tensor = handle->grad_input;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
return_tensor = handle->reg_output;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
return_tensor = handle->grad_output;
} else if ( type == LIBXSMM_DNN_POOLING_MASK ) {
return_tensor = handle->mask;
} else {
/* cannot happen */
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return return_tensor;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
(type != LIBXSMM_DNN_POOLING_MASK) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
handle->grad_output = 0;
} else if ( type == LIBXSMM_DNN_POOLING_MASK ) {
handle->mask = 0;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_execute_st(libxsmm_dnn_pooling* handle, libxsmm_dnn_compute_kind kind,
/*unsigned*/int start_thread, /*unsigned*/int tid) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
switch (handle->desc.buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_pooling_st_fwd_custom( handle, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
}
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD: {
switch (handle->desc.buffer_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
status = libxsmm_dnn_pooling_st_bwd_custom( handle, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_pooling_backward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
# define LIBXSMM_DNN_POOLING_BWD_BF16
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
# undef LIBXSMM_DNN_POOLING_BWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
# define LIBXSMM_DNN_POOLING_BWD_BF16
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
# undef LIBXSMM_DNN_POOLING_BWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
# define LIBXSMM_DNN_POOLING_BWD_BF16
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
# undef LIBXSMM_DNN_POOLING_BWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and mask */
if ( handle->grad_input == 0 || handle->grad_output == 0 ||
( (handle->mask == 0) && (handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX) ) ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 16) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c16( handle, start_thread, tid);
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c16( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 32) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c32( handle, start_thread, tid);
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c32( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 64) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c64( handle, start_thread, tid);
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c64( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
# define LIBXSMM_DNN_POOLING_BWD_BF16
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
# undef LIBXSMM_DNN_POOLING_BWD_BF16
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
LIBXSMM_UNUSED( handle );
LIBXSMM_UNUSED( start_thread );
LIBXSMM_UNUSED( tid );
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_POOLING_BACKWARD_H
#define LIBXSMM_DNN_POOLING_BACKWARD_H
#include <libxsmm_dnn_pooling.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_bwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_POOLING_BACKWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_pooling_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
# define LIBXSMM_DNN_POOLING_FWD_BF16
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
# undef LIBXSMM_DNN_POOLING_FWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
# define LIBXSMM_DNN_POOLING_FWD_BF16
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
# undef LIBXSMM_DNN_POOLING_FWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
# define LIBXSMM_DNN_POOLING_FWD_BF16
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
# undef LIBXSMM_DNN_POOLING_FWD_BF16
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and mask */
if ( handle->reg_input == 0 || handle->reg_output == 0 ||
( (handle->mask == 0) && (handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX) ) ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 16) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16( handle, start_thread, tid);
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 32) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32( handle, start_thread, tid);
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
(handle->ofmblock == 64) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64( handle, start_thread, tid);
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
LIBXSMM_ASSERT(NULL != handle->mask);
status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
} else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
# define LIBXSMM_DNN_POOLING_FWD_BF16
if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef int element_mask_type;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
} else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
}
# undef LIBXSMM_DNN_POOLING_FWD_BF16
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
LIBXSMM_UNUSED( handle );
LIBXSMM_UNUSED( start_thread );
LIBXSMM_UNUSED( tid );
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_POOLING_FORWARD_H
#define LIBXSMM_DNN_POOLING_FORWARD_H
#include <libxsmm_dnn_pooling.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_POOLING_FORWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas, Kunal Banerjee (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_rnncell_forward.h"
#include "libxsmm_dnn_rnncell_backward_weight_update.h"
#include "libxsmm_dnn_elementwise.h"
#include "libxsmm_main.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
LIBXSMM_API libxsmm_dnn_rnncell* libxsmm_dnn_create_rnncell(libxsmm_dnn_rnncell_desc rnncell_desc, libxsmm_dnn_err_t* status)
{
libxsmm_dnn_rnncell* handle = 0;
/* init libxsmm */
LIBXSMM_INIT
/* some check we can do before allocating the handle */
if ( (rnncell_desc.datatype_in != rnncell_desc.datatype_out) ||
( (rnncell_desc.datatype_in != LIBXSMM_DNN_DATATYPE_BF16) && (rnncell_desc.datatype_in != LIBXSMM_DNN_DATATYPE_F32) ) ) {
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return NULL;
}
/* let's do some simple checks for BF16 as this limits the cell and architecture */
if ( (rnncell_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (rnncell_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
if ( (LIBXSMM_X86_AVX512_CORE > libxsmm_target_archid) || (rnncell_desc.C % 16 != 0) || (rnncell_desc.K % 16 != 0) ) {
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return NULL;
}
}
/* we need at least one timestep */
if (rnncell_desc.max_T < 1) {
*status = LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL;
return NULL;
}
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle = (libxsmm_dnn_rnncell*)calloc(1, sizeof(libxsmm_dnn_rnncell));
if (NULL != handle) {
*status = LIBXSMM_DNN_SUCCESS;
/* initialize known handle components */
handle->desc = rnncell_desc;
/* set current seq length to max length */
handle->T = rnncell_desc.max_T;
/* set blocking factors */
handle->bk = (handle->desc.bk == 0) ? 64 : handle->desc.bk;
handle->bn = (handle->desc.bn == 0) ? 64 : handle->desc.bn;
handle->bc = (handle->desc.bc == 0) ? 64 : handle->desc.bc;
handle->use_fwd_fused_impl = handle->desc.use_fwd_fused_impl;
handle->fwd_block = handle->desc.fwd_block;
handle->bwdupd_block = handle->desc.bwdupd_block;
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
handle->lpb = 2;
} else {
handle->lpb = 1;
}
/* validate blocking factors */
if ( handle->desc.N % handle->bn != 0 ) {
handle->bn = handle->desc.N;
*status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING;
}
if ( handle->desc.C % handle->bc != 0 ) {
handle->bc = handle->desc.C;
*status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING;
}
if ( handle->desc.K % handle->bk != 0 ) {
handle->bk = handle->desc.K;
*status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING;
}
/* If in SPR, generate tilerelease kernel */
if ((libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) {
int l_tr_flags = LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
handle->tilerelease_kernel = libxsmm_bsmmdispatch(handle->bk, handle->bk, handle->bk, NULL, NULL, NULL, NULL, NULL, &l_tr_flags, NULL);
}
/* In case of BF16 for now hoist the BRGEMM and make them to use STRIDED variant by default */
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
libxsmm_blasint BF, CB_BLOCKS, KB_BLOCKS;
const libxsmm_blasint K = handle->desc.K;
const libxsmm_blasint N = handle->desc.N;
const libxsmm_blasint C = handle->desc.C;
const libxsmm_blasint bk = handle->bk;
const libxsmm_blasint bn = handle->bn;
const libxsmm_blasint bc = handle->bc;
const libxsmm_blasint cBlocks = C/bc;
const libxsmm_blasint kBlocks = K/bk;
const libxsmm_blasint nBlocks = N/bn;
int tc_flags = 0;
int kernel_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
int stride_a, stride_b;
if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) {
kernel_flags = ((handle->bk % 32 == 0) && (handle->bc % 32 == 0) && (handle->bn % 32 == 0)) ? LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG : 0;
kernel_flags = kernel_flags | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
tc_flags = LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG | ( LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') );
}
/* Blocking reduction domain if it is too large */
BF = 1;
if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) {
BF = 8;
while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
BF--;
}
}
if (C > 2048 || K > 2048) {
BF = 16;
while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
BF--;
}
}
if (C == 2048 && K == 1024) {
BF = 2;
}
BF = handle->fwd_block;
if (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) {
CB_BLOCKS = cBlocks/BF;
KB_BLOCKS = kBlocks/BF;
/* define batch-reduce gemm kernels */
stride_a = bc * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bc * bn * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->fwd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bc, stride_a, stride_b, CB_BLOCKS, &bk, &bc, &bk, NULL, NULL, &kernel_flags, NULL );
stride_a = bk * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bk * bn * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->fwd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &bk, &bk, NULL, NULL, &kernel_flags, NULL );
if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) {
handle->fwd_tileconfig = libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL );
}
BF = handle->bwdupd_block;
KB_BLOCKS = kBlocks/BF;
stride_a = bc * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bk * bn * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->bwdupd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bc, bn, bk, stride_a, stride_b, KB_BLOCKS, &bc, &bk, &bc, NULL, NULL, &kernel_flags, NULL);
stride_a = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->bwdupd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bk, bn, stride_a, stride_b, nBlocks, &bk, &bn, &bk, NULL, NULL, &kernel_flags, NULL);
stride_a = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bn * bc * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->bwdupd_kernelc = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bc, bn, stride_a, stride_b, nBlocks, &bk, &bn, &bk, NULL, NULL, &kernel_flags, NULL);
stride_a = bk * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->bwdupd_kerneld = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &bk, &bk, NULL, NULL, &kernel_flags, NULL);
if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) {
handle->bwdupd_tileconfig = libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL);
}
} else {
CB_BLOCKS = cBlocks/BF;
KB_BLOCKS = kBlocks/BF;
/* define batch-reduce gemm kernels */
stride_a = bc * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bc * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->fwd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bc, stride_a, stride_b, CB_BLOCKS, &bk, &C, &K, NULL, NULL, &kernel_flags, NULL );
stride_a = bk * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->fwd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL );
if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) {
handle->fwd_tileconfig = libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL );
}
BF = handle->bwdupd_block;
KB_BLOCKS = kBlocks/BF;
stride_a = bc * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->bwdupd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bc, bn, bk, stride_a, stride_b, KB_BLOCKS, &bc, &K, &C, NULL, NULL, &kernel_flags, NULL);
stride_a = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bn * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->bwdupd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bk, bn, stride_a, stride_b, nBlocks, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);
stride_a = bn * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bn * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->bwdupd_kernelc = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bc, bn, stride_a, stride_b, nBlocks, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);
stride_a = bk * bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
stride_b = bk * libxsmm_dnn_typesize(handle->desc.datatype_in);
handle->bwdupd_kerneld = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL);
if ((libxsmm_target_archid == LIBXSMM_X86_AVX512_SPR) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT)) {
handle->bwdupd_tileconfig = libxsmm_bsmmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, &tc_flags, NULL);
}
}
}
/* Need to allocate space for scratch libxsmm_dnn_tensor's, let's set all pointers to zero */
handle->internal_z = 0;
handle->scratch_wT = 0;
handle->scratch_rT = 0;
handle->scratch_xT = 0;
handle->scratch_hT = 0;
handle->scratch_deltat = 0;
handle->scratch_di = 0;
handle->scratch_df = 0;
handle->scratch_do = 0;
handle->scratch_dci = 0;
handle->scratch_diB = 0;
handle->scratch_dfB = 0;
handle->scratch_dpB = 0;
handle->scratch_dciB = 0;
/* initialize a high-performant barrier */
handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
if (NULL == handle->barrier)
{
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
free(handle);
return NULL;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
}
return handle;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_rnncell(const libxsmm_dnn_rnncell* handle)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
/* Deallocate barrier */
if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
/* deallocate handle structure */
free(/*remove constness*/(libxsmm_dnn_rnncell*)handle);
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_rnncell_create_tensor_datalayout(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status)
{
libxsmm_dnn_tensor_datalayout* layout;
*status = LIBXSMM_DNN_SUCCESS;
layout = 0;
if (handle != 0) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout));
if (layout != 0) {
if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ||
(type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) ||
(type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) ||
(type == LIBXSMM_DNN_RNN_REGULAR_CS) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS) ||
(type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) ||
(type == LIBXSMM_DNN_RNN_INTERNAL_I) || (type == LIBXSMM_DNN_RNN_INTERNAL_F) ||
(type == LIBXSMM_DNN_RNN_INTERNAL_O) || (type == LIBXSMM_DNN_RNN_INTERNAL_CI) ||
(type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
layout->format = handle->desc.buffer_format;
layout->tensor_type = LIBXSMM_DNN_ACTIVATION;
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 5;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
layout->dim_size[0] = (unsigned int)handle->bc;
layout->dim_size[1] = (unsigned int)handle->bn;
layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
layout->dim_size[4] = (unsigned int)handle->desc.max_T;
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) ||
(type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) ||
(type == LIBXSMM_DNN_RNN_REGULAR_CS) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS) ||
(type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) ||
(type == LIBXSMM_DNN_RNN_INTERNAL_I) || (type == LIBXSMM_DNN_RNN_INTERNAL_F) ||
(type == LIBXSMM_DNN_RNN_INTERNAL_O) || (type == LIBXSMM_DNN_RNN_INTERNAL_CI) ||
(type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)handle->bn;
layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
layout->dim_size[4] = (unsigned int)handle->desc.max_T;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NC) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 3;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
layout->dim_size[0] = (unsigned int)handle->desc.C;
layout->dim_size[1] = (unsigned int)handle->desc.N;
layout->dim_size[2] = (unsigned int)handle->desc.max_T;
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) ||
(type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) ||
(type == LIBXSMM_DNN_RNN_REGULAR_CS) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS) ||
(type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) ||
(type == LIBXSMM_DNN_RNN_INTERNAL_I) || (type == LIBXSMM_DNN_RNN_INTERNAL_F) ||
(type == LIBXSMM_DNN_RNN_INTERNAL_O) || (type == LIBXSMM_DNN_RNN_INTERNAL_CI) ||
(type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
layout->dim_size[0] = (unsigned int)handle->desc.K;
layout->dim_size[1] = (unsigned int)handle->desc.N;
layout->dim_size[2] = (unsigned int)handle->desc.max_T;
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ||
(type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
layout->format = handle->desc.filter_format;
layout->tensor_type = LIBXSMM_DNN_FILTER;
if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) {
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = handle->desc.datatype_in;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 5;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)handle->bc;
layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[4] = 4;
} else {
layout->dim_size[4] = 3;
}
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[4] = 4;
} else {
layout->dim_size[4] = 3;
}
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)handle->bc;
layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
layout->datatype = handle->desc.datatype_in;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 6;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = (unsigned int)handle->lpb;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->bc / handle->lpb);
layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[5] = 4;
} else {
layout->dim_size[5] = 3;
}
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = (unsigned int)handle->lpb;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[5] = 4;
} else {
layout->dim_size[5] = 3;
}
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 5;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = (unsigned int)handle->lpb;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->bc / handle->lpb);
layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = (unsigned int)handle->lpb;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CK) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 2;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[0] = (unsigned int)(handle->desc.K * 4);
layout->dim_size[1] = (unsigned int)handle->desc.C;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
layout->dim_size[0] = (unsigned int)(handle->desc.K * 3);
layout->dim_size[1] = (unsigned int)handle->desc.C;
} else {
layout->dim_size[0] = (unsigned int)handle->desc.K;
layout->dim_size[1] = (unsigned int)handle->desc.C;
}
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[0] = (unsigned int)(handle->desc.K * 4);
layout->dim_size[1] = (unsigned int)handle->desc.K;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
layout->dim_size[0] = (unsigned int)(handle->desc.K * 3);
layout->dim_size[1] = (unsigned int)handle->desc.K;
} else {
layout->dim_size[0] = (unsigned int)handle->desc.K;
layout->dim_size[1] = (unsigned int)handle->desc.K;
}
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) || (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
layout->format = handle->desc.filter_format;
layout->tensor_type = LIBXSMM_DNN_FILTER;
if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) {
if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = handle->desc.datatype_in;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 5;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = (unsigned int)handle->bc;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[4] = 4;
} else {
layout->dim_size[4] = 3;
}
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[4] = 4;
} else {
layout->dim_size[4] = 3;
}
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = (unsigned int)handle->bc;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = (unsigned int)handle->bk;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
layout->datatype = handle->desc.datatype_in;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 6;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = (unsigned int)handle->lpb;
layout->dim_size[1] = (unsigned int)handle->bc;
layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[4] = (unsigned int)(handle->desc.C / handle->bc);
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[5] = 4;
} else {
layout->dim_size[5] = 3;
}
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
layout->dim_size[0] = (unsigned int)handle->lpb;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[5] = 4;
} else {
layout->dim_size[5] = 3;
}
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 5;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = (unsigned int)handle->lpb;
layout->dim_size[1] = (unsigned int)handle->bc;
layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[4] = (unsigned int)(handle->desc.C / handle->bc);
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_size[0] = (unsigned int)handle->lpb;
layout->dim_size[1] = (unsigned int)handle->bk;
layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CK) > 0) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 2;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[0] = (unsigned int)handle->desc.C;
layout->dim_size[1] = (unsigned int)(handle->desc.K * 4);
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
layout->dim_size[0] = (unsigned int)handle->desc.C;
layout->dim_size[1] = (unsigned int)(handle->desc.K * 3);
} else {
layout->dim_size[0] = (unsigned int)handle->desc.C;
layout->dim_size[1] = (unsigned int)handle->desc.K;
}
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[0] = (unsigned int)handle->desc.K;
layout->dim_size[1] = (unsigned int)(handle->desc.K * 4);
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
layout->dim_size[0] = (unsigned int)handle->desc.K;
layout->dim_size[1] = (unsigned int)(handle->desc.K * 3);
} else {
layout->dim_size[0] = (unsigned int)handle->desc.K;
layout->dim_size[1] = (unsigned int)handle->desc.K;
}
} else {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( (type == LIBXSMM_DNN_RNN_REGULAR_BIAS) || (type == LIBXSMM_DNN_RNN_GRADIENT_BIAS) ) {
layout->format = handle->desc.buffer_format;
layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NC) > 0) || ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) {
if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
layout->datatype = handle->desc.datatype_in;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 1;
if ( (type == LIBXSMM_DNN_RNN_REGULAR_BIAS) || (type == LIBXSMM_DNN_RNN_GRADIENT_BIAS) ) {
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
layout->dim_size[0] = (unsigned int)(handle->desc.K * 4);
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
layout->dim_size[0] = (unsigned int)(handle->desc.K * 3);
} else {
layout->dim_size[0] = (unsigned int)handle->desc.K;
}
} else { /* coverity[dead_error_begin] */
free(layout->dim_type);
free(layout->dim_size);
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return layout;
}
LIBXSMM_API size_t libxsmm_dnn_rnncell_get_scratch_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status)
{
size_t size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
const size_t typesize_in = libxsmm_dnn_typesize(handle->desc.datatype_in);
const size_t dwdr_typesize = (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ? sizeof(float) : typesize_in;
switch (handle->desc.cell_type) {
case LIBXSMM_DNN_RNNCELL_RNN_RELU:
case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
size += 0;
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in + 64; /* wT */
size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in + 64; /* rT */
size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; /* xT */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) * (size_t)handle->desc.max_T + 64; /* deltat */
} break;
default: {
*status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_LSTM: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* w */
size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* r */
/* The scratches below are needed only for BF16 code for the intermediate results */
if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
size += (size_t)7 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64); /* intermediate scratches */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64; /* intermediate scratches */
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
size += (size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize * 4 + 4 * 64; /* w */
size += (size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize * 4 + 4 * 64; /* r */
size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* wT */
size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* rT */
size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; /* xT */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; /* deltat */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* di */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* df */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* do */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dci */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* diB */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dfB */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dpB */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dciB */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t1 */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t2 */
/* The scratches below are needed only for BF16 code for the intermediate results */
if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
size += (size_t)4 *((size_t)handle->desc.K * sizeof(float) + 64); /* intermediate db scratch */
size += (size_t)handle->desc.C * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; /* intermediate dx scratches */
size += (size_t)7 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64); /* intermediate scratches */
size += (size_t)2 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64); /* intermediate scratches */
}
} break;
default: {
*status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_GRU: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* w */
size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* r */
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
size += (size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize * 3 + 3 * 64; /* w */
size += (size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize * 3 + 3 * 64; /* r */
size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* wT */
size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* rT */
size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; /* xT */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; /* deltat */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* di */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dc */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* df */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* do */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* diB */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dcB */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dfB */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* oT */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t1 */
size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t2 */
} break;
default: {
*status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
default: {
*status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
}
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return size;
}
LIBXSMM_API void* libxsmm_dnn_rnncell_get_scratch_ptr(const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status)
{
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
return handle->scratch_base;
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return NULL;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* scratch)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (NULL != handle) {
const size_t typesize_in = libxsmm_dnn_typesize(handle->desc.datatype_in);
const size_t dwdr_typesize = (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ? sizeof(float) : typesize_in;
uintptr_t address = (uintptr_t)scratch;
size_t offset = 0;
switch (handle->desc.cell_type) {
case LIBXSMM_DNN_RNNCELL_RNN_RELU:
case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
/* forward only has no scratch need */
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
handle->scratch_base = (void*)address;
/* wT */
if (address % 64 == 0) {
handle->scratch_wT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_wT = (void*)(address+offset);
}
address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) + 64;
/* rT */
if (address % 64 == 0) {
handle->scratch_rT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_rT = (void*)(address+offset);
}
address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) + 64;
/* xT */
if (address % 64 == 0) {
handle->scratch_xT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_xT = (void*)(address+offset);
}
address += ((size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in) + 64;
/* hT */
if (address % 64 == 0) {
handle->scratch_hT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_hT = (void*)(address+offset);
}
address += ((size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out)) + 64;
/* deltat */
if (address % 64 == 0) {
handle->scratch_deltat = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_deltat = (void*)(address+offset);
}
address += ((size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) * (size_t)handle->desc.max_T) + 64;
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_LSTM: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
handle->scratch_base = (void*)address;
/* w scratch */
if (address % 64 == 0) {
handle->scratch_w = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_w = (void*)(address+offset);
}
address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 4 + 64;
/* r scratch */
if (address % 64 == 0) {
handle->scratch_r = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_r = (void*)(address+offset);
}
address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 4 + 64;
/* The scratches below are needed only for BF16 code for the intermediate results */
if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
/* cst scratch */
if (address % 64 == 0) {
handle->cst_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->cst_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* ht scratch */
if (address % 64 == 0) {
handle->ht_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->ht_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* it scratch */
if (address % 64 == 0) {
handle->it_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->it_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* ft scratch */
if (address % 64 == 0) {
handle->ft_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->ft_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* ot scratch */
if (address % 64 == 0) {
handle->ot_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->ot_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* cit scratch */
if (address % 64 == 0) {
handle->cit_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->cit_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* cot scratch */
if (address % 64 == 0) {
handle->cot_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->cot_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* csp scratch */
if (address % 64 == 0) {
handle->csp_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->csp_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
handle->scratch_base = (void*)address;
/* w scratch */
if (address % 64 == 0) {
handle->scratch_w = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_w = (void*)(address+offset);
}
address += ((size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize) * 4 + 64;
/* r scratch */
if (address % 64 == 0) {
handle->scratch_r = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_r = (void*)(address+offset);
}
address += ((size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize) * 4 + 64;
/* wT */
if (address % 64 == 0) {
handle->scratch_wT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_wT = (void*)(address+offset);
}
address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 4 + 64;
/* rT */
if (address % 64 == 0) {
handle->scratch_rT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_rT = (void*)(address+offset);
}
address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 4 + 64;
/* xT */
if (address % 64 == 0) {
handle->scratch_xT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_xT = (void*)(address+offset);
}
address += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64;
/* hT */
if (address % 64 == 0) {
handle->scratch_hT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_hT = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* deltat */
if (address % 64 == 0) {
handle->scratch_deltat = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_deltat = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64;
/* di */
if (address % 64 == 0) {
handle->scratch_di = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_di = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* df */
if (address % 64 == 0) {
handle->scratch_df = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_df = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* do */
if (address % 64 == 0) {
handle->scratch_do = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_do = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* dci */
if (address % 64 == 0) {
handle->scratch_dci = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dci = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* diB */
if (address % 64 == 0) {
handle->scratch_diB = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_diB = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* dfB */
if (address % 64 == 0) {
handle->scratch_dfB = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dfB = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* dpB */
if (address % 64 == 0) {
handle->scratch_dpB = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dpB = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* dciB */
if (address % 64 == 0) {
handle->scratch_dciB = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dciB = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* t1 */
if (address % 64 == 0) {
handle->scratch_t1 = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_t1 = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* t2 */
if (address % 64 == 0) {
handle->scratch_t2 = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_t2 = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* The scratches below are needed only for BF16 code for the intermediate results */
if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
/* dx scratch */
if (address % 64 == 0) {
handle->scratch_dx = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dx = (void*)(address+offset);
}
address += (size_t)handle->desc.C * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* dhp scratch */
if (address % 64 == 0) {
handle->scratch_dhp = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dhp = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;
/* db scratch */
if (address % 64 == 0) {
handle->scratch_db = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_db = (void*)(address+offset);
}
address += (size_t)handle->desc.K * 4 * sizeof(float) + 64;
/* cst scratch */
if (address % 64 == 0) {
handle->cst_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->cst_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* ht scratch */
if (address % 64 == 0) {
handle->ht_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->ht_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* it scratch */
if (address % 64 == 0) {
handle->it_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->it_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* ft scratch */
if (address % 64 == 0) {
handle->ft_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->ft_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* ot scratch */
if (address % 64 == 0) {
handle->ot_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->ot_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* cit scratch */
if (address % 64 == 0) {
handle->cit_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->cit_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* cot scratch */
if (address % 64 == 0) {
handle->cot_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->cot_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
/* csp scratch */
if (address % 64 == 0) {
handle->csp_scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->csp_scratch = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_GRU: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
handle->scratch_base = (void*)address;
/* w scratch */
if (address % 64 == 0) {
handle->scratch_w = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_w = (void*)(address+offset);
}
address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 3 + 64;
/* r scratch */
if (address % 64 == 0) {
handle->scratch_r = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_r = (void*)(address+offset);
}
address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 3 + 64;
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
handle->scratch_base = (void*)address;
/* w scratch */
if (address % 64 == 0) {
handle->scratch_w = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_w = (void*)(address+offset);
}
address += ((size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize) * 3 + 64;
/* r scratch */
if (address % 64 == 0) {
handle->scratch_r = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_r = (void*)(address+offset);
}
address += ((size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize) * 3 + 64;
/* wT */
if (address % 64 == 0) {
handle->scratch_wT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_wT = (void*)(address+offset);
}
address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 3 + 64;
/* rT */
if (address % 64 == 0) {
handle->scratch_rT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_rT = (void*)(address+offset);
}
address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 3 + 64;
/* xT */
if (address % 64 == 0) {
handle->scratch_xT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_xT = (void*)(address+offset);
}
address += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64;
/* hT */
if (address % 64 == 0) {
handle->scratch_hT = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_hT = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* deltat */
if (address % 64 == 0) {
handle->scratch_deltat = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_deltat = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64;
/* di */
if (address % 64 == 0) {
handle->scratch_di = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_di = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* dc */
if (address % 64 == 0) {
handle->scratch_dci = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dci = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* df */
if (address % 64 == 0) {
handle->scratch_df = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_df = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* do */
if (address % 64 == 0) {
handle->scratch_do = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_do = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* diB */
if (address % 64 == 0) {
handle->scratch_diB = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_diB = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* dcB */
if (address % 64 == 0) {
handle->scratch_dciB = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dciB = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* dfB */
if (address % 64 == 0) {
handle->scratch_dfB = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dfB = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* doB (repurposed for oT) */
if (address % 64 == 0) {
handle->scratch_dpB = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_dpB = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* t1 */
if (address % 64 == 0) {
handle->scratch_t1 = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_t1 = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
/* t2 */
if (address % 64 == 0) {
handle->scratch_t2 = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch_t2 = (void*)(address+offset);
}
address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
switch (handle->desc.cell_type) {
case LIBXSMM_DNN_RNNCELL_RNN_RELU:
case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
/* forward only has no scratch need */
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
handle->scratch_wT = 0;
handle->scratch_rT = 0;
handle->scratch_xT = 0;
handle->scratch_hT = 0;
handle->scratch_deltat = 0;
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_LSTM: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
handle->scratch_w = 0;
handle->scratch_r = 0;
handle->csp_scratch = 0;
handle->cst_scratch = 0;
handle->ht_scratch = 0;
handle->it_scratch = 0;
handle->ft_scratch = 0;
handle->ot_scratch = 0;
handle->cit_scratch = 0;
handle->cot_scratch = 0;
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
handle->scratch_w = 0;
handle->scratch_r = 0;
handle->scratch_wT = 0;
handle->scratch_rT = 0;
handle->scratch_xT = 0;
handle->scratch_hT = 0;
handle->scratch_deltat = 0;
handle->scratch_di = 0;
handle->scratch_df = 0;
handle->scratch_do = 0;
handle->scratch_dci = 0;
handle->scratch_diB = 0;
handle->scratch_dfB = 0;
handle->scratch_dpB = 0;
handle->scratch_dciB = 0;
handle->scratch_t1 = 0;
handle->scratch_t2 = 0;
handle->csp_scratch = 0;
handle->cst_scratch = 0;
handle->ht_scratch = 0;
handle->it_scratch = 0;
handle->ft_scratch = 0;
handle->ot_scratch = 0;
handle->cit_scratch = 0;
handle->cot_scratch = 0;
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_GRU: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
handle->scratch_w = 0;
handle->scratch_r = 0;
handle->ht_scratch = 0;
handle->it_scratch = 0;
handle->cit_scratch = 0;
handle->ft_scratch = 0;
handle->ot_scratch = 0;
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
handle->scratch_w = 0;
handle->scratch_r = 0;
handle->scratch_wT = 0;
handle->scratch_rT = 0;
handle->scratch_xT = 0;
handle->scratch_hT = 0;
handle->scratch_deltat = 0;
handle->scratch_di = 0;
handle->scratch_dci = 0;
handle->scratch_df = 0;
handle->scratch_do = 0;
handle->scratch_diB = 0;
handle->scratch_dciB = 0;
handle->scratch_dfB = 0;
handle->scratch_dpB = 0;
handle->scratch_t1 = 0;
handle->scratch_t2 = 0;
handle->ht_scratch = 0;
handle->it_scratch = 0;
handle->ft_scratch = 0;
handle->ot_scratch = 0;
handle->cit_scratch = 0;
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API size_t libxsmm_dnn_rnncell_get_internalstate_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status)
{
size_t size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
const size_t sizeof_datatype = sizeof(float);
switch (handle->desc.cell_type) {
case LIBXSMM_DNN_RNNCELL_RNN_RELU:
case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof_datatype * (size_t)handle->desc.max_T + 64; /* zt */
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof_datatype * (size_t)handle->desc.max_T + 64; /* zt */
} break;
default: {
*status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_LSTM: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
/* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
/* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */
} break;
default: {
*status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_GRU: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
/* with i, f, c, o exposed as i/o, there is currently no need for internal state */
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
/* with i, f, c, o exposed as i/o, there is currently no need for internal state */
} break;
default: {
*status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
default: {
*status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
}
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return size;
}
LIBXSMM_API void* libxsmm_dnn_rnncell_get_internalstate_ptr(const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status)
{
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
return handle->internal_z;
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return NULL;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* internalstate)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
uintptr_t address = (uintptr_t)internalstate;
size_t offset = 0;
if (0 != handle) {
switch (handle->desc.cell_type) {
case LIBXSMM_DNN_RNNCELL_RNN_RELU:
case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
if (internalstate == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
if (address % 64 == 0) {
handle->internal_z = (void*)address;
} else {
offset = (64 - address % 64);
handle->internal_z = (void*)(address+offset);
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
if (address % 64 == 0) {
handle->internal_z = (void*)address;
} else {
offset = (64 - address % 64);
handle->internal_z = (void*)(address+offset);
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_LSTM: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_GRU: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
switch (handle->desc.cell_type) {
case LIBXSMM_DNN_RNNCELL_RNN_RELU:
case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
handle->internal_z = 0;
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
handle->internal_z = 0;
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_LSTM: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
case LIBXSMM_DNN_RNNCELL_GRU: {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_allocate_forget_bias(libxsmm_dnn_rnncell* handle, const float forget_bias)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (handle != 0) {
handle->forget_bias = forget_bias;
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) &&
(type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) &&
(type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) &&
(type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) &&
(type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) &&
(type != LIBXSMM_DNN_RNN_REGULAR_BIAS) && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS) &&
(type != LIBXSMM_DNN_RNN_REGULAR_CS) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS) &&
(type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) &&
(type != LIBXSMM_DNN_RNN_INTERNAL_I) && (type != LIBXSMM_DNN_RNN_INTERNAL_F) &&
(type != LIBXSMM_DNN_RNN_INTERNAL_O) && (type != LIBXSMM_DNN_RNN_INTERNAL_CI) &&
(type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0 && tensor != 0) {
libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_rnncell_create_tensor_datalayout(handle, type, &status);
if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) {
handle->xt = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) {
handle->dxt = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) {
handle->csp = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) {
handle->dcsp = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) {
handle->hp = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) {
handle->dhp = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) {
handle->w = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) {
handle->wt = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) {
handle->dw = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) {
handle->r = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) {
handle->rt = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) {
handle->dr = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) {
handle->b = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) {
handle->db = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) {
handle->cst = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) {
handle->dcs = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) {
handle->ht = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) {
handle->dht = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) {
handle->it = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) {
handle->ft = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) {
handle->ot = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) {
handle->cit = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) {
handle->cot = (libxsmm_dnn_tensor*)tensor;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
}
libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_rnncell_get_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status)
{
libxsmm_dnn_tensor* tensor = 0;
LIBXSMM_UNUSED(status/*TODO*/);
/* check for tensor type */
if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) &&
(type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) &&
(type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) &&
(type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) &&
(type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) &&
(type != LIBXSMM_DNN_RNN_REGULAR_BIAS) && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS) &&
(type != LIBXSMM_DNN_RNN_REGULAR_CS) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS) &&
(type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) &&
(type != LIBXSMM_DNN_RNN_INTERNAL_I) && (type != LIBXSMM_DNN_RNN_INTERNAL_F) &&
(type != LIBXSMM_DNN_RNN_INTERNAL_O) && (type != LIBXSMM_DNN_RNN_INTERNAL_CI) &&
(type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
return tensor;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) {
tensor = handle->xt;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) {
tensor = handle->dxt;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) {
tensor = handle->csp;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) {
tensor = handle->dcsp;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) {
tensor = handle->hp;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) {
tensor = handle->dhp;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) {
tensor = handle->w;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) {
tensor = handle->wt;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) {
tensor = handle->dw;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) {
tensor = handle->r;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) {
tensor = handle->rt;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) {
tensor = handle->dr;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) {
tensor = handle->b;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) {
tensor = handle->db;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) {
tensor = handle->cst;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) {
tensor = handle->dcs;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) {
tensor = handle->ht;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) {
tensor = handle->dht;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) {
tensor = handle->it;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) {
tensor = handle->ft;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) {
tensor = handle->ot;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) {
tensor = handle->cit;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) {
tensor = handle->cot;
} else {
/* cannot happen */
}
}
return tensor;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) &&
(type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) &&
(type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) &&
(type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) &&
(type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) &&
(type != LIBXSMM_DNN_RNN_REGULAR_BIAS) && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS) &&
(type != LIBXSMM_DNN_RNN_REGULAR_CS) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS) &&
(type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) &&
(type != LIBXSMM_DNN_RNN_INTERNAL_I) && (type != LIBXSMM_DNN_RNN_INTERNAL_F) &&
(type != LIBXSMM_DNN_RNN_INTERNAL_O) && (type != LIBXSMM_DNN_RNN_INTERNAL_CI) &&
(type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) {
handle->xt = 0;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) {
handle->dxt = 0;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) {
handle->csp = 0;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) {
handle->dcsp = 0;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) {
handle->hp = 0;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) {
handle->dhp = 0;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) {
handle->w = 0;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) {
handle->wt = 0;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) {
handle->dw = 0;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) {
handle->r = 0;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) {
handle->rt = 0;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) {
handle->dr = 0;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) {
handle->b = 0;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) {
handle->db = 0;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) {
handle->cst = 0;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) {
handle->dcs = 0;
} else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) {
handle->ht = 0;
} else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) {
handle->dht = 0;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) {
handle->it = 0;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) {
handle->ft = 0;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) {
handle->ot = 0;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) {
handle->cit = 0;
} else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) {
handle->cot = 0;
} else {
/* cannot happen */
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_set_sequence_length( libxsmm_dnn_rnncell* handle, const libxsmm_blasint T ) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
if ( handle->desc.max_T < T ) {
status = LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN;
} else {
handle->T = T;
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_blasint libxsmm_dnn_rnncell_get_sequence_length( libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status ) {
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
return handle->T;
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return 0;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_execute_st(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind,
/*unsigned*/int start_thread, /*unsigned*/int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CK) ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_ck( handle, start_thread, tid );
} else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_kcck( handle, start_thread, tid );
} else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck( handle, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD:
case LIBXSMM_DNN_COMPUTE_KIND_UPD:
case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: {
if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CK) ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck( handle, kind, start_thread, tid );
} else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck( handle, kind, start_thread, tid );
} else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck( handle, kind, start_thread, tid );
} else {
status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Kunal Banerjee, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_rnncell_backward_weight_update.h"
#include "libxsmm_dnn_elementwise.h"
#include "libxsmm_main.h"
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
void trans_act(short int *in, short int *out)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf;
__m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf;
__m512i v0, v1, v2, v3, v4, v5, v6, v7;
const __m512i idx_v = _mm512_set_epi64(13, 12, 7, 6, 9, 8, 3, 2);
const __mmask8 mask0 = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8(204);
const __mmask8 mask1 = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8(51);
const int in_width = 32, out_width = 32;
r0 = _mm512_loadu_si512(in + 0*in_width);
r1 = _mm512_loadu_si512(in + 1*in_width);
t0 = _mm512_unpacklo_epi16(r0,r1);
t1 = _mm512_unpackhi_epi16(r0,r1);
r2 = _mm512_loadu_si512(in + 2*in_width);
r3 = _mm512_loadu_si512(in + 3*in_width);
t2 = _mm512_unpacklo_epi16(r2,r3);
t3 = _mm512_unpackhi_epi16(r2,r3);
r4 = _mm512_loadu_si512(in + 4*in_width);
r5 = _mm512_loadu_si512(in + 5*in_width);
t4 = _mm512_unpacklo_epi16(r4,r5);
t5 = _mm512_unpackhi_epi16(r4,r5);
r6 = _mm512_loadu_si512(in + 6*in_width);
r7 = _mm512_loadu_si512(in + 7*in_width);
t6 = _mm512_unpacklo_epi16(r6,r7);
t7 = _mm512_unpackhi_epi16(r6,r7);
r8 = _mm512_loadu_si512(in + 8*in_width);
r9 = _mm512_loadu_si512(in + 9*in_width);
t8 = _mm512_unpacklo_epi16(r8,r9);
t9 = _mm512_unpackhi_epi16(r8,r9);
ra = _mm512_loadu_si512(in + 10*in_width);
rb = _mm512_loadu_si512(in + 11*in_width);
ta = _mm512_unpacklo_epi16(ra,rb);
tb = _mm512_unpackhi_epi16(ra,rb);
rc = _mm512_loadu_si512(in + 12*in_width);
rd = _mm512_loadu_si512(in + 13*in_width);
tc = _mm512_unpacklo_epi16(rc,rd);
td = _mm512_unpackhi_epi16(rc,rd);
re = _mm512_loadu_si512(in + 14*in_width);
rf = _mm512_loadu_si512(in + 15*in_width);
te = _mm512_unpacklo_epi16(re,rf);
tf = _mm512_unpackhi_epi16(re,rf);
r0 = _mm512_unpacklo_epi32(t0,t2);
r1 = _mm512_unpackhi_epi32(t0,t2);
r2 = _mm512_unpacklo_epi32(t1,t3);
r3 = _mm512_unpackhi_epi32(t1,t3);
r4 = _mm512_unpacklo_epi32(t4,t6);
r5 = _mm512_unpackhi_epi32(t4,t6);
r6 = _mm512_unpacklo_epi32(t5,t7);
r7 = _mm512_unpackhi_epi32(t5,t7);
r8 = _mm512_unpacklo_epi32(t8,ta);
r9 = _mm512_unpackhi_epi32(t8,ta);
ra = _mm512_unpacklo_epi32(t9,tb);
rb = _mm512_unpackhi_epi32(t9,tb);
rc = _mm512_unpacklo_epi32(tc,te);
rd = _mm512_unpackhi_epi32(tc,te);
re = _mm512_unpacklo_epi32(td,tf);
rf = _mm512_unpackhi_epi32(td,tf);
t0 = _mm512_unpacklo_epi64(r0,r4);
t1 = _mm512_unpackhi_epi64(r0,r4);
t2 = _mm512_unpacklo_epi64(r1,r5);
t3 = _mm512_unpackhi_epi64(r1,r5);
t4 = _mm512_unpacklo_epi64(r2,r6);
t5 = _mm512_unpackhi_epi64(r2,r6);
t6 = _mm512_unpacklo_epi64(r3,r7);
t7 = _mm512_unpackhi_epi64(r3,r7);
t8 = _mm512_unpacklo_epi64(r8,rc);
t9 = _mm512_unpackhi_epi64(r8,rc);
ta = _mm512_unpacklo_epi64(r9,rd);
tb = _mm512_unpackhi_epi64(r9,rd);
tc = _mm512_unpacklo_epi64(ra,re);
td = _mm512_unpackhi_epi64(ra,re);
te = _mm512_unpacklo_epi64(rb,rf);
tf = _mm512_unpackhi_epi64(rb,rf);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
ra = _mm512_shuffle_i32x4(tc, td, 0x88);
rb = _mm512_shuffle_i32x4(te, tf, 0x88);
rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
re = _mm512_shuffle_i32x4(tc, td, 0xdd);
rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
v0 = _mm512_permutex2var_epi64(r0, idx_v, r8);
t0 = _mm512_mask_blend_epi64( mask0, r0, v0);
_mm256_storeu_si256((__m256i*)(out + 0*out_width), _mm512_extracti64x4_epi64(t0, 0));
_mm256_storeu_si256((__m256i*)(out + 1*out_width), _mm512_extracti64x4_epi64(t0, 1));
t8 = _mm512_mask_blend_epi64( mask1, r8, v0);
_mm256_storeu_si256((__m256i*)(out + 16*out_width), _mm512_extracti64x4_epi64(t8, 0));
_mm256_storeu_si256((__m256i*)(out + 17*out_width), _mm512_extracti64x4_epi64(t8, 1));
v1 = _mm512_permutex2var_epi64(r1, idx_v, r9);
t1 = _mm512_mask_blend_epi64( mask0, r1, v1);
_mm256_storeu_si256((__m256i*)(out + 2*out_width), _mm512_extracti64x4_epi64(t1, 0));
_mm256_storeu_si256((__m256i*)(out + 3*out_width), _mm512_extracti64x4_epi64(t1, 1));
t9 = _mm512_mask_blend_epi64( mask1, r9, v1);
_mm256_storeu_si256((__m256i*)(out + 18*out_width), _mm512_extracti64x4_epi64(t9, 0));
_mm256_storeu_si256((__m256i*)(out + 19*out_width), _mm512_extracti64x4_epi64(t9, 1));
v2 = _mm512_permutex2var_epi64(r2, idx_v, ra);
t2 = _mm512_mask_blend_epi64( mask0, r2, v2);
_mm256_storeu_si256((__m256i*)(out + 4*out_width), _mm512_extracti64x4_epi64(t2, 0));
_mm256_storeu_si256((__m256i*)(out + 5*out_width), _mm512_extracti64x4_epi64(t2, 1));
ta = _mm512_mask_blend_epi64( mask1, ra, v2);
_mm256_storeu_si256((__m256i*)(out + 20*out_width), _mm512_extracti64x4_epi64(ta, 0));
_mm256_storeu_si256((__m256i*)(out + 21*out_width), _mm512_extracti64x4_epi64(ta, 1));
v3 = _mm512_permutex2var_epi64(r3, idx_v, rb);
t3 = _mm512_mask_blend_epi64( mask0, r3, v3);
_mm256_storeu_si256((__m256i*)(out + 6*out_width), _mm512_extracti64x4_epi64(t3, 0));
_mm256_storeu_si256((__m256i*)(out + 7*out_width), _mm512_extracti64x4_epi64(t3, 1));
tb = _mm512_mask_blend_epi64( mask1, rb, v3);
_mm256_storeu_si256((__m256i*)(out + 22*out_width), _mm512_extracti64x4_epi64(tb, 0));
_mm256_storeu_si256((__m256i*)(out + 23*out_width), _mm512_extracti64x4_epi64(tb, 1));
v4 = _mm512_permutex2var_epi64(r4, idx_v, rc);
t4 = _mm512_mask_blend_epi64( mask0, r4, v4);
_mm256_storeu_si256((__m256i*)(out + 8*out_width), _mm512_extracti64x4_epi64(t4, 0));
_mm256_storeu_si256((__m256i*)(out + 9*out_width), _mm512_extracti64x4_epi64(t4, 1));
tc = _mm512_mask_blend_epi64( mask1, rc, v4);
_mm256_storeu_si256((__m256i*)(out + 24*out_width), _mm512_extracti64x4_epi64(tc, 0));
_mm256_storeu_si256((__m256i*)(out + 25*out_width), _mm512_extracti64x4_epi64(tc, 1));
v5 = _mm512_permutex2var_epi64(r5, idx_v, rd);
t5 = _mm512_mask_blend_epi64( mask0, r5, v5);
_mm256_storeu_si256((__m256i*)(out + 10*out_width), _mm512_extracti64x4_epi64(t5, 0));
_mm256_storeu_si256((__m256i*)(out + 11*out_width), _mm512_extracti64x4_epi64(t5, 1));
td = _mm512_mask_blend_epi64( mask1, rd, v5);
_mm256_storeu_si256((__m256i*)(out + 26*out_width), _mm512_extracti64x4_epi64(td, 0));
_mm256_storeu_si256((__m256i*)(out + 27*out_width), _mm512_extracti64x4_epi64(td, 1));
v6 = _mm512_permutex2var_epi64(r6, idx_v, re);
t6 = _mm512_mask_blend_epi64( mask0, r6, v6);
_mm256_storeu_si256((__m256i*)(out + 12*out_width), _mm512_extracti64x4_epi64(t6, 0));
_mm256_storeu_si256((__m256i*)(out + 13*out_width), _mm512_extracti64x4_epi64(t6, 1));
te = _mm512_mask_blend_epi64( mask1, re, v6);
_mm256_storeu_si256((__m256i*)(out + 28*out_width), _mm512_extracti64x4_epi64(te, 0));
_mm256_storeu_si256((__m256i*)(out + 29*out_width), _mm512_extracti64x4_epi64(te, 1));
v7 = _mm512_permutex2var_epi64(r7, idx_v, rf);
t7 = _mm512_mask_blend_epi64( mask0, r7, v7);
_mm256_storeu_si256((__m256i*)(out + 14*out_width), _mm512_extracti64x4_epi64(t7, 0));
_mm256_storeu_si256((__m256i*)(out + 15*out_width), _mm512_extracti64x4_epi64(t7, 1));
tf = _mm512_mask_blend_epi64( mask1, rf, v7);
_mm256_storeu_si256((__m256i*)(out + 30*out_width), _mm512_extracti64x4_epi64(tf, 0));
_mm256_storeu_si256((__m256i*)(out + 31*out_width), _mm512_extracti64x4_epi64(tf, 1));
r0 = _mm512_loadu_si512(in + 16*32 + 0*in_width);
r1 = _mm512_loadu_si512(in + 16*32 + 1*in_width);
t0 = _mm512_unpacklo_epi16(r0,r1);
t1 = _mm512_unpackhi_epi16(r0,r1);
r2 = _mm512_loadu_si512(in + 16*32 + 2*in_width);
r3 = _mm512_loadu_si512(in + 16*32 + 3*in_width);
t2 = _mm512_unpacklo_epi16(r2,r3);
t3 = _mm512_unpackhi_epi16(r2,r3);
r4 = _mm512_loadu_si512(in + 16*32 + 4*in_width);
r5 = _mm512_loadu_si512(in + 16*32 + 5*in_width);
t4 = _mm512_unpacklo_epi16(r4,r5);
t5 = _mm512_unpackhi_epi16(r4,r5);
r6 = _mm512_loadu_si512(in + 16*32 + 6*in_width);
r7 = _mm512_loadu_si512(in + 16*32 + 7*in_width);
t6 = _mm512_unpacklo_epi16(r6,r7);
t7 = _mm512_unpackhi_epi16(r6,r7);
r8 = _mm512_loadu_si512(in + 16*32 + 8*in_width);
r9 = _mm512_loadu_si512(in + 16*32 + 9*in_width);
t8 = _mm512_unpacklo_epi16(r8,r9);
t9 = _mm512_unpackhi_epi16(r8,r9);
ra = _mm512_loadu_si512(in + 16*32 + 10*in_width);
rb = _mm512_loadu_si512(in + 16*32 + 11*in_width);
ta = _mm512_unpacklo_epi16(ra,rb);
tb = _mm512_unpackhi_epi16(ra,rb);
rc = _mm512_loadu_si512(in + 16*32 + 12*in_width);
rd = _mm512_loadu_si512(in + 16*32 + 13*in_width);
tc = _mm512_unpacklo_epi16(rc,rd);
td = _mm512_unpackhi_epi16(rc,rd);
re = _mm512_loadu_si512(in + 16*32 + 14*in_width);
rf = _mm512_loadu_si512(in + 16*32 + 15*in_width);
te = _mm512_unpacklo_epi16(re,rf);
tf = _mm512_unpackhi_epi16(re,rf);
r0 = _mm512_unpacklo_epi32(t0,t2);
r1 = _mm512_unpackhi_epi32(t0,t2);
r2 = _mm512_unpacklo_epi32(t1,t3);
r3 = _mm512_unpackhi_epi32(t1,t3);
r4 = _mm512_unpacklo_epi32(t4,t6);
r5 = _mm512_unpackhi_epi32(t4,t6);
r6 = _mm512_unpacklo_epi32(t5,t7);
r7 = _mm512_unpackhi_epi32(t5,t7);
r8 = _mm512_unpacklo_epi32(t8,ta);
r9 = _mm512_unpackhi_epi32(t8,ta);
ra = _mm512_unpacklo_epi32(t9,tb);
rb = _mm512_unpackhi_epi32(t9,tb);
rc = _mm512_unpacklo_epi32(tc,te);
rd = _mm512_unpackhi_epi32(tc,te);
re = _mm512_unpacklo_epi32(td,tf);
rf = _mm512_unpackhi_epi32(td,tf);
t0 = _mm512_unpacklo_epi64(r0,r4);
t1 = _mm512_unpackhi_epi64(r0,r4);
t2 = _mm512_unpacklo_epi64(r1,r5);
t3 = _mm512_unpackhi_epi64(r1,r5);
t4 = _mm512_unpacklo_epi64(r2,r6);
t5 = _mm512_unpackhi_epi64(r2,r6);
t6 = _mm512_unpacklo_epi64(r3,r7);
t7 = _mm512_unpackhi_epi64(r3,r7);
t8 = _mm512_unpacklo_epi64(r8,rc);
t9 = _mm512_unpackhi_epi64(r8,rc);
ta = _mm512_unpacklo_epi64(r9,rd);
tb = _mm512_unpackhi_epi64(r9,rd);
tc = _mm512_unpacklo_epi64(ra,re);
td = _mm512_unpackhi_epi64(ra,re);
te = _mm512_unpacklo_epi64(rb,rf);
tf = _mm512_unpackhi_epi64(rb,rf);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
ra = _mm512_shuffle_i32x4(tc, td, 0x88);
rb = _mm512_shuffle_i32x4(te, tf, 0x88);
rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
re = _mm512_shuffle_i32x4(tc, td, 0xdd);
rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
v0 = _mm512_permutex2var_epi64(r0, idx_v, r8);
t0 = _mm512_mask_blend_epi64( mask0, r0, v0);
_mm256_storeu_si256((__m256i*)(out + 16 + 0*out_width), _mm512_extracti64x4_epi64(t0, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 1*out_width), _mm512_extracti64x4_epi64(t0, 1));
t8 = _mm512_mask_blend_epi64( mask1, r8, v0);
_mm256_storeu_si256((__m256i*)(out + 16 + 16*out_width), _mm512_extracti64x4_epi64(t8, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 17*out_width), _mm512_extracti64x4_epi64(t8, 1));
v1 = _mm512_permutex2var_epi64(r1, idx_v, r9);
t1 = _mm512_mask_blend_epi64( mask0, r1, v1);
_mm256_storeu_si256((__m256i*)(out + 16 + 2*out_width), _mm512_extracti64x4_epi64(t1, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 3*out_width), _mm512_extracti64x4_epi64(t1, 1));
t9 = _mm512_mask_blend_epi64( mask1, r9, v1);
_mm256_storeu_si256((__m256i*)(out + 16 + 18*out_width), _mm512_extracti64x4_epi64(t9, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 19*out_width), _mm512_extracti64x4_epi64(t9, 1));
v2 = _mm512_permutex2var_epi64(r2, idx_v, ra);
t2 = _mm512_mask_blend_epi64( mask0, r2, v2);
_mm256_storeu_si256((__m256i*)(out + 16 + 4*out_width), _mm512_extracti64x4_epi64(t2, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 5*out_width), _mm512_extracti64x4_epi64(t2, 1));
ta = _mm512_mask_blend_epi64( mask1, ra, v2);
_mm256_storeu_si256((__m256i*)(out + 16 + 20*out_width), _mm512_extracti64x4_epi64(ta, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 21*out_width), _mm512_extracti64x4_epi64(ta, 1));
v3 = _mm512_permutex2var_epi64(r3, idx_v, rb);
t3 = _mm512_mask_blend_epi64( mask0, r3, v3);
_mm256_storeu_si256((__m256i*)(out + 16 + 6*out_width), _mm512_extracti64x4_epi64(t3, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 7*out_width), _mm512_extracti64x4_epi64(t3, 1));
tb = _mm512_mask_blend_epi64( mask1, rb, v3);
_mm256_storeu_si256((__m256i*)(out + 16 + 22*out_width), _mm512_extracti64x4_epi64(tb, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 23*out_width), _mm512_extracti64x4_epi64(tb, 1));
v4 = _mm512_permutex2var_epi64(r4, idx_v, rc);
t4 = _mm512_mask_blend_epi64( mask0, r4, v4);
_mm256_storeu_si256((__m256i*)(out + 16 + 8*out_width), _mm512_extracti64x4_epi64(t4, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 9*out_width), _mm512_extracti64x4_epi64(t4, 1));
tc = _mm512_mask_blend_epi64( mask1, rc, v4);
_mm256_storeu_si256((__m256i*)(out + 16 + 24*out_width), _mm512_extracti64x4_epi64(tc, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 25*out_width), _mm512_extracti64x4_epi64(tc, 1));
v5 = _mm512_permutex2var_epi64(r5, idx_v, rd);
t5 = _mm512_mask_blend_epi64( mask0, r5, v5);
_mm256_storeu_si256((__m256i*)(out + 16 + 10*out_width), _mm512_extracti64x4_epi64(t5, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 11*out_width), _mm512_extracti64x4_epi64(t5, 1));
td = _mm512_mask_blend_epi64( mask1, rd, v5);
_mm256_storeu_si256((__m256i*)(out + 16 + 26*out_width), _mm512_extracti64x4_epi64(td, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 27*out_width), _mm512_extracti64x4_epi64(td, 1));
v6 = _mm512_permutex2var_epi64(r6, idx_v, re);
t6 = _mm512_mask_blend_epi64( mask0, r6, v6);
_mm256_storeu_si256((__m256i*)(out + 16 + 12*out_width), _mm512_extracti64x4_epi64(t6, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 13*out_width), _mm512_extracti64x4_epi64(t6, 1));
te = _mm512_mask_blend_epi64( mask1, re, v6);
_mm256_storeu_si256((__m256i*)(out + 16 + 28*out_width), _mm512_extracti64x4_epi64(te, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 29*out_width), _mm512_extracti64x4_epi64(te, 1));
v7 = _mm512_permutex2var_epi64(r7, idx_v, rf);
t7 = _mm512_mask_blend_epi64( mask0, r7, v7);
_mm256_storeu_si256((__m256i*)(out + 16 + 14*out_width), _mm512_extracti64x4_epi64(t7, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 15*out_width), _mm512_extracti64x4_epi64(t7, 1));
tf = _mm512_mask_blend_epi64( mask1, rf, v7);
_mm256_storeu_si256((__m256i*)(out + 16 + 30*out_width), _mm512_extracti64x4_epi64(tf, 0));
_mm256_storeu_si256((__m256i*)(out + 16 + 31*out_width), _mm512_extracti64x4_epi64(tf, 1));
#else
LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out);
#endif
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
# define LIBXSMM_DNN_RNN_RELU_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
# define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
# define LIBXSMM_DNN_RNN_TANH_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_ck_generic.tpl.c"
} else {
/* should not happen */
}
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
return libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu(handle, kind, start_thread, tid);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16_amx.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
#define LIBXSMM_RNN_CELL_AVX512
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16_amx.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#endif
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_bf16_amx.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__ */
#define LIBXSMM_RNN_CELL_AVX512
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_bf16_amx.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
return libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu(handle, kind, start_thread, tid);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16_amx.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16_amx.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#endif
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
# define LIBXSMM_DNN_RNN_RELU_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
# define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
# define LIBXSMM_DNN_RNN_TANH_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_kcck.tpl.c"
} else {
/* should not happen */
}
#undef LIBXSMM_RNN_CELL_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
#if 0
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_ncnc_kcck_generic.tpl.c"
#endif
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid); LIBXSMM_UNUSED(kind);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_f32_f32( handle, kind, start_thread, tid );
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
if ( handle->desc.N % 2 != 0 ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu( handle, kind, start_thread, tid );
} else if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16( handle, kind, start_thread, tid );
} else if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx( handle, kind, start_thread, tid );
}
#else
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu( handle, kind, start_thread, tid );
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
#define LIBXSMM_DNN_RNN_RELU_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
#define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
#define LIBXSMM_DNN_RNN_TANH_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_ck_generic.tpl.c"
} else {
/* should not happen */
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_f32_f32( handle, kind, start_thread, tid );
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
if ( handle->desc.N % 2 != 0 ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid );
} else if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16( handle, kind, start_thread, tid );
} else if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid );
}
#else
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid );
} else if (libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid );
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
#define LIBXSMM_DNN_RNN_RELU_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
#define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
#define LIBXSMM_DNN_RNN_TANH_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_BWDUPD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_kcck.tpl.c"
} else {
/* should not happen */
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32( handle, kind, start_thread, tid );
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#elif defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32( handle, kind, start_thread, tid );
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx( handle, kind, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_RNNCELL_BACKWARD_WEIGHT_UPDATE_H
#define LIBXSMM_DNN_RNNCELL_BACKWARD_WEIGHT_UPDATE_H
#include <libxsmm_dnn.h>
#include <libxsmm_dnn_rnncell.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_ck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_nc_kcck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
#endif /* LIBXSMM_DNN_RNNCELL_BACKWARD_WEIGHT_UPDATE_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Kunal Banerjee (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_rnncell_forward.h"
#include "libxsmm_dnn_elementwise.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
# define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
# define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
# define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c"
} else {
/* should not happen */
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__, __AVX512BW__, __AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__, __AVX512BW__, __AVX512DQ__, __AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
return libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(handle, start_thread, tid);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__, __AVX512BW__, __AVX512DQ__, __AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__, __AVX512BW__, __AVX512DQ__ */
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#endif
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
# define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
# define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
# define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
# define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
# define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
# define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c"
} else {
/* should not happen */
}
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
return libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(handle, start_thread, tid);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#else
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__ */
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef libxsmm_bfloat16 element_filter_type;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
#endif
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32( handle, start_thread, tid);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx( handle, start_thread, tid);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx( handle, start_thread, tid);
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
#define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
#define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
#define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c"
} else {
/* should not happen */
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#elif defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
#define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
#define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
#define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
} else {
/* should not happen */
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( (libxsmm_target_archid >= LIBXSMM_X86_AVX512) && (libxsmm_target_archid <= LIBXSMM_X86_ALLFEAT) ) {
if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32( handle, start_thread, tid);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx( handle, start_thread, tid);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_SPR) {
status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu( handle, start_thread, tid);
} else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_SPR ) {
status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx( handle, start_thread, tid);
}
#endif
else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
#define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
#define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
#define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_FWD
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c"
} else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c"
} else {
/* should not happen */
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_RNNCELL_FORWARD_H
#define LIBXSMM_DNN_RNNCELL_FORWARD_H
#include <libxsmm_dnn.h>
#include <libxsmm_dnn_rnncell.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_RNNCELL_FORWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_softmaxloss_backward.h"
#include "libxsmm_dnn_softmaxloss_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API libxsmm_dnn_softmaxloss* libxsmm_dnn_create_softmaxloss(libxsmm_dnn_softmaxloss_desc softmaxloss_desc, libxsmm_dnn_err_t* status) {
libxsmm_dnn_softmaxloss* handle = 0;
int lpb;
/* init libxsmm */
LIBXSMM_INIT
if ( (softmaxloss_desc.datatype == LIBXSMM_DNN_DATATYPE_F32) || (softmaxloss_desc.datatype == LIBXSMM_DNN_DATATYPE_BF16) ) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle = (libxsmm_dnn_softmaxloss*)calloc(1, sizeof(libxsmm_dnn_softmaxloss));
if (0 != handle) {
*status = LIBXSMM_DNN_SUCCESS;
/* let's make the description persistent */
handle->desc = softmaxloss_desc;
/* cnn */
if ( (handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
int bk;
/* we need to compute the memory layout given the */
*status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C,
&(handle->bc), &bk, &lpb,
handle->desc.datatype, handle->desc.datatype );
/* compute the outer blocks */
handle->Bc = handle->desc.C / handle->bc;
handle->bn = 1;
handle->Bn = handle->desc.N;
} else if ( (handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0 ) {
handle->bc = handle->desc.bc;
handle->bn = handle->desc.bn;
handle->Bc = handle->desc.C / handle->bc;
handle->Bn = handle->desc.N / handle->bn;
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
free( handle );
handle = 0;
return handle;
}
/* create barrier */
handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
/* calculate scratch size for local softmaxloss copies of one feature map block per thread */
if ( softmaxloss_desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
handle->scratch_size = (sizeof(float)*handle->desc.C*handle->desc.N*2);
} else {
handle->scratch_size = 1;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
}
} else {
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
return handle;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_softmaxloss(const libxsmm_dnn_softmaxloss* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
/* Deallocate barrier */
if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
/* deallocate handle structure */
free(/*remove constness*/(libxsmm_dnn_softmaxloss*)handle);
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_softmaxloss_create_tensor_datalayout(const libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor_datalayout* layout;
*status = LIBXSMM_DNN_SUCCESS;
layout = 0;
if (handle != 0) {
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout));
if (layout != 0) {
layout->format = handle->desc.buffer_format;
if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
(type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
layout->datatype = handle->desc.datatype;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 3;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->bc;
layout->dim_size[1] = handle->Bc;
layout->dim_size[2] = handle->desc.N;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) {
layout->datatype = handle->desc.datatype;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->bc;
layout->dim_size[1] = handle->bn;
layout->dim_size[2] = handle->Bc;
layout->dim_size[3] = handle->Bn;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
}
} else if ( type == LIBXSMM_DNN_LABEL ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_I32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) {
layout->num_dims = 1;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
layout->dim_size[0] = handle->desc.N;
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return layout;
}
LIBXSMM_API size_t libxsmm_dnn_softmaxloss_get_scratch_size(const libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_err_t* status) {
size_t l_scratch_size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return l_scratch_size;
}
LIBXSMM_API void* libxsmm_dnn_softmaxloss_get_scratch_ptr(const libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_err_t* status)
{
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
return handle->scratch;
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return 0;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_bind_scratch(libxsmm_dnn_softmaxloss* handle, const void* scratch) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
uintptr_t address = (uintptr_t)scratch;
size_t offset = 0;
if (scratch == 0) {
status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
return status;
}
if (0 != handle) {
/* align the internal scratch buffer if needed */
if (address % 64 == 0) {
handle->scratch = (void*)address;
} else {
offset = (64 - address % 64);
handle->scratch = (void*)(address+offset);
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_release_scratch(libxsmm_dnn_softmaxloss* handle) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
handle->scratch = 0;
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_bind_tensor(libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_LABEL) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0 && tensor != 0) {
libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_softmaxloss_create_tensor_datalayout(handle, type, &status);
if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = (libxsmm_dnn_tensor*)tensor;
} else if ( type == LIBXSMM_DNN_LABEL ) {
handle->label = (libxsmm_dnn_tensor*)tensor;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
}
libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_softmaxloss_get_tensor(libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor* return_tensor = 0;
*status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_LABEL) ) {
*status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return return_tensor;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
return_tensor = handle->reg_input;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
return_tensor = handle->grad_input;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
return_tensor = handle->reg_output;
} else if ( type == LIBXSMM_DNN_LABEL ) {
return_tensor = handle->label;
} else {
/* cannot happen */
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return return_tensor;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_release_tensor(libxsmm_dnn_softmaxloss* handle, const libxsmm_dnn_tensor_type type) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check for tensor type */
if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
(type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_LABEL) ) {
status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
return status;
}
if (handle != 0) {
if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
handle->reg_input = 0;
} else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
handle->grad_input = 0;
} else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
handle->reg_output = 0;
} else if ( type == LIBXSMM_DNN_LABEL ) {
handle->label = 0;
} else {
/* cannot happen */
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_execute_st(libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_compute_kind kind,
/*unsigned*/int start_thread, /*unsigned*/int tid) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
switch (kind) {
case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
status = libxsmm_dnn_softmaxloss_st_fwd_ncnc( handle, start_thread, tid );
} break;
case LIBXSMM_DNN_COMPUTE_KIND_BWD: {
status = libxsmm_dnn_softmaxloss_st_bwd_ncnc( handle, start_thread, tid );
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_KIND;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return status;
}
LIBXSMM_API float libxsmm_dnn_softmaxloss_get_loss(const libxsmm_dnn_softmaxloss* handle, libxsmm_dnn_err_t* status) {
float l_loss = 0.0f;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != handle) {
l_loss = handle->loss;
} else {
*status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
}
return l_loss;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_softmaxloss_backward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef int element_label_type;
# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef int element_label_type;
# define LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512
# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c"
# undef LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and mask */
if ( handle->grad_input == 0 || handle->reg_output == 0 || handle->label == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_softmaxloss_st_bwd_ncnc_f32_f32( handle, start_thread, tid);
} else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_softmaxloss_st_bwd_ncnc_bf16_bf16( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef int element_label_type;
# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c"
} else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef int element_label_type;
# define LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16
# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c"
# undef LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_SOFTMAXLOSS_BACKWARD_H
#define LIBXSMM_DNN_SOFTMAXLOSS_BACKWARD_H
#include <libxsmm_dnn_softmaxloss.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_bwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_SOFTMAXLOSS_BACKWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_softmaxloss_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef float element_input_type;
typedef float element_output_type;
typedef int element_label_type;
# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef int element_label_type;
# define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512
# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
# undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512
#else /* should not happen */
LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
return status;
}
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* check if we have input, output and mask */
if ( handle->reg_input == 0 || handle->reg_output == 0 || handle->label == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
status = libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32( handle, start_thread, tid);
} else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
status = libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16( handle, start_thread, tid);
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
} else
#endif
{
if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
typedef float element_input_type;
typedef float element_output_type;
typedef int element_label_type;
# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
} else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
typedef libxsmm_bfloat16 element_input_type;
typedef libxsmm_bfloat16 element_output_type;
typedef int element_label_type;
# define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16
# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
# undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
return status;
}
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_SOFTMAXLOSS_FORWARD_H
#define LIBXSMM_DNN_SOFTMAXLOSS_FORWARD_H
#include <libxsmm_dnn_softmaxloss.h>
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_softmaxloss_st_fwd_ncnc(libxsmm_dnn_softmaxloss* handle, int start_thread, int tid);
#endif /* LIBXSMM_DNN_SOFTMAXLOSS_FORWARD_H */
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include <libxsmm.h>
#include "libxsmm_main.h"
#include "libxsmm_dnn_tensor.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(_OPENMP)
# include <omp.h>
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_tensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, libxsmm_dnn_err_t* status)
{
return libxsmm_dnn_link_qtensor(layout, data, 0, status);
}
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_qtensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, const unsigned char scf, libxsmm_dnn_err_t* status)
{
/* zero entire content; not only safer but also sets data and code pointers to NULL */
libxsmm_dnn_tensor* tensor = (libxsmm_dnn_tensor*)calloc(1, sizeof(libxsmm_dnn_tensor));
*status = LIBXSMM_DNN_SUCCESS;
if (layout != 0 && tensor != 0 && data != 0) {
tensor->layout = libxsmm_dnn_duplicate_tensor_datalayout(layout, status);
tensor->data = (void*)data;
tensor->scf = scf;
/* when layout copy failed, free layout */
if (*status != LIBXSMM_DNN_SUCCESS) {
libxsmm_dnn_destroy_tensor_datalayout(tensor->layout);
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_TENSOR;
}
if (*status != LIBXSMM_DNN_SUCCESS) {
free((libxsmm_dnn_tensor*)tensor);
tensor = 0;
}
return tensor;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_duplicate_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor_datalayout* dst_layout;
*status = LIBXSMM_DNN_SUCCESS;
dst_layout = 0;
if (layout != 0 && layout->num_dims != 0) {
unsigned int dim = 0;
/* zero entire content; not only safer but also sets data and code pointers to NULL */
dst_layout = (libxsmm_dnn_tensor_datalayout*)calloc(1, sizeof(libxsmm_dnn_tensor_datalayout));
if (0 != dst_layout) {
dst_layout->dim_type = (libxsmm_dnn_tensor_dimtype*)malloc(layout->num_dims * sizeof(libxsmm_dnn_tensor_dimtype));
dst_layout->dim_size = (unsigned int*)malloc(layout->num_dims * sizeof(unsigned int));
dst_layout->num_dims = layout->num_dims;
dst_layout->format = layout->format;
dst_layout->datatype = layout->datatype;
dst_layout->tensor_type = layout->tensor_type;
if (0 != dst_layout->dim_type && 0 != dst_layout->dim_size) {
for (dim = 0; dim < layout->num_dims; ++dim) {
dst_layout->dim_type[dim] = layout->dim_type[dim];
dst_layout->dim_size[dim] = layout->dim_size[dim];
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
} else {
*status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
}
return dst_layout;
}
LIBXSMM_API unsigned int libxsmm_dnn_compare_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout_a, const libxsmm_dnn_tensor_datalayout* layout_b, libxsmm_dnn_err_t* status) {
unsigned int result = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (layout_a != 0 && layout_b != 0) {
unsigned int dim = 0;
if (layout_a->num_dims != layout_b->num_dims) { result = 1; }
if (layout_a->format != layout_b->format) { result = 1; }
if (layout_a->datatype != layout_b->datatype) { result = 1; }
if (result == 0) {
for ( dim = 0; dim < layout_a->num_dims; ++dim ) {
if ( layout_a->dim_type[dim] != layout_b->dim_type[dim] ) { result = 1; }
if ( layout_a->dim_size[dim] != layout_b->dim_size[dim] ) { result = 1; }
}
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
result = 100;
}
return result;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor_datalayout(libxsmm_dnn_tensor_datalayout* layout) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != layout) {
free(layout->dim_type);
free(layout->dim_size);
free(layout);
}
else {
status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
}
return status;
}
LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_size(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) {
unsigned int size = 0;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != layout) {
unsigned int dim = 0;
size = (unsigned int)libxsmm_dnn_typesize(layout->datatype);
for (dim = 0; dim < layout->num_dims; ++dim) {
size *= layout->dim_size[dim];
}
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
}
return size;
}
LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_elements(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) {
unsigned int elements = 1;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != layout) {
unsigned int dim = 0;
for ( dim = 0; dim < layout->num_dims; ++dim ) {
elements *= layout->dim_size[dim];
}
} else {
*status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
elements = 0;
}
return elements;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_tensor_data_ptr(libxsmm_dnn_tensor* tensor, const void* data) {
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if ((0 != tensor) && (0 != data)) {
if (0 != tensor->layout) {
if (0 < tensor->layout->num_dims) {
tensor->data = (void*)data;
} else {
status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
}
} else {
status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
return status;
}
LIBXSMM_API void* libxsmm_dnn_get_tensor_data_ptr(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status)
{
*status = LIBXSMM_DNN_SUCCESS;
if (0 != tensor) {
return tensor->data;
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
return 0;
}
LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_get_tensor_datalayout(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status) {
libxsmm_dnn_tensor_datalayout* dst_layout = NULL;
*status = LIBXSMM_DNN_SUCCESS;
if (0 != tensor) {
dst_layout = libxsmm_dnn_duplicate_tensor_datalayout( tensor->layout, status );
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
return dst_layout;
}
LIBXSMM_API unsigned char libxsmm_dnn_get_qtensor_scf(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status)
{
*status = LIBXSMM_DNN_SUCCESS;
if (0 != tensor) {
return tensor->scf;
}
else {
*status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
return 0;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_qtensor_scf(libxsmm_dnn_tensor* tensor, const unsigned char scf)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != tensor) {
tensor->scf = scf;
}
else {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor(const libxsmm_dnn_tensor* tensor)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != tensor) { /* it is not an error attempting to destroy a NULL-handle */
/* free layout information stored in tensor */
if (0 != tensor->layout) {
libxsmm_dnn_destroy_tensor_datalayout( (libxsmm_dnn_tensor_datalayout*)tensor->layout );
}
/* deallocate handle structure */
free(/*remove constness*/(libxsmm_dnn_tensor*)tensor);
}
#if 0 /* releasing a NULL-buffer should be not an error (similar to freeing a NULL pointer) */
else {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
#endif
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyin_tensor(const libxsmm_dnn_tensor* tensor, const void* data, const libxsmm_dnn_tensor_format in_format)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* @TODO check for valid combination */
if (0 != tensor) {
switch (tensor->layout->tensor_type) {
case LIBXSMM_DNN_REGULAR_INPUT:
case LIBXSMM_DNN_GRADIENT_INPUT:
case LIBXSMM_DNN_REGULAR_OUTPUT:
case LIBXSMM_DNN_GRADIENT_OUTPUT:
case LIBXSMM_DNN_INPUT:
case LIBXSMM_DNN_OUTPUT:
case LIBXSMM_DNN_ACTIVATION: {
switch (in_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
switch (tensor->layout->datatype) {
case LIBXSMM_DNN_DATATYPE_F32: {
typedef float element_type;
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_BF16: {
typedef libxsmm_bfloat16 element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
case LIBXSMM_DNN_DATATYPE_I32: {
typedef int element_type;
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_I16: {
typedef short element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
case LIBXSMM_DNN_DATATYPE_I8: {
typedef unsigned char element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
}
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
}
}
} break;
case LIBXSMM_DNN_REGULAR_FILTER:
case LIBXSMM_DNN_GRADIENT_FILTER:
case LIBXSMM_DNN_FILTER: {
switch (in_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_KCRS: {
if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
switch (tensor->layout->datatype) {
case LIBXSMM_DNN_DATATYPE_F32: {
typedef float element_type;
#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_BF16: {
typedef libxsmm_bfloat16 element_type;
#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_I16: {
typedef short element_type;
#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_I8: {
typedef char element_type;
#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
}
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
}
}
} break;
case LIBXSMM_DNN_REGULAR_CHANNEL_BIAS:
case LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS:
case LIBXSMM_DNN_CHANNEL_BIAS:
case LIBXSMM_DNN_REGULAR_CHANNEL_BETA:
case LIBXSMM_DNN_GRADIENT_CHANNEL_BETA:
case LIBXSMM_DNN_CHANNEL_BETA:
case LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA:
case LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA:
case LIBXSMM_DNN_CHANNEL_GAMMA:
case LIBXSMM_DNN_CHANNEL_EXPECTVAL:
case LIBXSMM_DNN_CHANNEL_RCPSTDDEV:
case LIBXSMM_DNN_CHANNEL_VARIANCE:
case LIBXSMM_DNN_CHANNEL_SCALAR: {
switch (in_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
switch (tensor->layout->datatype) {
case LIBXSMM_DNN_DATATYPE_F32: {
typedef float element_type;
#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_BF16: {
typedef libxsmm_bfloat16 element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
case LIBXSMM_DNN_DATATYPE_I16: {
typedef short element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
case LIBXSMM_DNN_DATATYPE_I8: {
typedef char element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
}
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_zero_tensor(const libxsmm_dnn_tensor* tensor)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
if (0 != tensor) {
const size_t size = libxsmm_dnn_get_tensor_elements( tensor->layout, &status );
size_t i;
/* use for-loops to potentially leverage NUMA in the future */
switch (tensor->layout->datatype) {
case LIBXSMM_DNN_DATATYPE_F32: {
float* fp32_data = (float*)tensor->data;
for (i = 0; i < size; ++i) fp32_data[i] = 0.0f;
} break;
case LIBXSMM_DNN_DATATYPE_BF16: {
libxsmm_bfloat16* bfp16_data = (libxsmm_bfloat16*)tensor->data;
for (i = 0; i < size; ++i) bfp16_data[i] = 0;
} break;
case LIBXSMM_DNN_DATATYPE_I32: {
int* int32_data = (int*)tensor->data;
for (i = 0; i < size; ++i) int32_data[i] = 0;
} break;
case LIBXSMM_DNN_DATATYPE_I16: {
short* int16_data = (short*)tensor->data;
for (i = 0; i < size; ++i) int16_data[i] = 0;
} break;
case LIBXSMM_DNN_DATATYPE_I8: {
char* int8_data = (char*)tensor->data;
for (i = 0; i < size; ++i) int8_data[i] = 0;
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
return status;
}
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyout_tensor(const libxsmm_dnn_tensor* tensor, void* data, const libxsmm_dnn_tensor_format out_format)
{
libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
/* @TODO check for valid combination */
if (0 != tensor) {
switch (tensor->layout->tensor_type) {
case LIBXSMM_DNN_REGULAR_INPUT:
case LIBXSMM_DNN_GRADIENT_INPUT:
case LIBXSMM_DNN_REGULAR_OUTPUT:
case LIBXSMM_DNN_GRADIENT_OUTPUT:
case LIBXSMM_DNN_INPUT:
case LIBXSMM_DNN_OUTPUT:
case LIBXSMM_DNN_ACTIVATION: {
switch (out_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
switch (tensor->layout->datatype) {
case LIBXSMM_DNN_DATATYPE_F32: {
typedef float element_type;
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_BF16: {
typedef libxsmm_bfloat16 element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
case LIBXSMM_DNN_DATATYPE_I32: {
typedef int element_type;
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_I16: {
typedef short element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
case LIBXSMM_DNN_DATATYPE_I8: {
typedef unsigned char element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
}
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
}
}
} break;
case LIBXSMM_DNN_REGULAR_FILTER:
case LIBXSMM_DNN_GRADIENT_FILTER:
case LIBXSMM_DNN_FILTER: {
switch (out_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_KCRS: {
if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
switch (tensor->layout->datatype) {
case LIBXSMM_DNN_DATATYPE_F32: {
typedef float element_type;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_BF16: {
typedef libxsmm_bfloat16 element_type;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_I32: {
typedef int element_type;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_I16: {
typedef short element_type;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_I8: {
typedef char element_type;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
}
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
}
}
} break;
case LIBXSMM_DNN_REGULAR_CHANNEL_BIAS:
case LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS:
case LIBXSMM_DNN_CHANNEL_BIAS:
case LIBXSMM_DNN_REGULAR_CHANNEL_BETA:
case LIBXSMM_DNN_GRADIENT_CHANNEL_BETA:
case LIBXSMM_DNN_CHANNEL_BETA:
case LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA:
case LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA:
case LIBXSMM_DNN_CHANNEL_GAMMA:
case LIBXSMM_DNN_CHANNEL_EXPECTVAL:
case LIBXSMM_DNN_CHANNEL_RCPSTDDEV:
case LIBXSMM_DNN_CHANNEL_VARIANCE:
case LIBXSMM_DNN_CHANNEL_SCALAR: {
switch (out_format) {
case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
switch (tensor->layout->datatype) {
case LIBXSMM_DNN_DATATYPE_F32: {
typedef float element_type;
#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
} break;
case LIBXSMM_DNN_DATATYPE_BF16: {
typedef libxsmm_bfloat16 element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
case LIBXSMM_DNN_DATATYPE_I16: {
typedef short element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
case LIBXSMM_DNN_DATATYPE_I8: {
typedef char element_type;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
}
} else {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
}
} break;
default: {
status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
}
}
} break;
default: {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
}
}
else {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
return status;
}
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include "libxsmm_ext.h"
#include "libxsmm_gemm.h"
#include <libxsmm.h>
#if defined(LIBXSMM_BUILD)
#if defined(LIBXSMM_BUILD_EXT) && !defined(__STATIC)
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK
void LIBXSMM_FSYMBOL(dgemm_batch)(const char transa_array[], const char transb_array[],
const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[],
const double* b_array[], const libxsmm_blasint ldb_array[],
const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[],
const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch)
{
if (LIBXSMM_FSYMBOL(__real_dgemm_batch) != libxsmm_original_dgemm_batch_function) {
LIBXSMM_FSYMBOL(__wrap_dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
else {
libxsmm_blas_error("dgemm_batch")(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK
void LIBXSMM_FSYMBOL(sgemm_batch)(const char transa_array[], const char transb_array[],
const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[],
const float* b_array[], const libxsmm_blasint ldb_array[],
const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[],
const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch)
{
if (LIBXSMM_FSYMBOL(__real_sgemm_batch) != libxsmm_original_sgemm_batch_function) {
LIBXSMM_FSYMBOL(__wrap_sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
else {
libxsmm_blas_error("sgemm_batch")(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK
void LIBXSMM_FSYMBOL(dgemm)(const char* transa, const char* transb,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const double* alpha, const double* a, const libxsmm_blasint* lda,
const double* b, const libxsmm_blasint* ldb,
const double* beta, double* c, const libxsmm_blasint* ldc) LIBXSMM_BLAS_NOEXCEPT(gemm)
{
if (LIBXSMM_FSYMBOL(__real_dgemm) != libxsmm_original_dgemm_function) {
LIBXSMM_FSYMBOL(__wrap_dgemm)(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
else {
libxsmm_blas_error("dgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK
void LIBXSMM_FSYMBOL(sgemm)(const char* transa, const char* transb,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const float* alpha, const float* a, const libxsmm_blasint* lda,
const float* b, const libxsmm_blasint* ldb,
const float* beta, float* c, const libxsmm_blasint* ldc) LIBXSMM_BLAS_NOEXCEPT(gemm)
{
if (LIBXSMM_FSYMBOL(__real_sgemm) != libxsmm_original_sgemm_function) {
LIBXSMM_FSYMBOL(__wrap_sgemm)(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
else {
libxsmm_blas_error("sgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK
void LIBXSMM_FSYMBOL(dgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
const double* alpha, const double* a, const libxsmm_blasint* lda, const double* x, const libxsmm_blasint* incx,
const double* beta, double* y, const libxsmm_blasint* incy) LIBXSMM_BLAS_NOEXCEPT(gemv)
{
if (LIBXSMM_FSYMBOL(__real_dgemv) != libxsmm_original_dgemv_function) {
LIBXSMM_FSYMBOL(__wrap_dgemv)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
else {
libxsmm_blas_error("dgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK
void LIBXSMM_FSYMBOL(sgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
const float* alpha, const float* a, const libxsmm_blasint* lda, const float* x, const libxsmm_blasint* incx,
const float* beta, float* y, const libxsmm_blasint* incy) LIBXSMM_BLAS_NOEXCEPT(gemv)
{
if (LIBXSMM_FSYMBOL(__real_sgemv) != libxsmm_original_sgemv_function) {
LIBXSMM_FSYMBOL(__wrap_sgemv)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
else {
libxsmm_blas_error("sgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK
void dgemm_batch(const char transa_array[], const char transb_array[],
const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[],
const double* b_array[], const libxsmm_blasint ldb_array[],
const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[],
const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch)
{
LIBXSMM_FSYMBOL(dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_WEAK
void sgemm_batch(const char transa_array[], const char transb_array[],
const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[],
const float* b_array[], const libxsmm_blasint ldb_array[],
const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[],
const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch)
{
LIBXSMM_FSYMBOL(sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
#elif (0 != LIBXSMM_NO_BLAS) /* no-BLAS library */
LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_ATTRIBUTE_COMMON unsigned int libxsmm_intrinsics_mm512_rng_state0[16]);
LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_ATTRIBUTE_COMMON unsigned int libxsmm_intrinsics_mm512_rng_state1[16]);
LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_ATTRIBUTE_COMMON unsigned int libxsmm_intrinsics_mm512_rng_state2[16]);
LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_ATTRIBUTE_COMMON unsigned int libxsmm_intrinsics_mm512_rng_state3[16]);
LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_NO_TRACE void internal_noblas_sink(LIBXSMM_VARIADIC);
LIBXSMM_API_INTERN void internal_noblas_sink(LIBXSMM_VARIADIC)
{
/* does nothing else but sinking given arguments */
}
LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_NO_TRACE libxsmm_sink_function internal_noblas_error(const char* /*symbol*/);
LIBXSMM_API_INTERN libxsmm_sink_function internal_noblas_error(const char* symbol)
{
static int internal_noblas_nerror = 0;
LIBXSMM_BLAS_ERROR(symbol, &internal_noblas_nerror);
return internal_noblas_sink;
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/
void LIBXSMM_FSYMBOL(dgemm_batch)(const char transa_array[], const char transb_array[],
const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[],
const double* b_array[], const libxsmm_blasint ldb_array[],
const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[],
const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch)
{
internal_noblas_error("dgemm_batch")(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/
void LIBXSMM_FSYMBOL(sgemm_batch)(const char transa_array[], const char transb_array[],
const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[],
const float* b_array[], const libxsmm_blasint ldb_array[],
const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[],
const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch)
{
internal_noblas_error("sgemm_batch")(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/
void LIBXSMM_FSYMBOL(dgemm)(const char* transa, const char* transb,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const double* alpha, const double* a, const libxsmm_blasint* lda,
const double* b, const libxsmm_blasint* ldb,
const double* beta, double* c, const libxsmm_blasint* ldc) LIBXSMM_BLAS_NOEXCEPT(gemm)
{
internal_noblas_error("dgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/
void LIBXSMM_FSYMBOL(sgemm)(const char* transa, const char* transb,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const float* alpha, const float* a, const libxsmm_blasint* lda,
const float* b, const libxsmm_blasint* ldb,
const float* beta, float* c, const libxsmm_blasint* ldc) LIBXSMM_BLAS_NOEXCEPT(gemm)
{
internal_noblas_error("sgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/
void LIBXSMM_FSYMBOL(dgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
const double* alpha, const double* a, const libxsmm_blasint* lda, const double* x, const libxsmm_blasint* incx,
const double* beta, double* y, const libxsmm_blasint* incy) LIBXSMM_BLAS_NOEXCEPT(gemv)
{
internal_noblas_error("dgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE /*LIBXSMM_ATTRIBUTE_WEAK*/
void LIBXSMM_FSYMBOL(sgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
const float* alpha, const float* a, const libxsmm_blasint* lda, const float* x, const libxsmm_blasint* incx,
const float* beta, float* y, const libxsmm_blasint* incy) LIBXSMM_BLAS_NOEXCEPT(gemv)
{
internal_noblas_error("sgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE
void dgemm_batch(const char transa_array[], const char transb_array[],
const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[],
const double* b_array[], const libxsmm_blasint ldb_array[],
const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[],
const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch)
{
LIBXSMM_FSYMBOL(dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_ATTRIBUTE_NO_TRACE
void sgemm_batch(const char transa_array[], const char transb_array[],
const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[],
const float* b_array[], const libxsmm_blasint ldb_array[],
const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[],
const libxsmm_blasint* group_count, const libxsmm_blasint group_size[]) LIBXSMM_BLAS_NOEXCEPT(gemm_batch)
{
LIBXSMM_FSYMBOL(sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
#endif
#endif /*defined(LIBXSMM_BUILD)*/
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_EXT_H
#define LIBXSMM_EXT_H
#include "libxsmm_main.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#if defined(_OPENMP)
# if !defined(__INTEL_COMPILER)
# if defined(__clang__)
# pragma clang diagnostic push
# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)
# pragma GCC diagnostic push
# endif
# if defined(__clang__)
# pragma clang diagnostic ignored "-Wpedantic"
# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)
# pragma GCC diagnostic ignored "-Wpedantic"
# endif
# endif
# include <omp.h>
# if defined(LIBXSMM_TRACE_CALLERID_GCCBUILTIN) && !defined(__INTEL_COMPILER)
# if defined(__clang__)
# pragma clang diagnostic pop
# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)
# pragma GCC diagnostic pop
# endif
# endif
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#endif /*LIBXSMM_EXT_H*/
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include <libxsmm.h>
#include "libxsmm_gemm.h"
#include "libxsmm_ext.h"
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
# include "libxsmm_trace.h"
#endif
#if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) && 0
# define LIBXSMM_EXT_GEMM_PARGROUPS_INFO
#endif
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
# if !defined(LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH)
# define LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO)
# endif
# if !defined(LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH)
# define LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH 8/*POT*/
# endif
LIBXSMM_APIVAR_DEFINE(libxsmm_gemm_descriptor internal_ext_gemm_batchdesc[LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH]);
LIBXSMM_APIVAR_DEFINE(unsigned int internal_ext_gemm_batchdepth);
LIBXSMM_APIVAR_DEFINE(unsigned int internal_ext_gemm_batchsize);
#endif
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
LIBXSMM_API_INLINE int internal_mmbatch_sortrev(const void* stat_a, const void* stat_b)
{
const libxsmm_mmbatch_item *const a = (const libxsmm_mmbatch_item*)stat_a;
const libxsmm_mmbatch_item *const b = (const libxsmm_mmbatch_item*)stat_b;
LIBXSMM_ASSERT(NULL != stat_a && NULL != stat_b);
return a->stat.count < b->stat.count ? 1 : (b->stat.count < a->stat.count ? -1 : 0);
}
#endif /*defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)*/
LIBXSMM_API_INLINE int internal_mmbatch_flush(const libxsmm_gemm_descriptor* batchdesc,
libxsmm_blasint batchsize, libxsmm_mmbatch_item* batcharray)
{
int result = EXIT_SUCCESS;
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
if (0 != batchsize) { /* recorded/lazy multiplications */
const libxsmm_blasint itemsize = sizeof(libxsmm_mmbatch_item);
LIBXSMM_ASSERT(NULL != batchdesc && 0 < batchsize);
if (0 == (LIBXSMM_MMBATCH_FLAG_STATISTIC & batchdesc->flags)) { /* process batch */
const libxsmm_xmmfunction kernel = libxsmm_xmmdispatch(batchdesc);
if (NULL != kernel.xmm) {
const unsigned char itypesize = libxsmm_typesize((libxsmm_datatype)LIBXSMM_GETENUM_INP(batchdesc->datatype));
const unsigned char otypesize = libxsmm_typesize((libxsmm_datatype)LIBXSMM_GETENUM_OUT(batchdesc->datatype));
#if defined(_OPENMP)
if (0 == (LIBXSMM_MMBATCH_FLAG_SEQUENTIAL & batchdesc->flags)) { /* parallelized */
const int nchunks = (int)LIBXSMM_UPDIV(batchsize, libxsmm_gemm_taskgrain);
# if defined(LIBXSMM_EXT_TASKS)
if (0 == omp_get_active_level()) {
const int max_nthreads = omp_get_max_threads();
const int nthreads = LIBXSMM_MIN(max_nthreads, nchunks);
if (0 == libxsmm_gemm_tasks)
# else
if (0 == omp_in_parallel()) {
const int max_nthreads = omp_get_max_threads();
const int nthreads = LIBXSMM_MIN(max_nthreads, nchunks);
# endif
{ /* classic internal parallelization */
# pragma omp parallel num_threads(nthreads)
/*check*/libxsmm_mmbatch_kernel(
kernel, 0/*index_base*/, 0/*index_stride*/, &itemsize, &itemsize, &itemsize,
&batcharray->value.a, &batcharray->value.b, &batcharray->value.c,
0 == (LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED & batchdesc->flags) ? batchsize : -batchsize,
omp_get_thread_num(), nthreads, itypesize, otypesize, batchdesc->flags);
}
# if defined(LIBXSMM_EXT_TASKS)
else { /* internal parallelization with tasks */
# pragma omp parallel num_threads(nthreads)
{ /* first thread discovering work will launch all tasks */
# pragma omp single nowait /* anyone is good */
{ int tid; for (tid = 0; tid < nchunks/*ntasks*/; ++tid) {
# pragma omp task untied
/*check*/libxsmm_mmbatch_kernel(
kernel, 0/*index_base*/, 0/*index_stride*/, &itemsize, &itemsize, &itemsize,
&batcharray->value.a, &batcharray->value.b, &batcharray->value.c,
0 == (LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED & batchdesc->flags) ? batchsize : -batchsize,
tid, nchunks/*ntasks*/, itypesize, otypesize, batchdesc->flags);
}
}
} /* implicit synchronization (barrier) */
}
# endif
}
else { /* assume external parallelization */
int tid; for (tid = 0; tid < nchunks/*ntasks*/; ++tid) {
# if defined(LIBXSMM_EXT_TASKS)
# pragma omp task untied
#endif
/*check*/libxsmm_mmbatch_kernel(
kernel, 0/*index_base*/, 0/*index_stride*/, &itemsize, &itemsize, &itemsize,
&batcharray->value.a, &batcharray->value.b, &batcharray->value.c,
0 == (LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED & batchdesc->flags) ? batchsize : -batchsize,
tid, nchunks/*ntasks*/, itypesize, otypesize, batchdesc->flags);
}
# if defined(LIBXSMM_EXT_TASKS)
if (0 == libxsmm_nosync) { /* allow to omit synchronization */
# pragma omp taskwait
}
# endif
}
}
else
#endif
{ /* sequential */
result = libxsmm_mmbatch_kernel(
kernel, 0/*index_base*/, 0/*index_stride*/, &itemsize, &itemsize, &itemsize,
&batcharray->value.a, &batcharray->value.b, &batcharray->value.c, batchsize,
0/*tid*/, 1/*nthreads*/, itypesize, otypesize, batchdesc->flags);
}
}
else { /* no fallback */
/* several reasons to arrive here: try-lock, unsuitable SMM, etc. */
result = EXIT_FAILURE;
}
memset(batcharray, 0, (size_t)batchsize * (size_t)itemsize); /* clear */
}
else { /* print statistic */
const libxsmm_blasint limit = (LIBXSMM_GEMM_MMBATCH_VERBOSITY < libxsmm_verbosity ? batchsize/*unlimited*/ : 7/*limited*/);
unsigned int threshold, batchcount;
libxsmm_blasint count = 0, i;
LIBXSMM_ASSERT(NULL != batcharray);
qsort(batcharray, (size_t)batchsize, (size_t)itemsize, internal_mmbatch_sortrev);
batchcount = batcharray[0].stat.count;
threshold = ((LIBXSMM_GEMM_MMBATCH_VERBOSITY < libxsmm_verbosity || 3 >= batchsize) ? 0 : (batchcount / 2));
for (i = 1; i < batchsize; ++i) batchcount += batcharray[i].stat.count;
LIBXSMM_STDIO_ACQUIRE();
for (i = 0; i < batchsize; ++i) {
const libxsmm_gemm_descriptor descriptor = batcharray[i].stat.desc;
const libxsmm_blasint lda = descriptor.lda, ldb = descriptor.ldb, ldc = descriptor.ldc;
const libxsmm_blasint m = descriptor.m, n = descriptor.n, k = descriptor.k;
const char *const symbol = batcharray[i].stat.symbol;
const unsigned int ci = batcharray[i].stat.count;
LIBXSMM_MEMZERO127(batcharray + i); /* clear */
if (threshold < ci && count < limit /* limit printed statistic */
&& 0 < m && 0 < n && 0 < k)
{
const unsigned int ciperc = (unsigned int)(100.0 * ci / batchcount + 0.5);
if (0 != ciperc) {
LIBXSMM_ASSERT(0 != ci);
if (0 == count) {
fprintf(stderr, "\nLIBXSMM STATISTIC: %u multiplication%c\n", batchcount, 1 < batchcount ? 's' : ' ');
}
LIBXSMM_GEMM_PRINT2(stderr,
LIBXSMM_GETENUM_INP(descriptor.datatype), LIBXSMM_GETENUM_OUT(descriptor.datatype), descriptor.flags, m, n, k,
/*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & descriptor.flags) ? 0 : */1, NULL/*a*/, lda, NULL/*b*/, ldb,
0 != (LIBXSMM_GEMM_FLAG_BETA_0 & descriptor.flags) ? 0 : 1, NULL/*c*/, ldc);
if (NULL != symbol && 0 != *symbol) {
fprintf(stderr, ": %u%% [%s]\n", ciperc, symbol);
}
else {
fprintf(stderr, ": %u%%\n", ciperc);
}
++count;
}
else break;
}
}
LIBXSMM_STDIO_RELEASE();
}
}
#else
LIBXSMM_UNUSED(batchdesc); LIBXSMM_UNUSED(batchsize); LIBXSMM_UNUSED(batcharray);
#endif
return result;
}
#if defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT)
#if defined(LIBXSMM_BLAS_WRAP_DYNAMIC)
LIBXSMM_API libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch(void)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER(1, double, gemm_batch, libxsmm_original_dgemm_batch_function, libxsmm_original_dgemm_batch/*self*/);
/*LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_batch_function);*/
# else
LIBXSMM_BLAS_WRAPPER(0, double, gemm_batch, libxsmm_original_dgemm_batch_function, libxsmm_original_dgemm_batch/*self*/);
# endif
return libxsmm_original_dgemm_batch_function;
}
LIBXSMM_API libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch(void)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER(1, float, gemm_batch, libxsmm_original_sgemm_batch_function, libxsmm_original_sgemm_batch/*self*/);
/*LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_batch_function);*/
# else
LIBXSMM_BLAS_WRAPPER(0, float, gemm_batch, libxsmm_original_sgemm_batch_function, libxsmm_original_sgemm_batch/*self*/);
# endif
return libxsmm_original_sgemm_batch_function;
}
LIBXSMM_API libxsmm_dgemm_function libxsmm_original_dgemm(void)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER(1, double, gemm, libxsmm_original_dgemm_function, libxsmm_original_dgemm/*self*/);
LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_function);
# else
LIBXSMM_BLAS_WRAPPER(0, double, gemm, libxsmm_original_dgemm_function, libxsmm_original_dgemm/*self*/);
# endif
return libxsmm_original_dgemm_function;
}
LIBXSMM_API libxsmm_sgemm_function libxsmm_original_sgemm(void)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER(1, float, gemm, libxsmm_original_sgemm_function, libxsmm_original_sgemm/*self*/);
LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_function);
# else
LIBXSMM_BLAS_WRAPPER(0, float, gemm, libxsmm_original_sgemm_function, libxsmm_original_sgemm/*self*/);
# endif
return libxsmm_original_sgemm_function;
}
LIBXSMM_API libxsmm_dgemv_function libxsmm_original_dgemv(void)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER(1, double, gemv, libxsmm_original_dgemv_function, libxsmm_original_dgemv/*self*/);
LIBXSMM_ASSERT(NULL != libxsmm_original_dgemv_function);
# else
LIBXSMM_BLAS_WRAPPER(0, double, gemv, libxsmm_original_dgemv_function, libxsmm_original_dgemv/*self*/);
# endif
return libxsmm_original_dgemv_function;
}
LIBXSMM_API libxsmm_sgemv_function libxsmm_original_sgemv(void)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER(1, float, gemv, libxsmm_original_sgemv_function, libxsmm_original_sgemv/*self*/);
LIBXSMM_ASSERT(NULL != libxsmm_original_sgemv_function);
# else
LIBXSMM_BLAS_WRAPPER(0, float, gemv, libxsmm_original_sgemv_function, libxsmm_original_sgemv/*self*/);
# endif
return libxsmm_original_sgemv_function;
}
#endif /*defined(LIBXSMM_BLAS_WRAP_DYNAMIC)*/
LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_dgemm_batch)(
const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[],
const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
{
LIBXSMM_ASSERT(NULL != lda_array && NULL != ldb_array && NULL != ldc_array && NULL != m_array && NULL != n_array && NULL != k_array);
LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != alpha_array && NULL != beta_array);
LIBXSMM_ASSERT(NULL != group_count && NULL != group_size);
LIBXSMM_INIT
if (0 != libxsmm_gemm_wrap) {
if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */
libxsmm_dgemm_batch(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
else { /* parallelized */
libxsmm_dgemm_batch_omp(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
}
else {
LIBXSMM_GEMM_BATCH_SYMBOL(double)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
}
LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_sgemm_batch)(
const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[],
const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
{
LIBXSMM_ASSERT(NULL != lda_array && NULL != ldb_array && NULL != ldc_array && NULL != m_array && NULL != n_array && NULL != k_array);
LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != alpha_array && NULL != beta_array);
LIBXSMM_ASSERT(NULL != group_count && NULL != group_size);
LIBXSMM_INIT
if (0 != libxsmm_gemm_wrap) {
if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */
libxsmm_sgemm_batch(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
else { /* parallelized */
libxsmm_sgemm_batch_omp(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
}
else {
LIBXSMM_GEMM_BATCH_SYMBOL(float)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
}
LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_dgemm)(
const char* transa, const char* transb,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const double* alpha, const double* a, const libxsmm_blasint* lda,
const double* b, const libxsmm_blasint* ldb,
const double* beta, double* c, const libxsmm_blasint* ldc)
{
LIBXSMM_ASSERT(NULL != lda && NULL != ldb && NULL != ldc && NULL != m && NULL != n && NULL != k);
LIBXSMM_ASSERT(NULL != transa && NULL != transb && NULL != alpha && NULL != beta);
{
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
unsigned int i = 0; /* no flush */
int flags = -1;
# if !defined(NDEBUG)
static int error_once = 0;
int result = EXIT_SUCCESS;
# endif
LIBXSMM_INIT
if (0 != libxsmm_gemm_wrap && (NULL == libxsmm_mmbatch_array
|| LIBXSMM_GEMM_PRECISION_F64 != libxsmm_mmbatch_desc.datatype
|| ((unsigned int)*lda) != libxsmm_mmbatch_desc.lda
|| ((unsigned int)*ldb) != libxsmm_mmbatch_desc.ldb
|| ((unsigned int)*ldc) != libxsmm_mmbatch_desc.ldc
|| ((unsigned int)*m) != libxsmm_mmbatch_desc.m
|| ((unsigned int)*n) != libxsmm_mmbatch_desc.n
|| ((unsigned int)*k) != libxsmm_mmbatch_desc.k
|| (flags = LIBXSMM_GEMM_FLAGS(*transa, *transb)) != (int)(LIBXSMM_GEMM_FLAG_TRANS_AB & libxsmm_mmbatch_desc.flags)
|| LIBXSMM_NEQ(/*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & libxsmm_mmbatch_desc.flags) ? 0 : */1, *alpha)
|| LIBXSMM_NEQ(0 != (LIBXSMM_GEMM_FLAG_BETA_0 & libxsmm_mmbatch_desc.flags) ? 0 : 1, *beta)))
#endif
{
#if defined(_DEBUG)
const char *const env_check = getenv("LIBXSMM_GEMM_CHECK");
const double check = LIBXSMM_ABS(NULL == env_check ? 0 : atof(env_check));
void* d = NULL;
if (LIBXSMM_NEQ(0, check)) {
const size_t size = (size_t)(*ldc) * (size_t)(*n) * sizeof(double);
d = libxsmm_scratch_malloc(size, 0/*auto*/, LIBXSMM_MALLOC_INTERNAL_CALLER);
if (NULL != d && LIBXSMM_NEQ(0, *beta)) memcpy(d, c, size); /* copy destination */
}
#endif
if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */
libxsmm_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
else { /* parallelized */
libxsmm_dgemm_omp(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
#if defined(_DEBUG)
if (NULL != d) {
libxsmm_matdiff_info diff;
libxsmm_blas_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, d, ldc);
if (EXIT_SUCCESS == libxsmm_matdiff(&diff, LIBXSMM_DATATYPE_F64, *m, *n, d, c, ldc, ldc)
&& check < 100.0 * diff.normf_rel)
{
LIBXSMM_STDIO_ACQUIRE();
fprintf(stderr, "LIBXSMM: ");
libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION_F64, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
fprintf(stderr, " => %f%% ERROR\n", 100.0 * diff.normf_rel);
LIBXSMM_STDIO_RELEASE();
}
libxsmm_free(d);
}
#endif
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
if (0 != (LIBXSMM_MMBATCH_FLAG_STATISTIC & libxsmm_mmbatch_desc.flags)) {
libxsmm_descriptor_blob blob;
const libxsmm_gemm_descriptor *const descriptor = libxsmm_dgemm_descriptor_init(&blob,
*m, *n, *k, *lda, *ldb, *ldc, *alpha, *beta, LIBXSMM_GEMM_FLAGS(*transa, *transb),
LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH);
LIBXSMM_ASSERT(0 != libxsmm_mmbatch_size);
if (NULL != descriptor) {
const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size);
const unsigned int batchsize = LIBXSMM_ATOMIC_LOAD(&internal_ext_gemm_batchsize, LIBXSMM_ATOMIC_RELAXED);
const unsigned int max_size = (0 != batchsize ? (((batchsize - 1) % max_batchsize) + 1) : 0);
libxsmm_mmbatch_item *const batcharray = (libxsmm_mmbatch_item*)libxsmm_mmbatch_array;
libxsmm_mmbatch_item* batcharray_cur = batcharray;
unsigned int size = max_size;
if (libxsmm_mmbatch_size < max_size) {
size = max_size - libxsmm_mmbatch_size;
batcharray_cur += libxsmm_mmbatch_size;
}
i = libxsmm_diff_n(descriptor, batcharray_cur, sizeof(libxsmm_gemm_descriptor),
sizeof(libxsmm_mmbatch_item)/*stride*/, 0/*hint*/, size);
if (i < size) { /* update existing entry */
LIBXSMM_ATOMIC_ADD_FETCH(&batcharray_cur[i].stat.count, 1, LIBXSMM_ATOMIC_RELAXED);
}
else { /* new entry needed */
const int all = -1, shift = 0;
void* extra = 0;
i = ((LIBXSMM_ATOMIC_ADD_FETCH(&internal_ext_gemm_batchsize, 1, LIBXSMM_ATOMIC_RELAXED) - 1) % max_batchsize) + 1;
batcharray[i-1].stat.desc = *descriptor;
batcharray[i-1].stat.count = 1;
batcharray[i-1].stat.symbol = libxsmm_trace_info(NULL/*depth*/, NULL/*tid*/, &all, LIBXSMM_FUNCNAME, &shift, &all);
if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(libxsmm_mmbatch_array, NULL/*size*/, NULL/*flags*/, &extra)) {
*(libxsmm_mmbatch_flush_function*)extra = libxsmm_mmbatch_end;
}
# if !defined(NDEBUG)
else {
result = EXIT_FAILURE;
}
# endif
}
}
}
#endif
}
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
else {
libxsmm_mmbatch_item *const batcharray = (libxsmm_mmbatch_item*)libxsmm_mmbatch_array;
const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size);
i = ((LIBXSMM_ATOMIC_ADD_FETCH(&internal_ext_gemm_batchsize, 1, LIBXSMM_ATOMIC_RELAXED) - 1) % max_batchsize) + 1;
batcharray[i-1].value.a = a;
batcharray[i-1].value.b = b;
batcharray[i-1].value.c = c;
LIBXSMM_ASSERT(0 <= flags);
}
if (libxsmm_mmbatch_size == (i - 1)) { /* condition ensure to flush once (first discovery) */
# if !defined(NDEBUG)
result =
# endif
internal_mmbatch_flush(&libxsmm_mmbatch_desc, libxsmm_mmbatch_size, (libxsmm_mmbatch_item*)libxsmm_mmbatch_array);
}
# if !defined(NDEBUG) /* library code is expected to be mute */
if (EXIT_SUCCESS != result && 0 != libxsmm_verbosity &&
1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM ERROR: DGEMM batch recording failed!\n");
}
# endif
#endif
}
}
LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_sgemm)(
const char* transa, const char* transb,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const float* alpha, const float* a, const libxsmm_blasint* lda,
const float* b, const libxsmm_blasint* ldb,
const float* beta, float* c, const libxsmm_blasint* ldc)
{
LIBXSMM_ASSERT(NULL != lda && NULL != ldb && NULL != ldc && NULL != m && NULL != n && NULL != k);
LIBXSMM_ASSERT(NULL != transa && NULL != transb && NULL != alpha && NULL != beta);
{
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
unsigned int i = 0; /* no flush */
int flags = -1;
# if !defined(NDEBUG)
static int error_once = 0;
int result = EXIT_SUCCESS;
# endif
LIBXSMM_INIT
if (0 != libxsmm_gemm_wrap && (NULL == libxsmm_mmbatch_array
|| LIBXSMM_GEMM_PRECISION_F32 != libxsmm_mmbatch_desc.datatype
|| ((unsigned int)*lda) != libxsmm_mmbatch_desc.lda
|| ((unsigned int)*ldb) != libxsmm_mmbatch_desc.ldb
|| ((unsigned int)*ldc) != libxsmm_mmbatch_desc.ldc
|| ((unsigned int)*m) != libxsmm_mmbatch_desc.m
|| ((unsigned int)*n) != libxsmm_mmbatch_desc.n
|| ((unsigned int)*k) != libxsmm_mmbatch_desc.k
|| (flags = LIBXSMM_GEMM_FLAGS(*transa, *transb)) != (int)(LIBXSMM_GEMM_FLAG_TRANS_AB & libxsmm_mmbatch_desc.flags)
|| LIBXSMM_NEQ(/*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & libxsmm_mmbatch_desc.flags) ? 0 : */1, *alpha)
|| LIBXSMM_NEQ(0 != (LIBXSMM_GEMM_FLAG_BETA_0 & libxsmm_mmbatch_desc.flags) ? 0 : 1, *beta)))
#endif
{
#if defined(_DEBUG)
const char *const env_check = getenv("LIBXSMM_GEMM_CHECK");
const double check = LIBXSMM_ABS(NULL == env_check ? 0 : atof(env_check));
void* d = NULL;
if (LIBXSMM_NEQ(0, check)) {
const size_t size = (size_t)(*ldc) * (size_t)(*n) * sizeof(float);
d = libxsmm_scratch_malloc(size, 0/*auto*/, LIBXSMM_MALLOC_INTERNAL_CALLER);
if (NULL != d && LIBXSMM_NEQ(0, *beta)) memcpy(d, c, size); /* copy destination */
}
#endif
if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */
libxsmm_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
else { /* parallelized */
libxsmm_sgemm_omp(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
#if defined(_DEBUG)
if (NULL != d) {
libxsmm_matdiff_info diff;
libxsmm_blas_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, d, ldc);
if (EXIT_SUCCESS == libxsmm_matdiff(&diff, LIBXSMM_DATATYPE_F32, *m, *n, d, c, ldc, ldc)
&& check < 100.0 * diff.normf_rel)
{
LIBXSMM_STDIO_ACQUIRE();
fprintf(stderr, "LIBXSMM: ");
libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION_F32, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
fprintf(stderr, " => %f%% ERROR\n", 100.0 * diff.normf_rel);
LIBXSMM_STDIO_RELEASE();
}
libxsmm_free(d);
}
#endif
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
if (0 != (LIBXSMM_MMBATCH_FLAG_STATISTIC & libxsmm_mmbatch_desc.flags)) {
libxsmm_descriptor_blob blob;
const libxsmm_gemm_descriptor *const descriptor = libxsmm_sgemm_descriptor_init(&blob,
*m, *n, *k, *lda, *ldb, *ldc, *alpha, *beta, LIBXSMM_GEMM_FLAGS(*transa, *transb),
LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH);
LIBXSMM_ASSERT(0 != libxsmm_mmbatch_size);
if (NULL != descriptor) {
const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size);
const unsigned int batchsize = LIBXSMM_ATOMIC_LOAD(&internal_ext_gemm_batchsize, LIBXSMM_ATOMIC_RELAXED);
const unsigned int max_size = (0 != batchsize ? (((batchsize - 1) % max_batchsize) + 1) : 0);
libxsmm_mmbatch_item *const batcharray = (libxsmm_mmbatch_item*)libxsmm_mmbatch_array;
libxsmm_mmbatch_item* batcharray_cur = batcharray;
unsigned int size = max_size;
if (libxsmm_mmbatch_size < max_size) {
size = max_size - libxsmm_mmbatch_size;
batcharray_cur += libxsmm_mmbatch_size;
}
i = libxsmm_diff_n(descriptor, batcharray_cur, sizeof(libxsmm_gemm_descriptor),
sizeof(libxsmm_mmbatch_item)/*stride*/, 0/*hint*/, size);
if (i < size) { /* update existing entry */
LIBXSMM_ATOMIC_ADD_FETCH(&batcharray_cur[i].stat.count, 1, LIBXSMM_ATOMIC_RELAXED);
}
else { /* new entry needed */
const int all = -1, shift = 0;
void* extra = 0;
i = ((LIBXSMM_ATOMIC_ADD_FETCH(&internal_ext_gemm_batchsize, 1, LIBXSMM_ATOMIC_RELAXED) - 1) % max_batchsize) + 1;
batcharray[i-1].stat.desc = *descriptor;
batcharray[i-1].stat.count = 1;
batcharray[i-1].stat.symbol = libxsmm_trace_info(NULL/*depth*/, NULL/*tid*/, &all, LIBXSMM_FUNCNAME, &shift, &all);
if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(libxsmm_mmbatch_array, NULL/*size*/, NULL/*flags*/, &extra)) {
*(libxsmm_mmbatch_flush_function*)extra = libxsmm_mmbatch_end;
}
# if !defined(NDEBUG)
else {
result = EXIT_FAILURE;
}
# endif
}
}
}
#endif
}
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
else {
libxsmm_mmbatch_item *const batcharray = (libxsmm_mmbatch_item*)libxsmm_mmbatch_array;
const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size);
i = ((LIBXSMM_ATOMIC_ADD_FETCH(&internal_ext_gemm_batchsize, 1, LIBXSMM_ATOMIC_RELAXED) - 1) % max_batchsize) + 1;
batcharray[i-1].value.a = a;
batcharray[i-1].value.b = b;
batcharray[i-1].value.c = c;
LIBXSMM_ASSERT(0 <= flags);
}
if (libxsmm_mmbatch_size == (i - 1)) { /* condition ensure to flush once (first discovery) */
# if !defined(NDEBUG)
result =
# endif
internal_mmbatch_flush(&libxsmm_mmbatch_desc, libxsmm_mmbatch_size, (libxsmm_mmbatch_item*)libxsmm_mmbatch_array);
}
# if !defined(NDEBUG) /* library code is expected to be mute */
if (EXIT_SUCCESS != result && 0 != libxsmm_verbosity &&
1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM ERROR: SGEMM batch recording failed!\n");
}
# endif
#endif
}
}
LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_dgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
const double* alpha, const double* a, const libxsmm_blasint* lda, const double* x, const libxsmm_blasint* incx,
const double* beta, double* y, const libxsmm_blasint* incy)
{
LIBXSMM_ASSERT(NULL != trans && NULL != m && NULL != n && NULL != lda && NULL != incx && NULL != incy && NULL != alpha && NULL != beta);
LIBXSMM_INIT
if ((2 < libxsmm_gemm_wrap || 2 > libxsmm_gemm_wrap) && 1 == *incx && 1 == *incy && LIBXSMM_SMM(*m, 1, *n, 2/*RFO*/, sizeof(double))) {
if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */
const int flags = LIBXSMM_GEMM_FLAGS(*trans, 'N');
const libxsmm_dmmfunction xgemv = libxsmm_dmmdispatch(*m, 1, *n, lda, n/*ldb*/, m/*ldc*/, alpha, beta, &flags, NULL);
if (NULL != xgemv) {
LIBXSMM_MMCALL_LDX(xgemv, a, x, y, *m, 1, *n, *lda, *n/*ldb*/, *m/*ldc*/);
}
else {
LIBXSMM_GEMV_SYMBOL(double)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
}
else { /* TODO: parallelized */
LIBXSMM_GEMV_SYMBOL(double)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
}
else {
LIBXSMM_GEMV_SYMBOL(double)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
}
LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void LIBXSMM_FSYMBOL(__wrap_sgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
const float* alpha, const float* a, const libxsmm_blasint* lda, const float* x, const libxsmm_blasint* incx,
const float* beta, float* y, const libxsmm_blasint* incy)
{
LIBXSMM_ASSERT(NULL != trans && NULL != m && NULL != n && NULL != lda && NULL != incx && NULL != incy && NULL != alpha && NULL != beta);
LIBXSMM_INIT
if ((2 < libxsmm_gemm_wrap || 2 > libxsmm_gemm_wrap) && 1 == *incx && 1 == *incy && LIBXSMM_SMM(*m, 1, *n, 2/*RFO*/, sizeof(float))) {
if (0 != (libxsmm_gemm_wrap & 1)) { /* sequential */
const int flags = LIBXSMM_GEMM_FLAGS(*trans, 'N');
const libxsmm_smmfunction xgemv = libxsmm_smmdispatch(*m, 1, *n, lda, n/*ldb*/, m/*ldc*/, alpha, beta, &flags, NULL);
if (NULL != xgemv) {
LIBXSMM_MMCALL_LDX(xgemv, a, x, y, *m, 1, *n, *lda, *n/*ldb*/, *m/*ldc*/);
}
else {
LIBXSMM_GEMV_SYMBOL(float)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
}
else { /* TODO: parallelized */
LIBXSMM_GEMV_SYMBOL(float)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
}
else {
LIBXSMM_GEMV_SYMBOL(float)(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
}
LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void __wrap_dgemm_batch(
const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[],
const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
{
LIBXSMM_FSYMBOL(__wrap_dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
LIBXSMM_APIEXT LIBXSMM_ATTRIBUTE_USED void __wrap_sgemm_batch(
const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[],
const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
{
LIBXSMM_FSYMBOL(__wrap_sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
group_count, group_size);
}
#endif /*defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT)*/
LIBXSMM_APIEXT void libxsmm_xgemm_omp(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
const void* beta, void* c, const libxsmm_blasint* ldc)
{
libxsmm_gemm_blob blob;
#if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */
const int outerpar = omp_get_active_level(), nthreads = (0 == outerpar ? omp_get_max_threads() : omp_get_num_threads());
#elif defined(_OPENMP)
const int outerpar = omp_in_parallel(), nthreads = (0 == outerpar ? omp_get_max_threads() : 1);
#else
const int nthreads = 1;
#endif
const libxsmm_gemm_handle *const handle = libxsmm_gemm_handle_init(&blob, iprec, oprec, transa, transb,
m, n, k, lda, ldb, ldc, alpha, beta, LIBXSMM_GEMM_HANDLE_FLAG_AUTO, nthreads);
const size_t scratch_size = libxsmm_gemm_handle_get_scratch_size(handle);
void* scratch = NULL;
if (NULL != handle && (0 == scratch_size ||
NULL != (scratch = libxsmm_scratch_malloc(scratch_size, LIBXSMM_CACHELINE, LIBXSMM_MALLOC_INTERNAL_CALLER))))
{
#if defined(_OPENMP)
if (0 == outerpar) { /* enable internal parallelization */
# if defined(LIBXSMM_EXT_TASKS)
if (0 == libxsmm_gemm_tasks)
# endif
{
# pragma omp parallel num_threads(nthreads)
libxsmm_gemm_task(handle, scratch, a, b, c, omp_get_thread_num(), nthreads);
}
# if defined(LIBXSMM_EXT_TASKS)
else { /* tasks requested */
const int ntasks = nthreads; /* TODO: apply grain-size */
# pragma omp parallel num_threads(nthreads)
{ /* first thread discovering work will launch all tasks */
# pragma omp single nowait /* anyone is good */
{ int tid; for (tid = 0; tid < ntasks; ++tid) {
# pragma omp task untied
libxsmm_gemm_task(handle, scratch, a, b, c, tid, ntasks);
}
}
} /* implicit synchronization (barrier) */
}
# endif
}
else { /* assume external parallelization */
# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */
const int ntasks = nthreads; /* TODO: apply grain-size */
int tid; for (tid = 0; tid < ntasks; ++tid) {
# pragma omp task untied
libxsmm_gemm_task(handle, scratch, a, b, c, tid, ntasks);
}
if (0 == libxsmm_nosync) { /* allow to omit synchronization */
# pragma omp taskwait
}
# else
libxsmm_gemm_task(handle, scratch, a, b, c, 0/*tid*/, 1/*nthreads*/);
# endif
}
if (LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) { /* library code is expected to be mute */
const unsigned int ntasks = handle->mt * handle->nt * handle->kt;
const double imbalance = 100.0 * LIBXSMM_DELTA((unsigned int)nthreads, ntasks) / nthreads;
static double max_imbalance = 50.0;
if (max_imbalance < imbalance) {
fprintf(stderr, "LIBXSMM WARNING: XGEMM %.0f%% imbalance (%u of %i workers utilized)!\n",
imbalance, ntasks, nthreads);
max_imbalance = imbalance;
}
}
#else
libxsmm_gemm_task(handle, scratch, a, b, c, 0/*tid*/, 1/*nthreads*/);
#endif /*defined(_OPENMP)*/
libxsmm_free(scratch);
}
else { /* fallback or error */
static int error_once = 0;
if (NULL == handle) { /* fallback */
if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM WARNING: XGEMM fallback code path triggered!\n");
}
}
else if (0 != libxsmm_verbosity && /* library code is expected to be mute */
1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM ERROR: failed to allocate GEMM-scratch memory!\n");
}
libxsmm_blas_xgemm(iprec, oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
}
LIBXSMM_API_INLINE void internal_gemm_batch_omp(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
const char transa[], const char transb[], const libxsmm_blasint m[], const libxsmm_blasint n[], const libxsmm_blasint k[],
const void* alpha, const void* a[], const libxsmm_blasint lda[], const void* b[], const libxsmm_blasint ldb[],
const void* beta, void* c[], const libxsmm_blasint ldc[], libxsmm_blasint index_base, libxsmm_blasint index_stride,
const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
const libxsmm_blasint batchsize[], libxsmm_blasint group_count)
{
static int error_once = 0;
LIBXSMM_INIT
if ( /* check for sensible arguments */
#if defined(LIBXSMM_BATCH_CHECK)
NULL != a && NULL != b && NULL != c && (1 == group_count || -1 == group_count ||
(0 == index_stride && (NULL == stride_a || 0 != *stride_a) && (NULL == stride_b || 0 != *stride_b) && (NULL == stride_c || 0 != *stride_c))) &&
#endif
0 != group_count)
{
int result = EXIT_SUCCESS;
const int max_npargroups = (int)(0 < libxsmm_gemm_npargroups
? LIBXSMM_MIN(libxsmm_gemm_npargroups, LIBXSMM_GEMM_NPARGROUPS) : LIBXSMM_GEMM_NPARGROUPS);
const libxsmm_gemm_prefetch_type prefetch = libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO);
const size_t sa = (NULL != stride_a ? (size_t)(*stride_a) : sizeof(void*));
const size_t sb = (NULL != stride_b ? (size_t)(*stride_b) : sizeof(void*));
const size_t sc = (NULL != stride_c ? (size_t)(*stride_c) : sizeof(void*));
const unsigned char otypesize = libxsmm_typesize((libxsmm_datatype)oprec);
const int ngroups = (int)LIBXSMM_ABS(group_count);
int group = 0, group_next = LIBXSMM_GEMM_NPARGROUPS;
libxsmm_code_pointer kernel[LIBXSMM_GEMM_NPARGROUPS];
libxsmm_blasint base[LIBXSMM_GEMM_NPARGROUPS], i;
#if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
int kflags[LIBXSMM_GEMM_NPARGROUPS];
#endif
int max_nthreads = 1;
#if defined(_OPENMP)
# if defined(LIBXSMM_EXT_TASKS)
const int outerpar = omp_get_active_level();
# else
const int outerpar = omp_in_parallel();
# endif
if (0 == outerpar) max_nthreads = omp_get_max_threads();
#endif
for (i = 0; i < max_npargroups; ++i) {
#if !defined(NDEBUG)
kernel[i].ptr = NULL;
# if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
kflags[i] = 0;
# endif
#endif
base[i] = 0;
}
for (group = 0; group < ngroups; group = group_next, group_next += max_npargroups) {
const int npargroups = LIBXSMM_MIN(group_next, ngroups);
libxsmm_blasint size = 0;
int suitable = 0;
if (0 < group) { /* base is maintained even if par-group is not suitable */
for (i = 0; i < npargroups; ++i) {
const libxsmm_blasint isize = batchsize[group+i-1], asize = LIBXSMM_ABS(isize);
base[i] += asize;
}
}
for (i = 0; i < npargroups; ++i) {
const libxsmm_blasint g = group + i, im = m[g], in = n[g], ik = k[g];
suitable = LIBXSMM_SMM_AI(im, in, ik, 2/*RFO*/, otypesize);
if (0 != suitable) {
const libxsmm_blasint isize = batchsize[g], asize = LIBXSMM_ABS(isize);
const char *const ta = (NULL != transa ? (transa + g) : NULL);
const char *const tb = (NULL != transb ? (transb + g) : NULL);
const int flags = LIBXSMM_GEMM_PFLAGS(ta, tb, LIBXSMM_FLAGS);
const void **const galpha = &alpha, **const gbeta = &beta;
libxsmm_descriptor_blob blob;
/* coverity[ptr_arith] */
libxsmm_gemm_descriptor *const desc = libxsmm_gemm_descriptor_init2(&blob, iprec, oprec, im, in, ik,
NULL != lda ? lda[g] : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & flags) ? im : ik),
NULL != ldb ? ldb[g] : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & flags) ? ik : in),
NULL != ldc ? ldc[g] : im, NULL != alpha ? galpha[g] : NULL, NULL != beta ? gbeta[g] : NULL,
flags, prefetch);
if (NULL != desc) {
libxsmm_gemm_internal_set_batchflag(desc, c, index_stride, 0 < group_count ? isize : -asize, 1 != max_nthreads);
kernel[i].xgemm = libxsmm_xmmdispatch(desc);
}
else kernel[i].ptr = NULL;
if (NULL != kernel[i].ptr_const) {
if (size < asize) size = asize;
#if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
LIBXSMM_ASSERT(NULL != desc); /* coverity[var_deref_op] */
kflags[i] = desc->flags;
#endif
}
else {
suitable = 0;
break;
}
}
else break;
}
if (0 != suitable) { /* check if an SMM is suitable */
const unsigned char itypesize = libxsmm_typesize((libxsmm_datatype)iprec);
#if defined(_OPENMP)
const int nchunks = (int)LIBXSMM_UPDIV(size, libxsmm_gemm_taskgrain);
const int ntasks = nchunks * npargroups, nthreads = LIBXSMM_MIN(max_nthreads, ntasks);
if (1 < nthreads) {
if (0 == outerpar) { /* enable internal parallelization */
# if defined(LIBXSMM_EXT_TASKS)
if (0 == libxsmm_gemm_tasks)
# endif
{
# pragma omp parallel for num_threads(nthreads) private(i)
for (i = 0; i < ntasks; ++i) {
const libxsmm_blasint j = i * libxsmm_gemm_taskgrain, u = j / size, v = j - u * size, g = group + u;
const libxsmm_blasint isize = batchsize[g], asize = LIBXSMM_ABS(isize);
if (v < asize) {
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
libxsmm_mmkernel_info kernel_info;
#endif
/*check*/libxsmm_mmbatch_kernel(kernel[g].xgemm, index_base, index_stride, stride_a, stride_b, stride_c,
(const char*)a + sa * base[u], (const char*)b + sb * base[u], (char*)c + sc * base[u],
0 < group_count ? isize : -asize, (int)i, nchunks, itypesize, otypesize,
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
EXIT_SUCCESS == libxsmm_get_mmkernel_info(kernel[g].xgemm, &kernel_info) ? kernel_info.flags : 0);
#else
kflags[g]);
#endif
}
}
}
# if defined(LIBXSMM_EXT_TASKS)
else { /* tasks requested */
# pragma omp parallel num_threads(nthreads) private(i)
{ /* first thread discovering work will launch all tasks */
# pragma omp single nowait /* anyone is good */
for (i = 0; i < ntasks; ++i) {
const libxsmm_blasint j = i * libxsmm_gemm_taskgrain, u = j / size, v = j - u * size, g = group + u;
const libxsmm_blasint isize = batchsize[g], asize = LIBXSMM_ABS(isize);
if (v < asize) {
# pragma omp task
{
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
libxsmm_mmkernel_info kernel_info;
#endif
/*check*/libxsmm_mmbatch_kernel(kernel[g].xgemm, index_base, index_stride, stride_a, stride_b, stride_c,
(const char*)a + sa * base[u], (const char*)b + sb * base[u], (char*)c + sc * base[u],
0 < group_count ? isize : -asize, (int)i, nchunks, itypesize, otypesize,
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
EXIT_SUCCESS == libxsmm_get_mmkernel_info(kernel[g].xgemm, &kernel_info) ? kernel_info.flags : 0);
#else
kflags[g]);
#endif
}
}
}
} /* implicit synchronization (barrier) */
}
# endif
}
else { /* assume external parallelization */
for (i = 0; i < (libxsmm_blasint)ntasks; ++i) {
const libxsmm_blasint j = i * libxsmm_gemm_taskgrain, u = j / size, v = j - u * size, g = group + u;
const libxsmm_blasint isize = batchsize[g], asize = LIBXSMM_ABS(isize);
if (v < asize) {
# if defined(LIBXSMM_EXT_TASKS) /* OpenMP-tasks */
# pragma omp task
#endif
{
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
libxsmm_mmkernel_info kernel_info;
#endif
/*check*/libxsmm_mmbatch_kernel(kernel[g].xgemm, index_base, index_stride, stride_a, stride_b, stride_c,
(const char*)a + sa * base[u], (const char*)b + sb * base[u], (char*)c + sc * base[u],
0 < group_count ? isize : -asize, (int)i, nchunks, itypesize, otypesize,
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
EXIT_SUCCESS == libxsmm_get_mmkernel_info(kernel[g].xgemm, &kernel_info) ? kernel_info.flags : 0);
#else
kflags[g]);
#endif
}
}
}
# if defined(LIBXSMM_EXT_TASKS) /* OpenMP-tasks */
if (0 == libxsmm_nosync) { /* allow to omit synchronization */
# pragma omp taskwait
}
# endif
}
}
else
#endif /*defined(_OPENMP)*/
{ /* sequential */
for (i = 0; i < npargroups; ++i) {
const libxsmm_blasint g = group + i;
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
libxsmm_mmkernel_info kernel_info;
#endif
libxsmm_mmbatch_kernel(kernel[i].xgemm, index_base, index_stride, stride_a, stride_b, stride_c,
(const char*)a + sa * base[i], (const char*)b + sb * base[i], (char*)c + sc * base[i], batchsize[g],
0/*tid*/, 1/*nthreads*/, itypesize, otypesize,
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
EXIT_SUCCESS == libxsmm_get_mmkernel_info(kernel[i].xgemm, &kernel_info) ? kernel_info.flags : 0);
#else
kflags[i]);
#endif
}
}
}
else { /* trigger fallback */
result = EXIT_FAILURE;
}
if (EXIT_SUCCESS != result) {
for (i = 0; i < npargroups; ++i) {
const libxsmm_blasint g = group + i;
const char *const ta = (NULL != transa ? (transa + g) : NULL);
const char *const tb = (NULL != transb ? (transb + g) : NULL);
const int flags = LIBXSMM_GEMM_PFLAGS(ta, tb, LIBXSMM_FLAGS);
const libxsmm_blasint im = m[g], in = n[g], ik = k[g];
const libxsmm_blasint ilda = (NULL != lda ? lda[g] : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & flags) ? im : ik));
const libxsmm_blasint ildb = (NULL != ldb ? ldb[g] : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & flags) ? ik : in));
const libxsmm_blasint ildc = (NULL != ldc ? ldc[g] : im);
const void **const galpha = &alpha, **const gbeta = &beta;
/* coverity[overrun-local] */
const void *const ialpha = (NULL != alpha ? galpha[g] : NULL);
/* coverity[overrun-local] */
const void *const ibeta = (NULL != beta ? gbeta[g] : NULL);
if (EXIT_SUCCESS == libxsmm_mmbatch_blas(iprec, oprec, ta, tb, im, in, ik, ialpha,
(const char*)a + sa * base[i], &ilda, (const char*)b + sb * base[i], &ildb, ibeta, (char*)c + sc * base[i], &ildc,
index_base, index_stride, stride_a, stride_b, stride_c, batchsize[g]))
{
if (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) {
const size_t threshold = LIBXSMM_MNK_SIZE(im, in, im);
static size_t threshold_max = 0;
if (threshold_max < threshold) {
LIBXSMM_STDIO_ACQUIRE();
fprintf(stderr, "LIBXSMM WARNING: ");
libxsmm_gemm_print2(stderr, iprec, oprec, ta, tb, &im, &in, &ik,
ialpha, NULL/*a*/, &ilda, NULL/*b*/, &ildb, ibeta, NULL/*c*/, &ildc);
fprintf(stderr, " => batched GEMM/omp was falling back to BLAS!\n");
LIBXSMM_STDIO_RELEASE();
threshold_max = threshold;
}
}
}
else {
if (0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM ERROR: libxsmm_gemm_batch_omp failed!\n");
}
return; /* exit routine */
}
}
}
}
}
#if defined(LIBXSMM_BATCH_CHECK)
else if (0 != group_count && 0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM ERROR: incorrect arguments (libxsmm_gemm_batch_omp)!\n");
}
#endif
}
LIBXSMM_APIEXT void libxsmm_gemm_batch_omp(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
const void* beta, void* c, const libxsmm_blasint* ldc, libxsmm_blasint index_base, libxsmm_blasint index_stride,
const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
libxsmm_blasint batchsize)
{
internal_gemm_batch_omp(iprec, oprec, transa, transb, &m, &n, &k,
alpha, (const void**)a, lda, (const void**)b, ldb, beta, (void**)c, ldc, index_base, index_stride,
stride_a, stride_b, stride_c, &batchsize, 1);
}
LIBXSMM_APIEXT void libxsmm_dgemm_batch_omp(
const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[],
const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
{
if (NULL != group_count) {
const libxsmm_blasint ptrsize = sizeof(void*);
internal_gemm_batch_omp(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, transa_array, transb_array, m_array, n_array, k_array,
alpha_array, (const void**)a_array, lda_array, (const void**)b_array, ldb_array, beta_array, (void**)c_array, ldc_array,
0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, group_size, *group_count);
}
}
LIBXSMM_APIEXT void libxsmm_sgemm_batch_omp(
const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[],
const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
{
if (NULL != group_count) {
const libxsmm_blasint ptrsize = sizeof(void*);
internal_gemm_batch_omp(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, transa_array, transb_array, m_array, n_array, k_array,
alpha_array, (const void**)a_array, lda_array, (const void**)b_array, ldb_array, beta_array, (void**)c_array, ldc_array,
0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, group_size, *group_count);
}
}
LIBXSMM_APIEXT void libxsmm_mmbatch_begin(libxsmm_gemm_precision precision,
const int* flags, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc,
const void* alpha, const void* beta)
{
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
# if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 26115) /* try-lock is treated incorrectly by static analysis */
# endif
LIBXSMM_INIT
if (NULL != libxsmm_mmbatch_array /* batch-recording available, but not yet running */
/* currently, batch recording is only enabled if all values are present (no complex filtering) */
&& NULL != flags && NULL != alpha && NULL != beta
&& NULL != lda && NULL != ldb && NULL != ldc
&& NULL != m && NULL != n && NULL != k
&& LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_DEFAULT) == LIBXSMM_LOCK_TRYLOCK(LIBXSMM_LOCK_DEFAULT, &libxsmm_mmbatch_lock))
{
libxsmm_descriptor_blob blob;
const libxsmm_gemm_descriptor *const descriptor = libxsmm_gemm_descriptor_init(&blob, precision,
*m, *n, *k, *lda, *ldb, *ldc, alpha, beta, *flags, libxsmm_get_gemm_prefetch(LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH));
static int error_once = 0;
int result = EXIT_SUCCESS;
if (NULL != descriptor) {
const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size);
unsigned int i;
#if !defined(NDEBUG)
const unsigned int mmbatch_maxdepth = LIBXSMM_UP2POT(LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH);
LIBXSMM_ASSERT((LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH) == mmbatch_maxdepth/*is pot*/);
#endif
/* eventually overwrite the oldest entry */
i = LIBXSMM_MOD2(internal_ext_gemm_batchdepth, LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH);
internal_ext_gemm_batchdesc[i] = libxsmm_mmbatch_desc; /* backup */
++internal_ext_gemm_batchdepth;
/* ensure descriptor does not match any GEMM such that... */
LIBXSMM_MEMZERO127(&libxsmm_mmbatch_desc);
/* ...the batch stops and completely flushes */
if (0 != internal_ext_gemm_batchsize) {
result = internal_mmbatch_flush(internal_ext_gemm_batchdesc + i,
(((libxsmm_blasint)internal_ext_gemm_batchsize - 1) % max_batchsize) + 1,
(libxsmm_mmbatch_item*)libxsmm_mmbatch_array);
}
if (EXIT_SUCCESS == result) { /* enable descriptor */
internal_ext_gemm_batchsize = 0; /* reset */
if (0 == (LIBXSMM_MMBATCH_FLAG_STATISTIC & *flags)) {
libxsmm_mmbatch_desc = *descriptor;
}
else {
libxsmm_mmbatch_desc.flags = LIBXSMM_MMBATCH_FLAG_STATISTIC;
}
}
}
else {
result = EXIT_FAILURE;
}
if (EXIT_SUCCESS != result && 0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM ERROR: GEMM batch enabling failed!\n");
}
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_DEFAULT, &libxsmm_mmbatch_lock);
}
# if defined(_MSC_VER)
# pragma warning(pop)
# endif
#else
LIBXSMM_UNUSED(precision); LIBXSMM_UNUSED(flags);
LIBXSMM_UNUSED(m); LIBXSMM_UNUSED(n); LIBXSMM_UNUSED(k);
LIBXSMM_UNUSED(lda); LIBXSMM_UNUSED(ldb); LIBXSMM_UNUSED(ldc);
LIBXSMM_UNUSED(alpha); LIBXSMM_UNUSED(beta);
#endif
}
LIBXSMM_APIEXT void libxsmm_mmbatch_end(void)
{
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
# if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 26115) /* try-lock is treated incorrectly by static analysis */
# endif
/*const*/ int trystate = LIBXSMM_LOCK_TRYLOCK(LIBXSMM_LOCK_DEFAULT, &libxsmm_mmbatch_lock);
if (LIBXSMM_LOCK_ACQUIRED(LIBXSMM_LOCK_DEFAULT) == trystate) {
const unsigned int max_batchsize = (unsigned int)((LIBXSMM_GEMM_MMBATCH_SCALE) * libxsmm_mmbatch_size);
const libxsmm_gemm_descriptor flushdesc = libxsmm_mmbatch_desc;
static int error_once = 0;
#if !defined(NDEBUG)
const unsigned int mmbatch_maxdepth = LIBXSMM_UP2POT(LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH);
#endif
/* ensure descriptor does not match any GEMM such that... */
LIBXSMM_MEMZERO127(&libxsmm_mmbatch_desc);
/* ...the batch stops and completely flushes */
if (EXIT_SUCCESS == internal_mmbatch_flush(&flushdesc,
0 != internal_ext_gemm_batchsize ? (((internal_ext_gemm_batchsize - 1) % max_batchsize) + 1) : 0,
(libxsmm_mmbatch_item*)libxsmm_mmbatch_array))
{
internal_ext_gemm_batchsize = 0; /* reset */
--internal_ext_gemm_batchdepth; /* restore the previous descriptor */
assert((LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH) == mmbatch_maxdepth/*is pot*/); /* no LIBXSMM_ASSERT! */
libxsmm_mmbatch_desc = internal_ext_gemm_batchdesc[LIBXSMM_MOD2(internal_ext_gemm_batchdepth, LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH)];
}
else if (0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM ERROR: GEMM batch processing failed!\n");
}
LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_DEFAULT, &libxsmm_mmbatch_lock);
}
# if defined(_MSC_VER)
# pragma warning(pop)
# endif
#endif
}
#if defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_xgemm_omp)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*,
const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
const double*, const double*, const libxsmm_blasint*, const double*, const libxsmm_blasint*,
const double*, double*, const libxsmm_blasint*);
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_xgemm_omp)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec,
const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const double* alpha, const double* a, const libxsmm_blasint* lda, const double* b, const libxsmm_blasint* ldb,
const double* beta, double* c, const libxsmm_blasint* ldc)
{
LIBXSMM_ASSERT(NULL != iprec && NULL != oprec);
libxsmm_xgemm_omp(*iprec, *oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_dgemm_omp)(const char*, const char*,
const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
const double*, const double*, const libxsmm_blasint*,
const double*, const libxsmm_blasint*,
const double*, double*, const libxsmm_blasint*);
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_dgemm_omp)(const char* transa, const char* transb,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const double* alpha, const double* a, const libxsmm_blasint* lda,
const double* b, const libxsmm_blasint* ldb,
const double* beta, double* c, const libxsmm_blasint* ldc)
{
libxsmm_dgemm_omp(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_sgemm_omp)(const char*, const char*,
const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
const float*, const float*, const libxsmm_blasint*,
const float*, const libxsmm_blasint*,
const float*, float*, const libxsmm_blasint*);
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_sgemm_omp)(const char* transa, const char* transb,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const float* alpha, const float* a, const libxsmm_blasint* lda,
const float* b, const libxsmm_blasint* ldb,
const float* beta, float* c, const libxsmm_blasint* ldc)
{
libxsmm_sgemm_omp(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_gemm_batch_omp)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*,
const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
const void*, const void*, const libxsmm_blasint*, const void*, const libxsmm_blasint*,
const void*, void*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
const libxsmm_blasint[], const libxsmm_blasint[], const libxsmm_blasint[],
const libxsmm_blasint*);
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_gemm_batch_omp)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec,
const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
const void* beta, void* c, const libxsmm_blasint* ldc, const libxsmm_blasint* index_base, const libxsmm_blasint* index_stride,
const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
const libxsmm_blasint* batchsize)
{
LIBXSMM_ASSERT(NULL != iprec && NULL != oprec && NULL != m && NULL != n && NULL != k);
LIBXSMM_ASSERT(NULL != index_base && NULL != index_stride && NULL != batchsize);
libxsmm_gemm_batch_omp(*iprec, *oprec, transa, transb, *m, *n, *k, alpha, a, lda, b, ldb, beta, c, ldc,
*index_base, *index_stride, stride_a, stride_b, stride_c, *batchsize);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_mmbatch_begin)(const libxsmm_gemm_precision*,
const int*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
const void*, const void*);
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_mmbatch_begin)(const libxsmm_gemm_precision* precision,
const int* flags, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc,
const void* alpha, const void* beta)
{
LIBXSMM_ASSERT(NULL != precision);
libxsmm_mmbatch_begin(*precision, flags, m, n, k, lda, ldb, ldc, alpha, beta);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_mmbatch_end)(void);
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_mmbatch_end)(void)
{
libxsmm_mmbatch_end();
}
#endif /*defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include "libxsmm_xcopy.h"
#include "libxsmm_ext.h"
#define LIBXSMM_MCOPY_MT(MT, NT, M, N) ((MT) <= (M) && (NT) <= (N) && (64U * 64U) <= (((unsigned int)(M)) * (N)))
LIBXSMM_APIEXT void libxsmm_matcopy_omp(void* out, const void* in, unsigned int typesize,
libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo)
{
LIBXSMM_INIT
if (0 < typesize && 256 > typesize && m <= ldi && m <= ldo && out != in &&
((NULL != out && 0 < m && 0 < n) || (0 == m && 0 == n)))
{
if (0 < m && 0 < n) {
#if defined(_OPENMP)
unsigned int tm, tn, ts;
if (NULL != in) { /* mcopy */
tm = LIBXSMM_UPDIV(libxsmm_mcopy_mbytes, typesize);
tn = (unsigned int)(libxsmm_mcopy_nscale * tm);
ts = libxsmm_mcopy_mbytes;
}
else { /* mzero */
tm = LIBXSMM_UPDIV(libxsmm_mzero_mbytes, typesize);
tn = (unsigned int)(libxsmm_mzero_nscale * tm);
ts = libxsmm_mzero_mbytes;
}
if (0 == tm) tm = m;
if (0 == tn) tn = LIBXSMM_MIN(LIBXSMM_XCOPY_TILE_MIN, n);
if (0 != ts && ts < (tm * tn * typesize)) {
tm = LIBXSMM_MAX(ts / (tn * typesize), LIBXSMM_XCOPY_TILE_MIN);
}
if (LIBXSMM_MCOPY_MT(tm, tn, (unsigned int)m, (unsigned int)n)) { /* consider problem-size */
libxsmm_xcopykernel kernel;
kernel.ptr = NULL;
# if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 2))
if (0 != (2 & libxsmm_xcopy_jit)) { /* JIT'ted matrix-copy permitted? */
switch (typesize) {
case 8: kernel.function = libxsmm_dispatch_meltw_unary(tm, tn, &ldi, &ldo,
LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_MELTW_FLAG_UNARY_NONE,
NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/);
break;
case 4: kernel.function = libxsmm_dispatch_meltw_unary(tm, tn, &ldi, &ldo,
LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_MELTW_FLAG_UNARY_NONE,
NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/);
break;
case 2: kernel.function = libxsmm_dispatch_meltw_unary(tm, tn, &ldi, &ldo,
LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_MELTW_FLAG_UNARY_NONE,
NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/);
break;
case 1: kernel.function = libxsmm_dispatch_meltw_unary(tm, tn, &ldi, &ldo,
LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_MELTW_FLAG_UNARY_NONE,
NULL != in ? LIBXSMM_MELTW_TYPE_UNARY_IDENTITY/*mcopy*/ : LIBXSMM_MELTW_TYPE_UNARY_XOR/*mzero*/);
break;
}
}
# endif
# if defined(LIBXSMM_EXT_TASKS) && 0/* implies _OPENMP */
if (0 == omp_get_active_level())
# else
if (0 == omp_in_parallel())
# endif
{ /* enable internal parallelization */
const int nthreads = omp_get_max_threads();
# if defined(LIBXSMM_EXT_TASKS)
if (0 >= libxsmm_xcopy_taskscale)
# endif
{
# pragma omp parallel num_threads(nthreads)
libxsmm_matcopy_task_internal(out, in, typesize,
(unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo,
tm, tn, kernel, omp_get_thread_num(), nthreads);
}
# if defined(LIBXSMM_EXT_TASKS)
else { /* tasks requested */
const int ntasks = nthreads * libxsmm_xcopy_taskscale;
# pragma omp parallel num_threads(nthreads)
{ /* first thread discovering work will launch all tasks */
# pragma omp single nowait /* anyone is good */
{ int tid;
for (tid = 0; tid < ntasks; ++tid) {
# pragma omp task untied
libxsmm_matcopy_task_internal(out, in, typesize,
(unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo,
tm, tn, kernel, tid, ntasks);
}
}
}
}
# endif
}
else { /* assume external parallelization */
# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */
const int nthreads = omp_get_num_threads();
const int ntasks = (0 == libxsmm_xcopy_taskscale
? (LIBXSMM_XCOPY_TASKSCALE)
: libxsmm_xcopy_taskscale) * nthreads;
int tid;
for (tid = 0; tid < ntasks; ++tid) {
# pragma omp task untied
libxsmm_matcopy_task_internal(out, in, typesize,
(unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo,
tm, tn, kernel, tid, ntasks);
}
if (0 == libxsmm_nosync) { /* allow to omit synchronization */
# pragma omp taskwait
}
# else
libxsmm_matcopy_task_internal(out, in, typesize,
(unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo,
tm, tn, kernel, 0/*tid*/, 1/*nthreads*/);
# endif
}
}
else
#endif /*defined(_OPENMP)*/
if (NULL != in) { /* no MT, or small problem-size */
LIBXSMM_XCOPY_NONJIT(LIBXSMM_MCOPY_KERNEL,
typesize, out, in, ldi, ldo, 0, m, 0, n);
}
else { /* no MT, or small problem-size */
/* coverity[ptr_arith] */
LIBXSMM_XCOPY_NONJIT(LIBXSMM_MZERO_KERNEL,
typesize, out, in, ldi, ldo, 0, m, 0, n);
}
}
}
else {
static int error_once = 0;
if ( 0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
if (NULL == out) {
fprintf(stderr, "LIBXSMM ERROR: the matrix-copy input and/or output is NULL!\n");
}
else if (out == in) {
fprintf(stderr, "LIBXSMM ERROR: output and input of the matrix-copy must be different!\n");
}
else if (0 == typesize || 256 <= typesize) {
fprintf(stderr, "LIBXSMM ERROR: invalid type-size for matrix-copy specified!\n");
}
else if (ldi < m || ldo < m) {
fprintf(stderr, "LIBXSMM ERROR: the leading dimension(s) of the matrix-copy is/are too small!\n");
}
else if (0 > m || 0 > n) {
fprintf(stderr, "LIBXSMM ERROR: the matrix extent(s) of the matrix-copy is/are negative!\n");
}
}
}
}
LIBXSMM_APIEXT void libxsmm_otrans_omp(void* out, const void* in, unsigned int typesize,
libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo)
{
static int error_once = 0;
LIBXSMM_INIT
if (0 < typesize && 256 > typesize && m <= ldi && n <= ldo &&
((NULL != out && NULL != in && 0 < m && 0 < n) || (0 == m && 0 == n)))
{
if (0 < m && 0 < n) {
if (out != in) {
#if defined(_OPENMP)
unsigned int tm = LIBXSMM_UPDIV(libxsmm_tcopy_mbytes, typesize);
unsigned int tn = (unsigned int)(libxsmm_tcopy_nscale * tm);
if (0 == tm) tm = m;
if (0 == tn) tn = LIBXSMM_MIN(LIBXSMM_XCOPY_TILE_MIN, n);
if (0 != libxsmm_tcopy_mbytes && libxsmm_tcopy_mbytes < (tm * tn * typesize)) {
tm = LIBXSMM_MAX(libxsmm_tcopy_mbytes / (tn * typesize), LIBXSMM_XCOPY_TILE_MIN);
}
if (tm <= (unsigned int)m && tn <= (unsigned int)n) { /* consider problem-size */
libxsmm_xcopykernel kernel;
kernel.ptr = NULL;
# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */
if (0 == omp_get_active_level())
# else
if (0 == omp_in_parallel())
# endif
{ /* enable internal parallelization */
const int nthreads = omp_get_max_threads();
# if defined(LIBXSMM_EXT_TASKS)
if (0 >= libxsmm_xcopy_taskscale)
# endif
{
# pragma omp parallel num_threads(nthreads)
{ /* coverity[divide_by_zero] */
libxsmm_otrans_task_internal(out, in, typesize,
(unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo,
tm, tn, kernel, omp_get_thread_num(), nthreads);
}
}
# if defined(LIBXSMM_EXT_TASKS)
else { /* tasks requested */
const int ntasks = nthreads * libxsmm_xcopy_taskscale;
# pragma omp parallel num_threads(nthreads)
{ /* first thread discovering work will launch all tasks */
# pragma omp single nowait /* anyone is good */
{ int tid;
for (tid = 0; tid < ntasks; ++tid) {
# pragma omp task untied
libxsmm_otrans_task_internal(out, in, typesize,
(unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo,
tm, tn, kernel, tid, ntasks);
}
}
}
}
# endif
}
else { /* assume external parallelization */
# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */
const int nthreads = omp_get_num_threads();
const int ntasks = (0 == libxsmm_xcopy_taskscale
? (LIBXSMM_XCOPY_TASKSCALE)
: libxsmm_xcopy_taskscale) * nthreads;
int tid;
for (tid = 0; tid < ntasks; ++tid) {
# pragma omp task untied
libxsmm_otrans_task_internal(out, in, typesize,
(unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo,
tm, tn, kernel, tid, ntasks);
}
if (0 == libxsmm_nosync) { /* allow to omit synchronization */
# pragma omp taskwait
}
# else /* coverity[divide_by_zero] */
libxsmm_otrans_task_internal(out, in, typesize,
(unsigned int)m, (unsigned int)n, (unsigned int)ldi, (unsigned int)ldo,
tm, tn, kernel, 0/*tid*/, 1/*nthreads*/);
# endif
}
}
else
#endif /*defined(_OPENMP)*/
{ /* no MT, or small problem-size */
#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1))
libxsmm_xcopykernel kernel;
kernel.ptr = NULL;
if (0 != (1 & libxsmm_xcopy_jit)) { /* JIT'ted transpose permitted? */
switch (typesize) {
case 8: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo,
LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
break;
case 4: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo,
LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
break;
case 2: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo,
LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16,
LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
break;
case 1: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo,
LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8,
LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
break;
}
if (NULL != kernel.ptr) { /* JIT-kernel available */
LIBXSMM_TCOPY_CALL(kernel, typesize, in, ldi, out, ldo);
}
}
else
#endif
{
LIBXSMM_XCOPY_NONJIT(LIBXSMM_TCOPY_KERNEL,
typesize, out, in, ldi, ldo, 0, m, 0, n);
}
}
}
else if (ldi == ldo) {
libxsmm_itrans/*TODO: omp*/(out, typesize, m, n, ldi, ldo);
}
else if (0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM ERROR: output and input of the transpose must be different!\n");
}
}
}
else {
if (0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
if (NULL == out || NULL == in) {
fprintf(stderr, "LIBXSMM ERROR: the transpose input and/or output is NULL!\n");
}
else if (out == in) {
fprintf(stderr, "LIBXSMM ERROR: output and input of the transpose must be different!\n");
}
else if (0 == typesize || 256 <= typesize) {
fprintf(stderr, "LIBXSMM ERROR: invalid type-size for matrix-transpose specified!\n");
}
else if (ldi < m || ldo < n) {
fprintf(stderr, "LIBXSMM ERROR: the leading dimension(s) of the transpose is/are too small!\n");
}
else if (0 > m || 0 > n) {
fprintf(stderr, "LIBXSMM ERROR: the matrix extent(s) of the transpose is/are negative!\n");
}
}
}
}
LIBXSMM_APIEXT void libxsmm_itrans_batch_omp(void* inout, unsigned int typesize,
libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint ldi, libxsmm_blasint ldo,
libxsmm_blasint index_base, libxsmm_blasint index_stride,
const libxsmm_blasint stride[], libxsmm_blasint batchsize)
{
#if defined(_OPENMP)
if (1 < batchsize) { /* consider problem-size */
const libxsmm_blasint scratchsize = m * n * typesize;
const libxsmm_blasint size = LIBXSMM_ABS(batchsize);
char buffer[LIBXSMM_ITRANS_BUFFER_MAXSIZE];
char *const mat0 = (char*)inout;
void* scratch = NULL;
libxsmm_xcopykernel kernel = { NULL };
if (m != n || ldi != ldo || 127 < typesize) {
if (scratchsize <= LIBXSMM_ITRANS_BUFFER_MAXSIZE) {
scratch = buffer;
}
else {
static int error_once = 0;
LIBXSMM_INIT
if (EXIT_SUCCESS != libxsmm_xmalloc(&scratch, scratchsize, 0/*auto-align*/,
LIBXSMM_MALLOC_FLAG_SCRATCH | LIBXSMM_MALLOC_FLAG_PRIVATE,
0/*extra*/, 0/*extra_size*/)
&& 0 != libxsmm_verbosity /* library code is expected to be mute */
&& 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
{
fprintf(stderr, "LIBXSMM ERROR: failed to allocate buffer for in-place transpose!\n");
}
}
#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1))
if (0 != (1 & libxsmm_xcopy_jit) /* JIT'ted transpose permitted? */
/* avoid outgrown transpose kernel upfront */
&& (m <= LIBXSMM_CONFIG_MAX_DIM || n <= LIBXSMM_CONFIG_MAX_DIM))
{
switch (typesize) {
case 8: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo,
LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
break;
case 4: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo,
LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
break;
case 2: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo,
LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16, LIBXSMM_DATATYPE_I16,
LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
break;
case 1: kernel.function = libxsmm_dispatch_meltw_unary(m, n, &ldi, &ldo,
LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8, LIBXSMM_DATATYPE_I8,
LIBXSMM_MELTW_FLAG_UNARY_NONE, LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT);
break;
}
}
#endif
}
# if defined(LIBXSMM_EXT_TASKS) && 0/* implies _OPENMP */
if (0 == omp_get_active_level())
# else
if (0 == omp_in_parallel())
# endif
{ /* enable internal parallelization */
const int nthreads = omp_get_max_threads();
# if defined(LIBXSMM_EXT_TASKS)
if (0 >= libxsmm_xcopy_taskscale)
# endif
{
const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, nthreads);
# pragma omp parallel num_threads(nthreads)
{
const libxsmm_blasint begin = omp_get_thread_num() * tasksize;
const libxsmm_blasint span = begin + tasksize;
libxsmm_itrans_internal(mat0, scratch, typesize, m, n, ldi, ldo, index_base,
index_stride, stride, kernel, begin, LIBXSMM_MIN(span, size));
}
}
# if defined(LIBXSMM_EXT_TASKS)
else { /* tasks requested */
const int ntasks = nthreads * libxsmm_xcopy_taskscale;
const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, ntasks);
# pragma omp parallel num_threads(nthreads)
{ /* first thread discovering work will launch all tasks */
# pragma omp single nowait /* anyone is good */
{ int tid;
for (tid = 0; tid < ntasks; ++tid) {
const libxsmm_blasint begin = tid * tasksize;
const libxsmm_blasint span = begin + tasksize;
# pragma omp task untied
libxsmm_itrans_internal(mat0, scratch, typesize, m, n, ldi, ldo, index_base,
index_stride, stride, kernel, begin, LIBXSMM_MIN(span, size));
}
}
}
}
# endif
}
else { /* assume external parallelization */
# if defined(LIBXSMM_EXT_TASKS) /* implies _OPENMP */
const int nthreads = omp_get_num_threads();
const int ntasks = (0 == libxsmm_xcopy_taskscale
? (LIBXSMM_XCOPY_TASKSCALE)
: libxsmm_xcopy_taskscale) * nthreads;
const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, ntasks);
int tid;
for (tid = 0; tid < ntasks; ++tid) {
const libxsmm_blasint begin = tid * tasksize;
const libxsmm_blasint span = begin + tasksize;
# pragma omp task untied
libxsmm_itrans_internal(mat0, scratch, typesize, m, n, ldi, ldo, index_base,
index_stride, stride, kernel, begin, LIBXSMM_MIN(span, size));
}
if (0 == libxsmm_nosync) { /* allow to omit synchronization */
# pragma omp taskwait
}
# else
libxsmm_itrans_internal(mat0, scratch, typesize, m, n, ldi, ldo, index_base,
index_stride, stride, kernel, 0, batchsize);
# endif
}
if (NULL != scratch && LIBXSMM_ITRANS_BUFFER_MAXSIZE < scratchsize) {
libxsmm_xfree(scratch, 0/*no check*/);
}
}
else
#endif /*defined(_OPENMP)*/
libxsmm_itrans_batch(inout, typesize, m, n, ldi, ldo,
index_base, index_stride, stride, batchsize,
0/*tid*/, 1/*ntasks*/);
}
#if defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_matcopy_omp)(void* /*out*/, const void* /*in*/, const int* /*typesize*/,
const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const libxsmm_blasint* /*ldi*/, const libxsmm_blasint* /*ldo*/);
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_matcopy_omp)(void* out, const void* in, const int* typesize,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo)
{
libxsmm_blasint ldx;
LIBXSMM_ASSERT(NULL != typesize && 0 < *typesize && NULL != m);
ldx = *(NULL != ldi ? ldi : m);
libxsmm_matcopy_omp(out, in, (unsigned int)*typesize, *m, *(NULL != n ? n : m), ldx, NULL != ldo ? *ldo : ldx);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_otrans_omp)(void* /*out*/, const void* /*in*/, const int* /*typesize*/,
const libxsmm_blasint* /*m*/, const libxsmm_blasint* /*n*/, const libxsmm_blasint* /*ldi*/, const libxsmm_blasint* /*ldo*/);
LIBXSMM_APIEXT void LIBXSMM_FSYMBOL(libxsmm_otrans_omp)(void* out, const void* in, const int* typesize,
const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* ldi, const libxsmm_blasint* ldo)
{
libxsmm_blasint ldx;
LIBXSMM_ASSERT(NULL != typesize && 0 < *typesize && NULL != m);
ldx = *(NULL != ldi ? ldi : m);
libxsmm_otrans_omp(out, in, (unsigned int)*typesize, *m, *(NULL != n ? n : m), ldx, NULL != ldo ? *ldo : ldx);
}
#endif /*defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment