Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c454d419
Commit
c454d419
authored
May 12, 2023
by
lisj
Browse files
删除子模块的gitignore
parent
3359c1f1
Changes
264
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8567 additions
and
0 deletions
+8567
-0
third_party/libxsmm/src/libxsmm_dnn_pooling.c
third_party/libxsmm/src/libxsmm_dnn_pooling.c
+451
-0
third_party/libxsmm/src/libxsmm_dnn_pooling_backward.c
third_party/libxsmm/src/libxsmm_dnn_pooling_backward.c
+301
-0
third_party/libxsmm/src/libxsmm_dnn_pooling_backward.h
third_party/libxsmm/src/libxsmm_dnn_pooling_backward.h
+20
-0
third_party/libxsmm/src/libxsmm_dnn_pooling_forward.c
third_party/libxsmm/src/libxsmm_dnn_pooling_forward.c
+301
-0
third_party/libxsmm/src/libxsmm_dnn_pooling_forward.h
third_party/libxsmm/src/libxsmm_dnn_pooling_forward.h
+20
-0
third_party/libxsmm/src/libxsmm_dnn_rnncell.c
third_party/libxsmm/src/libxsmm_dnn_rnncell.c
+2357
-0
third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.c
.../libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.c
+1016
-0
third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.h
.../libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.h
+21
-0
third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.c
third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.c
+740
-0
third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.h
third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.h
+21
-0
third_party/libxsmm/src/libxsmm_dnn_softmaxloss.c
third_party/libxsmm/src/libxsmm_dnn_softmaxloss.c
+382
-0
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.c
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.c
+103
-0
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.h
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.h
+18
-0
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.c
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.c
+103
-0
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.h
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.h
+18
-0
third_party/libxsmm/src/libxsmm_dnn_tensor.c
third_party/libxsmm/src/libxsmm_dnn_tensor.c
+642
-0
third_party/libxsmm/src/libxsmm_ext.c
third_party/libxsmm/src/libxsmm_ext.c
+267
-0
third_party/libxsmm/src/libxsmm_ext.h
third_party/libxsmm/src/libxsmm_ext.h
+46
-0
third_party/libxsmm/src/libxsmm_ext_gemm.c
third_party/libxsmm/src/libxsmm_ext_gemm.c
+1268
-0
third_party/libxsmm/src/libxsmm_ext_xcopy.c
third_party/libxsmm/src/libxsmm_ext_xcopy.c
+472
-0
No files found.
Too many changes to show.
To preserve performance only
264 of 264+
files are displayed.
Plain diff
Email patch
third_party/libxsmm/src/libxsmm_dnn_pooling.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_pooling_backward.h"
#include "libxsmm_dnn_pooling_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API
libxsmm_dnn_pooling
*
libxsmm_dnn_create_pooling
(
libxsmm_dnn_pooling_desc
pooling_desc
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_pooling
*
handle
=
0
;
int
lpb
;
/* init libxsmm */
LIBXSMM_INIT
if
(
((
pooling_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
pooling_desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
))
||
((
pooling_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
pooling_desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
)
{
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle
=
(
libxsmm_dnn_pooling
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_pooling
));
if
(
0
!=
handle
)
{
*
status
=
LIBXSMM_DNN_SUCCESS
;
/* let's make the description persistent */
handle
->
desc
=
pooling_desc
;
/* we need to compute the memory layout given the */
*
status
=
libxsmm_dnn_get_feature_map_blocks
(
handle
->
desc
.
C
,
handle
->
desc
.
C
,
&
(
handle
->
ifmblock
),
&
(
handle
->
ofmblock
),
&
lpb
,
handle
->
desc
.
datatype_in
,
handle
->
desc
.
datatype_out
);
/* compute the outer blocks */
handle
->
blocksifm
=
handle
->
desc
.
C
/
handle
->
ifmblock
;
handle
->
blocksofm
=
handle
->
desc
.
C
/
handle
->
ofmblock
;
/* setting ofh and ofw */
handle
->
ofh
=
(
handle
->
desc
.
H
+
2
*
handle
->
desc
.
pad_h
-
handle
->
desc
.
R
)
/
handle
->
desc
.
u
+
1
;
handle
->
ofw
=
(
handle
->
desc
.
W
+
2
*
handle
->
desc
.
pad_w
-
handle
->
desc
.
S
)
/
handle
->
desc
.
v
+
1
;
/* create barrier */
handle
->
barrier
=
libxsmm_barrier_create
(
handle
->
desc
.
threads
,
1
);
/* calculate scratch size for local pooling copies of one feature map block per thread */
handle
->
scratch_size
=
(
sizeof
(
float
)
*
(
(
size_t
)
handle
->
desc
.
H
+
(
size_t
)
LIBXSMM_MAX
(
handle
->
desc
.
pad_h_in
,
handle
->
desc
.
pad_h_out
)
*
2
)
*
(
(
size_t
)
handle
->
desc
.
W
+
(
size_t
)
LIBXSMM_MAX
(
handle
->
desc
.
pad_w_in
,
handle
->
desc
.
pad_w_out
)
*
2
)
*
LIBXSMM_MAX
(
handle
->
ofmblock
,
handle
->
ifmblock
)
*
handle
->
desc
.
threads
);
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_HANDLE
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
return
handle
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_destroy_pooling
(
const
libxsmm_dnn_pooling
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
/* Deallocate barrier */
if
(
handle
->
barrier
!=
0
)
{
libxsmm_barrier_release
((
const
libxsmm_barrier
*
)
handle
->
barrier
);
}
/* deallocate handle structure */
free
(
/*remove constness*/
(
libxsmm_dnn_pooling
*
)
handle
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor_datalayout
*
libxsmm_dnn_pooling_create_tensor_datalayout
(
const
libxsmm_dnn_pooling
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor_datalayout
*
layout
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
layout
=
0
;
if
(
handle
!=
0
)
{
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout
=
(
libxsmm_dnn_tensor_datalayout
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_tensor_datalayout
));
if
(
layout
!=
0
)
{
layout
->
format
=
handle
->
desc
.
buffer_format
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
||
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_POOLING_MASK
)
)
{
if
((
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
)
{
if
(
type
==
LIBXSMM_DNN_POOLING_MASK
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_mask
;
}
else
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
}
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
W
+
(
2
*
handle
->
desc
.
pad_w_in
);
layout
->
dim_size
[
2
]
=
handle
->
desc
.
H
+
(
2
*
handle
->
desc
.
pad_h_in
);
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
(
handle
->
ofw
)
+
(
2
*
handle
->
desc
.
pad_w_out
);
layout
->
dim_size
[
2
]
=
(
handle
->
ofh
)
+
(
2
*
handle
->
desc
.
pad_h_out
);
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_POOLING_MASK
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofw
;
layout
->
dim_size
[
2
]
=
handle
->
ofh
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
/* coverity[dead_error_begin] */
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS
;
}
}
else
if
(
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
if
(
type
==
LIBXSMM_DNN_POOLING_MASK
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_mask
;
}
else
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_BF16
;
}
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
W
+
(
2
*
handle
->
desc
.
pad_w_in
);
layout
->
dim_size
[
2
]
=
handle
->
desc
.
H
+
(
2
*
handle
->
desc
.
pad_h_in
);
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
(
handle
->
ofw
)
+
(
2
*
handle
->
desc
.
pad_w_out
);
layout
->
dim_size
[
2
]
=
(
handle
->
ofh
)
+
(
2
*
handle
->
desc
.
pad_h_out
);
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_POOLING_MASK
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofw
;
layout
->
dim_size
[
2
]
=
handle
->
ofh
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
)
>
0
)
{
if
(
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
))
)
{
if
(
type
==
LIBXSMM_DNN_POOLING_MASK
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_mask
;
}
else
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
}
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
4
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
desc
.
C
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
W
+
(
2
*
handle
->
desc
.
pad_w_in
);
layout
->
dim_size
[
2
]
=
handle
->
desc
.
H
+
(
2
*
handle
->
desc
.
pad_h_in
);
layout
->
dim_size
[
3
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
desc
.
C
;
layout
->
dim_size
[
1
]
=
(
handle
->
ofw
)
+
(
2
*
handle
->
desc
.
pad_w_out
);
layout
->
dim_size
[
2
]
=
(
handle
->
ofh
)
+
(
2
*
handle
->
desc
.
pad_h_out
);
layout
->
dim_size
[
3
]
=
handle
->
desc
.
N
;
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
layout
;
}
LIBXSMM_API
size_t
libxsmm_dnn_pooling_get_scratch_size
(
const
libxsmm_dnn_pooling
*
handle
,
libxsmm_dnn_err_t
*
status
)
{
size_t
l_scratch_size
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
l_scratch_size
=
handle
->
scratch_size
+
64
;
/* 64 byte extra in case the user code does not care about alignment */
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
l_scratch_size
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_pooling_bind_scratch
(
libxsmm_dnn_pooling
*
handle
,
const
void
*
scratch
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
uintptr_t
address
=
(
uintptr_t
)
scratch
;
size_t
offset
=
0
;
if
(
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
if
(
0
!=
handle
)
{
/* align the internal scratch buffer if needed */
if
(
address
%
64
==
0
)
{
handle
->
scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch
=
(
void
*
)(
address
+
offset
);
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_pooling_release_scratch
(
libxsmm_dnn_pooling
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
handle
->
scratch
=
0
;
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_pooling_bind_tensor
(
libxsmm_dnn_pooling
*
handle
,
const
libxsmm_dnn_tensor
*
tensor
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_POOLING_MASK
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
&&
tensor
!=
0
)
{
libxsmm_dnn_tensor_datalayout
*
handle_layout
=
libxsmm_dnn_pooling_create_tensor_datalayout
(
handle
,
type
,
&
status
);
if
(
libxsmm_dnn_compare_tensor_datalayout
(
handle_layout
,
tensor
->
layout
,
&
status
)
==
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
handle
->
reg_input
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
handle
->
grad_input
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
handle
->
reg_output
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
handle
->
grad_output
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_POOLING_MASK
)
{
handle
->
mask
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_MISMATCH_TENSOR
;
}
libxsmm_dnn_destroy_tensor_datalayout
(
handle_layout
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor
*
libxsmm_dnn_pooling_get_tensor
(
libxsmm_dnn_pooling
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor
*
return_tensor
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_POOLING_MASK
)
)
{
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
return_tensor
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
return_tensor
=
handle
->
reg_input
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
return_tensor
=
handle
->
grad_input
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
return_tensor
=
handle
->
reg_output
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
return_tensor
=
handle
->
grad_output
;
}
else
if
(
type
==
LIBXSMM_DNN_POOLING_MASK
)
{
return_tensor
=
handle
->
mask
;
}
else
{
/* cannot happen */
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
return_tensor
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_pooling_release_tensor
(
libxsmm_dnn_pooling
*
handle
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_POOLING_MASK
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
handle
->
reg_input
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
handle
->
grad_input
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
handle
->
reg_output
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
handle
->
grad_output
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_POOLING_MASK
)
{
handle
->
mask
=
0
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_pooling_execute_st
(
libxsmm_dnn_pooling
*
handle
,
libxsmm_dnn_compute_kind
kind
,
/*unsigned*/
int
start_thread
,
/*unsigned*/
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
switch
(
handle
->
desc
.
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_pooling_st_fwd_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN
;
}
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
{
switch
(
handle
->
desc
.
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_pooling_st_bwd_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_pooling_backward.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_pooling_backward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c16
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c32
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c64
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c16
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c32
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c64
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c16
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c32
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c64
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c16
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
# define LIBXSMM_DNN_POOLING_BWD_BF16
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
# undef LIBXSMM_DNN_POOLING_BWD_BF16
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c32
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
# define LIBXSMM_DNN_POOLING_BWD_BF16
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
# undef LIBXSMM_DNN_POOLING_BWD_BF16
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c64
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
# define LIBXSMM_DNN_POOLING_BWD_BF16
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
# undef LIBXSMM_DNN_POOLING_BWD_BF16
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and mask */
if
(
handle
->
grad_input
==
0
||
handle
->
grad_output
==
0
||
(
(
handle
->
mask
==
0
)
&&
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
)
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
ofmblock
==
16
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c16
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
ofmblock
==
32
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
ofmblock
==
64
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_bwd_custom_f32_f32_c64
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_bwd_custom_bf16_bf16_c64
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
# define LIBXSMM_DNN_POOLING_BWD_BF16
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_BWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_BWD_AVG
# include "template/libxsmm_dnn_pooling_st_bwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_BWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
# undef LIBXSMM_DNN_POOLING_BWD_BF16
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_nhwc
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_pooling_backward.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_POOLING_BACKWARD_H
#define LIBXSMM_DNN_POOLING_BACKWARD_H
#include <libxsmm_dnn_pooling.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_custom
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_bwd_nhwc
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_POOLING_BACKWARD_H */
third_party/libxsmm/src/libxsmm_dnn_pooling_forward.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_pooling_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
# define LIBXSMM_DNN_POOLING_FWD_BF16
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
# undef LIBXSMM_DNN_POOLING_FWD_BF16
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
# define LIBXSMM_DNN_POOLING_FWD_BF16
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
# undef LIBXSMM_DNN_POOLING_FWD_BF16
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
# define LIBXSMM_DNN_POOLING_FWD_BF16
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
# undef LIBXSMM_DNN_POOLING_FWD_BF16
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and mask */
if
(
handle
->
reg_input
==
0
||
handle
->
reg_output
==
0
||
(
(
handle
->
mask
==
0
)
&&
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
)
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
ofmblock
==
16
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
ofmblock
==
32
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
ofmblock
==
64
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
LIBXSMM_ASSERT
(
NULL
!=
handle
->
mask
);
status
=
libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
# define LIBXSMM_DNN_POOLING_FWD_BF16
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_MAX
)
{
# define LIBXSMM_DNN_POOLING_FWD_MAX
typedef
int
element_mask_type
;
# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_MAX
}
else
if
(
handle
->
desc
.
pooling_type
==
LIBXSMM_DNN_POOLING_AVG
)
{
# define LIBXSMM_DNN_POOLING_FWD_AVG
# include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
# undef LIBXSMM_DNN_POOLING_FWD_AVG
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
;
}
# undef LIBXSMM_DNN_POOLING_FWD_BF16
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_nhwc
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_pooling_forward.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_POOLING_FORWARD_H
#define LIBXSMM_DNN_POOLING_FORWARD_H
#include <libxsmm_dnn_pooling.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_custom
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_pooling_st_fwd_nhwc
(
libxsmm_dnn_pooling
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_POOLING_FORWARD_H */
third_party/libxsmm/src/libxsmm_dnn_rnncell.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas, Kunal Banerjee (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_rnncell_forward.h"
#include "libxsmm_dnn_rnncell_backward_weight_update.h"
#include "libxsmm_dnn_elementwise.h"
#include "libxsmm_main.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
LIBXSMM_API
libxsmm_dnn_rnncell
*
libxsmm_dnn_create_rnncell
(
libxsmm_dnn_rnncell_desc
rnncell_desc
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_rnncell
*
handle
=
0
;
/* init libxsmm */
LIBXSMM_INIT
/* some check we can do before allocating the handle */
if
(
(
rnncell_desc
.
datatype_in
!=
rnncell_desc
.
datatype_out
)
||
(
(
rnncell_desc
.
datatype_in
!=
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
rnncell_desc
.
datatype_in
!=
LIBXSMM_DNN_DATATYPE_F32
)
)
)
{
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
NULL
;
}
/* let's do some simple checks for BF16 as this limits the cell and architecture */
if
(
(
rnncell_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
rnncell_desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
if
(
(
LIBXSMM_X86_AVX512_CORE
>
libxsmm_target_archid
)
||
(
rnncell_desc
.
C
%
16
!=
0
)
||
(
rnncell_desc
.
K
%
16
!=
0
)
)
{
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
NULL
;
}
}
/* we need at least one timestep */
if
(
rnncell_desc
.
max_T
<
1
)
{
*
status
=
LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL
;
return
NULL
;
}
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle
=
(
libxsmm_dnn_rnncell
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_rnncell
));
if
(
NULL
!=
handle
)
{
*
status
=
LIBXSMM_DNN_SUCCESS
;
/* initialize known handle components */
handle
->
desc
=
rnncell_desc
;
/* set current seq length to max length */
handle
->
T
=
rnncell_desc
.
max_T
;
/* set blocking factors */
handle
->
bk
=
(
handle
->
desc
.
bk
==
0
)
?
64
:
handle
->
desc
.
bk
;
handle
->
bn
=
(
handle
->
desc
.
bn
==
0
)
?
64
:
handle
->
desc
.
bn
;
handle
->
bc
=
(
handle
->
desc
.
bc
==
0
)
?
64
:
handle
->
desc
.
bc
;
handle
->
use_fwd_fused_impl
=
handle
->
desc
.
use_fwd_fused_impl
;
handle
->
fwd_block
=
handle
->
desc
.
fwd_block
;
handle
->
bwdupd_block
=
handle
->
desc
.
bwdupd_block
;
if
(
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
handle
->
lpb
=
2
;
}
else
{
handle
->
lpb
=
1
;
}
/* validate blocking factors */
if
(
handle
->
desc
.
N
%
handle
->
bn
!=
0
)
{
handle
->
bn
=
handle
->
desc
.
N
;
*
status
=
LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING
;
}
if
(
handle
->
desc
.
C
%
handle
->
bc
!=
0
)
{
handle
->
bc
=
handle
->
desc
.
C
;
*
status
=
LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING
;
}
if
(
handle
->
desc
.
K
%
handle
->
bk
!=
0
)
{
handle
->
bk
=
handle
->
desc
.
K
;
*
status
=
LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING
;
}
/* If in SPR, generate tilerelease kernel */
if
((
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
int
l_tr_flags
=
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
handle
->
tilerelease_kernel
=
libxsmm_bsmmdispatch
(
handle
->
bk
,
handle
->
bk
,
handle
->
bk
,
NULL
,
NULL
,
NULL
,
NULL
,
NULL
,
&
l_tr_flags
,
NULL
);
}
/* In case of BF16 for now hoist the BRGEMM and make them to use STRIDED variant by default */
if
(
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
libxsmm_blasint
BF
,
CB_BLOCKS
,
KB_BLOCKS
;
const
libxsmm_blasint
K
=
handle
->
desc
.
K
;
const
libxsmm_blasint
N
=
handle
->
desc
.
N
;
const
libxsmm_blasint
C
=
handle
->
desc
.
C
;
const
libxsmm_blasint
bk
=
handle
->
bk
;
const
libxsmm_blasint
bn
=
handle
->
bn
;
const
libxsmm_blasint
bc
=
handle
->
bc
;
const
libxsmm_blasint
cBlocks
=
C
/
bc
;
const
libxsmm_blasint
kBlocks
=
K
/
bk
;
const
libxsmm_blasint
nBlocks
=
N
/
bn
;
int
tc_flags
=
0
;
int
kernel_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
int
stride_a
,
stride_b
;
if
((
libxsmm_target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
kernel_flags
=
((
handle
->
bk
%
32
==
0
)
&&
(
handle
->
bc
%
32
==
0
)
&&
(
handle
->
bn
%
32
==
0
))
?
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
:
0
;
kernel_flags
=
kernel_flags
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
tc_flags
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
}
/* Blocking reduction domain if it is too large */
BF
=
1
;
if
((
C
>
1024
&&
C
<=
2048
)
||
(
K
>
1024
&&
K
<=
2048
))
{
BF
=
8
;
while
(
(
cBlocks
%
BF
!=
0
)
||
(
kBlocks
%
BF
!=
0
)
)
{
BF
--
;
}
}
if
(
C
>
2048
||
K
>
2048
)
{
BF
=
16
;
while
(
(
cBlocks
%
BF
!=
0
)
||
(
kBlocks
%
BF
!=
0
)
)
{
BF
--
;
}
}
if
(
C
==
2048
&&
K
==
1024
)
{
BF
=
2
;
}
BF
=
handle
->
fwd_block
;
if
(
handle
->
desc
.
buffer_format
==
LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED
)
{
CB_BLOCKS
=
cBlocks
/
BF
;
KB_BLOCKS
=
kBlocks
/
BF
;
/* define batch-reduce gemm kernels */
stride_a
=
bc
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bc
*
bn
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
fwd_kernela
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bn
,
bc
,
stride_a
,
stride_b
,
CB_BLOCKS
,
&
bk
,
&
bc
,
&
bk
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
stride_a
=
bk
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bk
*
bn
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
fwd_kernelb
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bn
,
bk
,
stride_a
,
stride_b
,
KB_BLOCKS
,
&
bk
,
&
bk
,
&
bk
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
if
((
libxsmm_target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
handle
->
fwd_tileconfig
=
libxsmm_bsmmdispatch_reducebatch_addr
(
bk
,
bn
,
bk
,
&
bk
,
&
K
,
&
K
,
NULL
,
NULL
,
&
tc_flags
,
NULL
);
}
BF
=
handle
->
bwdupd_block
;
KB_BLOCKS
=
kBlocks
/
BF
;
stride_a
=
bc
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bk
*
bn
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
bwdupd_kernela
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bc
,
bn
,
bk
,
stride_a
,
stride_b
,
KB_BLOCKS
,
&
bc
,
&
bk
,
&
bc
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
stride_a
=
bn
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bn
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
bwdupd_kernelb
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bk
,
bn
,
stride_a
,
stride_b
,
nBlocks
,
&
bk
,
&
bn
,
&
bk
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
stride_a
=
bn
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bn
*
bc
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
bwdupd_kernelc
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bc
,
bn
,
stride_a
,
stride_b
,
nBlocks
,
&
bk
,
&
bn
,
&
bk
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
stride_a
=
bk
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bn
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
bwdupd_kerneld
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bn
,
bk
,
stride_a
,
stride_b
,
KB_BLOCKS
,
&
bk
,
&
bk
,
&
bk
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
if
((
libxsmm_target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
handle
->
bwdupd_tileconfig
=
libxsmm_bsmmdispatch_reducebatch_addr
(
bk
,
bn
,
bk
,
&
bk
,
&
K
,
&
K
,
NULL
,
NULL
,
&
tc_flags
,
NULL
);
}
}
else
{
CB_BLOCKS
=
cBlocks
/
BF
;
KB_BLOCKS
=
kBlocks
/
BF
;
/* define batch-reduce gemm kernels */
stride_a
=
bc
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bc
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
fwd_kernela
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bn
,
bc
,
stride_a
,
stride_b
,
CB_BLOCKS
,
&
bk
,
&
C
,
&
K
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
stride_a
=
bk
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
fwd_kernelb
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bn
,
bk
,
stride_a
,
stride_b
,
KB_BLOCKS
,
&
bk
,
&
K
,
&
K
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
if
((
libxsmm_target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
handle
->
fwd_tileconfig
=
libxsmm_bsmmdispatch_reducebatch_addr
(
bk
,
bn
,
bk
,
&
bk
,
&
K
,
&
K
,
NULL
,
NULL
,
&
tc_flags
,
NULL
);
}
BF
=
handle
->
bwdupd_block
;
KB_BLOCKS
=
kBlocks
/
BF
;
stride_a
=
bc
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
bwdupd_kernela
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bc
,
bn
,
bk
,
stride_a
,
stride_b
,
KB_BLOCKS
,
&
bc
,
&
K
,
&
C
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
stride_a
=
bn
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bn
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
bwdupd_kernelb
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bk
,
bn
,
stride_a
,
stride_b
,
nBlocks
,
&
bk
,
&
N
,
&
bk
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
stride_a
=
bn
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bn
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
bwdupd_kernelc
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bc
,
bn
,
stride_a
,
stride_b
,
nBlocks
,
&
bk
,
&
N
,
&
bk
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
stride_a
=
bk
*
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
stride_b
=
bk
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
handle
->
bwdupd_kerneld
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
bk
,
bn
,
bk
,
stride_a
,
stride_b
,
KB_BLOCKS
,
&
bk
,
&
K
,
&
K
,
NULL
,
NULL
,
&
kernel_flags
,
NULL
);
if
((
libxsmm_target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
handle
->
bwdupd_tileconfig
=
libxsmm_bsmmdispatch_reducebatch_addr
(
bk
,
bn
,
bk
,
&
bk
,
&
K
,
&
K
,
NULL
,
NULL
,
&
tc_flags
,
NULL
);
}
}
}
/* Need to allocate space for scratch libxsmm_dnn_tensor's, let's set all pointers to zero */
handle
->
internal_z
=
0
;
handle
->
scratch_wT
=
0
;
handle
->
scratch_rT
=
0
;
handle
->
scratch_xT
=
0
;
handle
->
scratch_hT
=
0
;
handle
->
scratch_deltat
=
0
;
handle
->
scratch_di
=
0
;
handle
->
scratch_df
=
0
;
handle
->
scratch_do
=
0
;
handle
->
scratch_dci
=
0
;
handle
->
scratch_diB
=
0
;
handle
->
scratch_dfB
=
0
;
handle
->
scratch_dpB
=
0
;
handle
->
scratch_dciB
=
0
;
/* initialize a high-performant barrier */
handle
->
barrier
=
libxsmm_barrier_create
(
handle
->
desc
.
threads
,
1
);
if
(
NULL
==
handle
->
barrier
)
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_HANDLE
;
free
(
handle
);
return
NULL
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_HANDLE
;
}
return
handle
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_destroy_rnncell
(
const
libxsmm_dnn_rnncell
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
/* Deallocate barrier */
if
(
handle
->
barrier
!=
0
)
{
libxsmm_barrier_release
((
const
libxsmm_barrier
*
)
handle
->
barrier
);
}
/* deallocate handle structure */
free
(
/*remove constness*/
(
libxsmm_dnn_rnncell
*
)
handle
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor_datalayout
*
libxsmm_dnn_rnncell_create_tensor_datalayout
(
const
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor_datalayout
*
layout
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
layout
=
0
;
if
(
handle
!=
0
)
{
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout
=
(
libxsmm_dnn_tensor_datalayout
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_tensor_datalayout
));
if
(
layout
!=
0
)
{
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_I
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_F
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_O
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CI
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CO
)
)
{
layout
->
format
=
handle
->
desc
.
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_ACTIVATION
;
if
((
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED
)
>
0
)
{
if
(
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
))
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_INPUT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_T
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bc
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bn
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
C
/
handle
->
bc
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
N
/
handle
->
bn
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)
handle
->
desc
.
max_T
;
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_I
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_F
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_O
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CI
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CO
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_T
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bn
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
N
/
handle
->
bn
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)
handle
->
desc
.
max_T
;
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NC
)
>
0
)
{
if
(
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
))
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
3
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
3
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
3
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_INPUT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_T
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
C
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
N
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)
handle
->
desc
.
max_T
;
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_I
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_F
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_O
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CI
)
||
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CO
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_T
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
N
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)
handle
->
desc
.
max_T
;
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
)
{
layout
->
format
=
handle
->
desc
.
filter_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_FILTER
;
if
((
handle
->
desc
.
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED
)
>
0
)
{
if
(
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
||
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bc
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
C
/
handle
->
bc
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
4
]
=
4
;
}
else
{
layout
->
dim_size
[
4
]
=
3
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
4
]
=
4
;
}
else
{
layout
->
dim_size
[
4
]
=
3
;
}
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
4
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bc
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
C
/
handle
->
bc
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
if
(
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
||
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
6
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
6
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
6
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
lpb
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
bc
/
handle
->
lpb
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
C
/
handle
->
bc
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
5
]
=
4
;
}
else
{
layout
->
dim_size
[
5
]
=
3
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
lpb
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
bk
/
handle
->
lpb
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
5
]
=
4
;
}
else
{
layout
->
dim_size
[
5
]
=
3
;
}
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
lpb
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
bc
/
handle
->
lpb
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
C
/
handle
->
bc
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
lpb
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
bk
/
handle
->
lpb
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
desc
.
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_CK
)
>
0
)
{
if
(
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
))
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
4
);
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
C
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
3
);
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
C
;
}
else
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
C
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
4
);
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
3
);
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
}
else
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
}
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
||
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
)
{
layout
->
format
=
handle
->
desc
.
filter_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_FILTER
;
if
((
handle
->
desc
.
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED
)
>
0
)
{
if
(
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
||
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bc
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
C
/
handle
->
bc
);
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
4
]
=
4
;
}
else
{
layout
->
dim_size
[
4
]
=
3
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
4
]
=
4
;
}
else
{
layout
->
dim_size
[
4
]
=
3
;
}
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
4
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bc
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
C
/
handle
->
bc
);
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
if
(
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
||
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
6
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
6
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
6
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
lpb
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bc
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
bk
/
handle
->
lpb
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)(
handle
->
desc
.
C
/
handle
->
bc
);
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
5
]
=
4
;
}
else
{
layout
->
dim_size
[
5
]
=
3
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
lpb
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
bk
/
handle
->
lpb
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
5
]
=
4
;
}
else
{
layout
->
dim_size
[
5
]
=
3
;
}
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
lpb
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bc
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
bk
/
handle
->
lpb
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)(
handle
->
desc
.
C
/
handle
->
bc
);
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
lpb
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
bk
;
layout
->
dim_size
[
2
]
=
(
unsigned
int
)(
handle
->
bk
/
handle
->
lpb
);
layout
->
dim_size
[
3
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
layout
->
dim_size
[
4
]
=
(
unsigned
int
)(
handle
->
desc
.
K
/
handle
->
bk
);
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
desc
.
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_CK
)
>
0
)
{
if
(
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
))
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
C
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
4
);
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
C
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
3
);
}
else
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
C
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
4
);
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
3
);
}
else
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
layout
->
dim_size
[
1
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
}
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_BIAS
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_BIAS
)
)
{
layout
->
format
=
handle
->
desc
.
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_CHANNEL_SCALAR
;
if
(
((
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NC
)
>
0
)
||
((
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED
)
>
0
)
)
{
if
(
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
((
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
))
)
{
layout
->
datatype
=
handle
->
desc
.
datatype_in
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
1
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
1
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
1
;
if
(
(
type
==
LIBXSMM_DNN_RNN_REGULAR_BIAS
)
||
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_BIAS
)
)
{
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
4
);
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)(
handle
->
desc
.
K
*
3
);
}
else
{
layout
->
dim_size
[
0
]
=
(
unsigned
int
)
handle
->
desc
.
K
;
}
}
else
{
/* coverity[dead_error_begin] */
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
layout
;
}
LIBXSMM_API
size_t
libxsmm_dnn_rnncell_get_scratch_size
(
const
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_compute_kind
kind
,
libxsmm_dnn_err_t
*
status
)
{
size_t
size
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
const
size_t
typesize_in
=
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
const
size_t
dwdr_typesize
=
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
?
sizeof
(
float
)
:
typesize_in
;
switch
(
handle
->
desc
.
cell_type
)
{
case
LIBXSMM_DNN_RNNCELL_RNN_RELU
:
case
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
:
case
LIBXSMM_DNN_RNNCELL_RNN_TANH
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
size
+=
0
;
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
+
64
;
/* wT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
+
64
;
/* rT */
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
N
*
typesize_in
+
64
;
/* xT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* hT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* deltat */
}
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_LSTM
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
*
4
+
4
*
64
;
/* w */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
*
4
+
4
*
64
;
/* r */
/* The scratches below are needed only for BF16 code for the intermediate results */
if
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
size
+=
(
size_t
)
7
*
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
);
/* intermediate scratches */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
+
64
;
/* intermediate scratches */
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
dwdr_typesize
*
4
+
4
*
64
;
/* w */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
dwdr_typesize
*
4
+
4
*
64
;
/* r */
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
*
4
+
4
*
64
;
/* wT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
*
4
+
4
*
64
;
/* rT */
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
N
*
typesize_in
+
64
;
/* xT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* hT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
dwdr_typesize
+
64
;
/* deltat */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* di */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* df */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* do */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dci */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* diB */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dfB */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dpB */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dciB */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* t1 */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* t2 */
/* The scratches below are needed only for BF16 code for the intermediate results */
if
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
size
+=
(
size_t
)
4
*
((
size_t
)
handle
->
desc
.
K
*
sizeof
(
float
)
+
64
);
/* intermediate db scratch */
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* intermediate dx scratches */
size
+=
(
size_t
)
7
*
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
);
/* intermediate scratches */
size
+=
(
size_t
)
2
*
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
+
64
);
/* intermediate scratches */
}
}
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_GRU
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
*
3
+
3
*
64
;
/* w */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
*
3
+
3
*
64
;
/* r */
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
dwdr_typesize
*
3
+
3
*
64
;
/* w */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
dwdr_typesize
*
3
+
3
*
64
;
/* r */
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
*
3
+
3
*
64
;
/* wT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
*
3
+
3
*
64
;
/* rT */
size
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
N
*
typesize_in
+
64
;
/* xT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* hT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
dwdr_typesize
+
64
;
/* deltat */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* di */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dc */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* df */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* do */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* diB */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dcB */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dfB */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* oT */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* t1 */
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* t2 */
}
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_RNN_TYPE
;
}
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
size
;
}
LIBXSMM_API
void
*
libxsmm_dnn_rnncell_get_scratch_ptr
(
const
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_err_t
*
status
)
{
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
return
handle
->
scratch_base
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
NULL
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_bind_scratch
(
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_compute_kind
kind
,
const
void
*
scratch
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
NULL
!=
handle
)
{
const
size_t
typesize_in
=
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_in
);
const
size_t
dwdr_typesize
=
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
?
sizeof
(
float
)
:
typesize_in
;
uintptr_t
address
=
(
uintptr_t
)
scratch
;
size_t
offset
=
0
;
switch
(
handle
->
desc
.
cell_type
)
{
case
LIBXSMM_DNN_RNNCELL_RNN_RELU
:
case
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
:
case
LIBXSMM_DNN_RNNCELL_RNN_TANH
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
/* forward only has no scratch need */
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
if
(
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
handle
->
scratch_base
=
(
void
*
)
address
;
/* wT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_wT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_wT
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
+
64
;
/* rT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_rT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_rT
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
+
64
;
/* xT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_xT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_xT
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
N
*
typesize_in
)
+
64
;
/* hT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_hT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_hT
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
))
+
64
;
/* deltat */
if
(
address
%
64
==
0
)
{
handle
->
scratch_deltat
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_deltat
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
*
(
size_t
)
handle
->
desc
.
max_T
)
+
64
;
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_LSTM
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
if
(
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
handle
->
scratch_base
=
(
void
*
)
address
;
/* w scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_w
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_w
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
*
4
+
64
;
/* r scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_r
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_r
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
*
4
+
64
;
/* The scratches below are needed only for BF16 code for the intermediate results */
if
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
/* cst scratch */
if
(
address
%
64
==
0
)
{
handle
->
cst_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
cst_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* ht scratch */
if
(
address
%
64
==
0
)
{
handle
->
ht_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
ht_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* it scratch */
if
(
address
%
64
==
0
)
{
handle
->
it_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
it_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* ft scratch */
if
(
address
%
64
==
0
)
{
handle
->
ft_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
ft_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* ot scratch */
if
(
address
%
64
==
0
)
{
handle
->
ot_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
ot_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* cit scratch */
if
(
address
%
64
==
0
)
{
handle
->
cit_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
cit_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* cot scratch */
if
(
address
%
64
==
0
)
{
handle
->
cot_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
cot_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* csp scratch */
if
(
address
%
64
==
0
)
{
handle
->
csp_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
csp_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
+
64
;
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
if
(
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
handle
->
scratch_base
=
(
void
*
)
address
;
/* w scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_w
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_w
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
dwdr_typesize
)
*
4
+
64
;
/* r scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_r
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_r
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
dwdr_typesize
)
*
4
+
64
;
/* wT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_wT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_wT
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
*
4
+
64
;
/* rT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_rT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_rT
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
*
4
+
64
;
/* xT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_xT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_xT
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
N
*
typesize_in
+
64
;
/* hT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_hT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_hT
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* deltat */
if
(
address
%
64
==
0
)
{
handle
->
scratch_deltat
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_deltat
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
dwdr_typesize
+
64
;
/* di */
if
(
address
%
64
==
0
)
{
handle
->
scratch_di
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_di
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* df */
if
(
address
%
64
==
0
)
{
handle
->
scratch_df
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_df
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* do */
if
(
address
%
64
==
0
)
{
handle
->
scratch_do
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_do
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dci */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dci
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dci
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* diB */
if
(
address
%
64
==
0
)
{
handle
->
scratch_diB
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_diB
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dfB */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dfB
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dfB
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dpB */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dpB
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dpB
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dciB */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dciB
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dciB
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* t1 */
if
(
address
%
64
==
0
)
{
handle
->
scratch_t1
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_t1
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* t2 */
if
(
address
%
64
==
0
)
{
handle
->
scratch_t2
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_t2
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* The scratches below are needed only for BF16 code for the intermediate results */
if
(
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
/* dx scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dx
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dx
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* dhp scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dhp
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dhp
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
+
64
;
/* db scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_db
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_db
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
4
*
sizeof
(
float
)
+
64
;
/* cst scratch */
if
(
address
%
64
==
0
)
{
handle
->
cst_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
cst_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* ht scratch */
if
(
address
%
64
==
0
)
{
handle
->
ht_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
ht_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* it scratch */
if
(
address
%
64
==
0
)
{
handle
->
it_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
it_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* ft scratch */
if
(
address
%
64
==
0
)
{
handle
->
ft_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
ft_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* ot scratch */
if
(
address
%
64
==
0
)
{
handle
->
ot_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
ot_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* cit scratch */
if
(
address
%
64
==
0
)
{
handle
->
cit_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
cit_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* cot scratch */
if
(
address
%
64
==
0
)
{
handle
->
cot_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
cot_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* csp scratch */
if
(
address
%
64
==
0
)
{
handle
->
csp_scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
csp_scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof
(
float
)
+
64
;
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_GRU
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
if
(
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
handle
->
scratch_base
=
(
void
*
)
address
;
/* w scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_w
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_w
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
*
3
+
64
;
/* r scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_r
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_r
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
*
3
+
64
;
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
if
(
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
handle
->
scratch_base
=
(
void
*
)
address
;
/* w scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_w
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_w
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
dwdr_typesize
)
*
3
+
64
;
/* r scratch */
if
(
address
%
64
==
0
)
{
handle
->
scratch_r
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_r
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
dwdr_typesize
)
*
3
+
64
;
/* wT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_wT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_wT
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
*
3
+
64
;
/* rT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_rT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_rT
=
(
void
*
)(
address
+
offset
);
}
address
+=
((
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
K
*
typesize_in
)
*
3
+
64
;
/* xT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_xT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_xT
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
C
*
(
size_t
)
handle
->
desc
.
N
*
typesize_in
+
64
;
/* hT */
if
(
address
%
64
==
0
)
{
handle
->
scratch_hT
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_hT
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* deltat */
if
(
address
%
64
==
0
)
{
handle
->
scratch_deltat
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_deltat
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
dwdr_typesize
+
64
;
/* di */
if
(
address
%
64
==
0
)
{
handle
->
scratch_di
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_di
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dc */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dci
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dci
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* df */
if
(
address
%
64
==
0
)
{
handle
->
scratch_df
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_df
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* do */
if
(
address
%
64
==
0
)
{
handle
->
scratch_do
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_do
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* diB */
if
(
address
%
64
==
0
)
{
handle
->
scratch_diB
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_diB
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dcB */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dciB
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dciB
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* dfB */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dfB
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dfB
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* doB (repurposed for oT) */
if
(
address
%
64
==
0
)
{
handle
->
scratch_dpB
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_dpB
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* t1 */
if
(
address
%
64
==
0
)
{
handle
->
scratch_t1
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_t1
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
/* t2 */
if
(
address
%
64
==
0
)
{
handle
->
scratch_t2
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch_t2
=
(
void
*
)(
address
+
offset
);
}
address
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
libxsmm_dnn_typesize
(
handle
->
desc
.
datatype_out
)
+
64
;
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_RNN_TYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_release_scratch
(
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_compute_kind
kind
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
switch
(
handle
->
desc
.
cell_type
)
{
case
LIBXSMM_DNN_RNNCELL_RNN_RELU
:
case
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
:
case
LIBXSMM_DNN_RNNCELL_RNN_TANH
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
/* forward only has no scratch need */
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
handle
->
scratch_wT
=
0
;
handle
->
scratch_rT
=
0
;
handle
->
scratch_xT
=
0
;
handle
->
scratch_hT
=
0
;
handle
->
scratch_deltat
=
0
;
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_LSTM
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
handle
->
scratch_w
=
0
;
handle
->
scratch_r
=
0
;
handle
->
csp_scratch
=
0
;
handle
->
cst_scratch
=
0
;
handle
->
ht_scratch
=
0
;
handle
->
it_scratch
=
0
;
handle
->
ft_scratch
=
0
;
handle
->
ot_scratch
=
0
;
handle
->
cit_scratch
=
0
;
handle
->
cot_scratch
=
0
;
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
handle
->
scratch_w
=
0
;
handle
->
scratch_r
=
0
;
handle
->
scratch_wT
=
0
;
handle
->
scratch_rT
=
0
;
handle
->
scratch_xT
=
0
;
handle
->
scratch_hT
=
0
;
handle
->
scratch_deltat
=
0
;
handle
->
scratch_di
=
0
;
handle
->
scratch_df
=
0
;
handle
->
scratch_do
=
0
;
handle
->
scratch_dci
=
0
;
handle
->
scratch_diB
=
0
;
handle
->
scratch_dfB
=
0
;
handle
->
scratch_dpB
=
0
;
handle
->
scratch_dciB
=
0
;
handle
->
scratch_t1
=
0
;
handle
->
scratch_t2
=
0
;
handle
->
csp_scratch
=
0
;
handle
->
cst_scratch
=
0
;
handle
->
ht_scratch
=
0
;
handle
->
it_scratch
=
0
;
handle
->
ft_scratch
=
0
;
handle
->
ot_scratch
=
0
;
handle
->
cit_scratch
=
0
;
handle
->
cot_scratch
=
0
;
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_GRU
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
handle
->
scratch_w
=
0
;
handle
->
scratch_r
=
0
;
handle
->
ht_scratch
=
0
;
handle
->
it_scratch
=
0
;
handle
->
cit_scratch
=
0
;
handle
->
ft_scratch
=
0
;
handle
->
ot_scratch
=
0
;
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
handle
->
scratch_w
=
0
;
handle
->
scratch_r
=
0
;
handle
->
scratch_wT
=
0
;
handle
->
scratch_rT
=
0
;
handle
->
scratch_xT
=
0
;
handle
->
scratch_hT
=
0
;
handle
->
scratch_deltat
=
0
;
handle
->
scratch_di
=
0
;
handle
->
scratch_dci
=
0
;
handle
->
scratch_df
=
0
;
handle
->
scratch_do
=
0
;
handle
->
scratch_diB
=
0
;
handle
->
scratch_dciB
=
0
;
handle
->
scratch_dfB
=
0
;
handle
->
scratch_dpB
=
0
;
handle
->
scratch_t1
=
0
;
handle
->
scratch_t2
=
0
;
handle
->
ht_scratch
=
0
;
handle
->
it_scratch
=
0
;
handle
->
ft_scratch
=
0
;
handle
->
ot_scratch
=
0
;
handle
->
cit_scratch
=
0
;
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_RNN_TYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
size_t
libxsmm_dnn_rnncell_get_internalstate_size
(
const
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_compute_kind
kind
,
libxsmm_dnn_err_t
*
status
)
{
size_t
size
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
const
size_t
sizeof_datatype
=
sizeof
(
float
);
switch
(
handle
->
desc
.
cell_type
)
{
case
LIBXSMM_DNN_RNNCELL_RNN_RELU
:
case
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
:
case
LIBXSMM_DNN_RNNCELL_RNN_TANH
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof_datatype
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* zt */
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
size
+=
(
size_t
)
handle
->
desc
.
K
*
(
size_t
)
handle
->
desc
.
N
*
sizeof_datatype
*
(
size_t
)
handle
->
desc
.
max_T
+
64
;
/* zt */
}
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_LSTM
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
/* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
/* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */
}
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_GRU
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
/* with i, f, c, o exposed as i/o, there is currently no need for internal state */
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
/* with i, f, c, o exposed as i/o, there is currently no need for internal state */
}
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_RNN_TYPE
;
}
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
size
;
}
LIBXSMM_API
void
*
libxsmm_dnn_rnncell_get_internalstate_ptr
(
const
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_err_t
*
status
)
{
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
return
handle
->
internal_z
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
NULL
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_bind_internalstate
(
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_compute_kind
kind
,
const
void
*
internalstate
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
uintptr_t
address
=
(
uintptr_t
)
internalstate
;
size_t
offset
=
0
;
if
(
0
!=
handle
)
{
switch
(
handle
->
desc
.
cell_type
)
{
case
LIBXSMM_DNN_RNNCELL_RNN_RELU
:
case
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
:
case
LIBXSMM_DNN_RNNCELL_RNN_TANH
:
{
if
(
internalstate
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
if
(
address
%
64
==
0
)
{
handle
->
internal_z
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
internal_z
=
(
void
*
)(
address
+
offset
);
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
if
(
address
%
64
==
0
)
{
handle
->
internal_z
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
internal_z
=
(
void
*
)(
address
+
offset
);
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_LSTM
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_GRU
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_RNN_TYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_release_internalstate
(
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_compute_kind
kind
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
switch
(
handle
->
desc
.
cell_type
)
{
case
LIBXSMM_DNN_RNNCELL_RNN_RELU
:
case
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
:
case
LIBXSMM_DNN_RNNCELL_RNN_TANH
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
handle
->
internal_z
=
0
;
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
handle
->
internal_z
=
0
;
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_LSTM
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
case
LIBXSMM_DNN_RNNCELL_GRU
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
{
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_RNN_TYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_allocate_forget_bias
(
libxsmm_dnn_rnncell
*
handle
,
const
float
forget_bias
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
handle
!=
0
)
{
handle
->
forget_bias
=
forget_bias
;
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_bind_tensor
(
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_tensor
*
tensor
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_CS_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_CS_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_CS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_CS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_I
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_F
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_O
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_CI
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_CO
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
&&
tensor
!=
0
)
{
libxsmm_dnn_tensor_datalayout
*
handle_layout
=
libxsmm_dnn_rnncell_create_tensor_datalayout
(
handle
,
type
,
&
status
);
if
(
libxsmm_dnn_compare_tensor_datalayout
(
handle_layout
,
tensor
->
layout
,
&
status
)
==
0
)
{
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_INPUT
)
{
handle
->
xt
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_INPUT
)
{
handle
->
dxt
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS_PREV
)
{
handle
->
csp
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS_PREV
)
{
handle
->
dcsp
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV
)
{
handle
->
hp
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV
)
{
handle
->
dhp
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
{
handle
->
w
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
{
handle
->
wt
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
{
handle
->
dw
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
{
handle
->
r
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
{
handle
->
rt
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
{
handle
->
dr
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_BIAS
)
{
handle
->
b
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_BIAS
)
{
handle
->
db
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS
)
{
handle
->
cst
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS
)
{
handle
->
dcs
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE
)
{
handle
->
ht
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE
)
{
handle
->
dht
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_I
)
{
handle
->
it
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_F
)
{
handle
->
ft
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_O
)
{
handle
->
ot
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CI
)
{
handle
->
cit
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CO
)
{
handle
->
cot
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_MISMATCH_TENSOR
;
}
libxsmm_dnn_destroy_tensor_datalayout
(
handle_layout
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor
*
libxsmm_dnn_rnncell_get_tensor
(
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor
*
tensor
=
0
;
LIBXSMM_UNUSED
(
status
/*TODO*/
);
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_CS_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_CS_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_CS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_CS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_I
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_F
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_O
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_CI
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_CO
)
)
{
return
tensor
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_INPUT
)
{
tensor
=
handle
->
xt
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_INPUT
)
{
tensor
=
handle
->
dxt
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS_PREV
)
{
tensor
=
handle
->
csp
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS_PREV
)
{
tensor
=
handle
->
dcsp
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV
)
{
tensor
=
handle
->
hp
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV
)
{
tensor
=
handle
->
dhp
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
{
tensor
=
handle
->
w
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
{
tensor
=
handle
->
wt
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
{
tensor
=
handle
->
dw
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
{
tensor
=
handle
->
r
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
{
tensor
=
handle
->
rt
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
{
tensor
=
handle
->
dr
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_BIAS
)
{
tensor
=
handle
->
b
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_BIAS
)
{
tensor
=
handle
->
db
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS
)
{
tensor
=
handle
->
cst
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS
)
{
tensor
=
handle
->
dcs
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE
)
{
tensor
=
handle
->
ht
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE
)
{
tensor
=
handle
->
dht
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_I
)
{
tensor
=
handle
->
it
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_F
)
{
tensor
=
handle
->
ft
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_O
)
{
tensor
=
handle
->
ot
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CI
)
{
tensor
=
handle
->
cit
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CO
)
{
tensor
=
handle
->
cot
;
}
else
{
/* cannot happen */
}
}
return
tensor
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_release_tensor
(
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_CS_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_CS_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_CS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_CS
)
&&
(
type
!=
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE
)
&&
(
type
!=
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_I
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_F
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_O
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_CI
)
&&
(
type
!=
LIBXSMM_DNN_RNN_INTERNAL_CO
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_INPUT
)
{
handle
->
xt
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_INPUT
)
{
handle
->
dxt
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS_PREV
)
{
handle
->
csp
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS_PREV
)
{
handle
->
dcsp
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV
)
{
handle
->
hp
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV
)
{
handle
->
dhp
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT
)
{
handle
->
w
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS
)
{
handle
->
wt
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_WEIGHT
)
{
handle
->
dw
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT
)
{
handle
->
r
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS
)
{
handle
->
rt
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT
)
{
handle
->
dr
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_BIAS
)
{
handle
->
b
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_BIAS
)
{
handle
->
db
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_CS
)
{
handle
->
cst
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_CS
)
{
handle
->
dcs
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE
)
{
handle
->
ht
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE
)
{
handle
->
dht
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_I
)
{
handle
->
it
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_F
)
{
handle
->
ft
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_O
)
{
handle
->
ot
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CI
)
{
handle
->
cit
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_RNN_INTERNAL_CO
)
{
handle
->
cot
=
0
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_set_sequence_length
(
libxsmm_dnn_rnncell
*
handle
,
const
libxsmm_blasint
T
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
if
(
handle
->
desc
.
max_T
<
T
)
{
status
=
LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN
;
}
else
{
handle
->
T
=
T
;
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_blasint
libxsmm_dnn_rnncell_get_sequence_length
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_err_t
*
status
)
{
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
return
handle
->
T
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
0
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_execute_st
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
/*unsigned*/
int
start_thread
,
/*unsigned*/
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
if
(
(
handle
->
desc
.
buffer_format
==
LIBXSMM_DNN_TENSOR_FORMAT_NC
)
&&
(
handle
->
desc
.
filter_format
==
LIBXSMM_DNN_TENSOR_FORMAT_CK
)
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_ck
(
handle
,
start_thread
,
tid
);
}
else
if
(
(
handle
->
desc
.
buffer_format
==
LIBXSMM_DNN_TENSOR_FORMAT_NC
)
&&
(
handle
->
desc
.
filter_format
==
LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED
)
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_kcck
(
handle
,
start_thread
,
tid
);
}
else
if
(
(
handle
->
desc
.
buffer_format
==
LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED
)
&&
(
handle
->
desc
.
filter_format
==
LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED
)
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
{
if
(
(
handle
->
desc
.
buffer_format
==
LIBXSMM_DNN_TENSOR_FORMAT_NC
)
&&
(
handle
->
desc
.
filter_format
==
LIBXSMM_DNN_TENSOR_FORMAT_CK
)
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_ck
(
handle
,
kind
,
start_thread
,
tid
);
}
else
if
(
(
handle
->
desc
.
buffer_format
==
LIBXSMM_DNN_TENSOR_FORMAT_NC
)
&&
(
handle
->
desc
.
filter_format
==
LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED
)
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck
(
handle
,
kind
,
start_thread
,
tid
);
}
else
if
(
(
handle
->
desc
.
buffer_format
==
LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED
)
&&
(
handle
->
desc
.
filter_format
==
LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED
)
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck
(
handle
,
kind
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Kunal Banerjee, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_rnncell_backward_weight_update.h"
#include "libxsmm_dnn_elementwise.h"
#include "libxsmm_main.h"
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
trans_act
(
short
int
*
in
,
short
int
*
out
)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
,
r9
,
ra
,
rb
,
rc
,
rd
,
re
,
rf
;
__m512i
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
,
t8
,
t9
,
ta
,
tb
,
tc
,
td
,
te
,
tf
;
__m512i
v0
,
v1
,
v2
,
v3
,
v4
,
v5
,
v6
,
v7
;
const
__m512i
idx_v
=
_mm512_set_epi64
(
13
,
12
,
7
,
6
,
9
,
8
,
3
,
2
);
const
__mmask8
mask0
=
LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8
(
204
);
const
__mmask8
mask1
=
LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8
(
51
);
const
int
in_width
=
32
,
out_width
=
32
;
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
t0
=
_mm512_unpacklo_epi16
(
r0
,
r1
);
t1
=
_mm512_unpackhi_epi16
(
r0
,
r1
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
t2
=
_mm512_unpacklo_epi16
(
r2
,
r3
);
t3
=
_mm512_unpackhi_epi16
(
r2
,
r3
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
t4
=
_mm512_unpacklo_epi16
(
r4
,
r5
);
t5
=
_mm512_unpackhi_epi16
(
r4
,
r5
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
t6
=
_mm512_unpacklo_epi16
(
r6
,
r7
);
t7
=
_mm512_unpackhi_epi16
(
r6
,
r7
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
t8
=
_mm512_unpacklo_epi16
(
r8
,
r9
);
t9
=
_mm512_unpackhi_epi16
(
r8
,
r9
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
ta
=
_mm512_unpacklo_epi16
(
ra
,
rb
);
tb
=
_mm512_unpackhi_epi16
(
ra
,
rb
);
rc
=
_mm512_loadu_si512
(
in
+
12
*
in_width
);
rd
=
_mm512_loadu_si512
(
in
+
13
*
in_width
);
tc
=
_mm512_unpacklo_epi16
(
rc
,
rd
);
td
=
_mm512_unpackhi_epi16
(
rc
,
rd
);
re
=
_mm512_loadu_si512
(
in
+
14
*
in_width
);
rf
=
_mm512_loadu_si512
(
in
+
15
*
in_width
);
te
=
_mm512_unpacklo_epi16
(
re
,
rf
);
tf
=
_mm512_unpackhi_epi16
(
re
,
rf
);
r0
=
_mm512_unpacklo_epi32
(
t0
,
t2
);
r1
=
_mm512_unpackhi_epi32
(
t0
,
t2
);
r2
=
_mm512_unpacklo_epi32
(
t1
,
t3
);
r3
=
_mm512_unpackhi_epi32
(
t1
,
t3
);
r4
=
_mm512_unpacklo_epi32
(
t4
,
t6
);
r5
=
_mm512_unpackhi_epi32
(
t4
,
t6
);
r6
=
_mm512_unpacklo_epi32
(
t5
,
t7
);
r7
=
_mm512_unpackhi_epi32
(
t5
,
t7
);
r8
=
_mm512_unpacklo_epi32
(
t8
,
ta
);
r9
=
_mm512_unpackhi_epi32
(
t8
,
ta
);
ra
=
_mm512_unpacklo_epi32
(
t9
,
tb
);
rb
=
_mm512_unpackhi_epi32
(
t9
,
tb
);
rc
=
_mm512_unpacklo_epi32
(
tc
,
te
);
rd
=
_mm512_unpackhi_epi32
(
tc
,
te
);
re
=
_mm512_unpacklo_epi32
(
td
,
tf
);
rf
=
_mm512_unpackhi_epi32
(
td
,
tf
);
t0
=
_mm512_unpacklo_epi64
(
r0
,
r4
);
t1
=
_mm512_unpackhi_epi64
(
r0
,
r4
);
t2
=
_mm512_unpacklo_epi64
(
r1
,
r5
);
t3
=
_mm512_unpackhi_epi64
(
r1
,
r5
);
t4
=
_mm512_unpacklo_epi64
(
r2
,
r6
);
t5
=
_mm512_unpackhi_epi64
(
r2
,
r6
);
t6
=
_mm512_unpacklo_epi64
(
r3
,
r7
);
t7
=
_mm512_unpackhi_epi64
(
r3
,
r7
);
t8
=
_mm512_unpacklo_epi64
(
r8
,
rc
);
t9
=
_mm512_unpackhi_epi64
(
r8
,
rc
);
ta
=
_mm512_unpacklo_epi64
(
r9
,
rd
);
tb
=
_mm512_unpackhi_epi64
(
r9
,
rd
);
tc
=
_mm512_unpacklo_epi64
(
ra
,
re
);
td
=
_mm512_unpackhi_epi64
(
ra
,
re
);
te
=
_mm512_unpacklo_epi64
(
rb
,
rf
);
tf
=
_mm512_unpackhi_epi64
(
rb
,
rf
);
r0
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0x88
);
r1
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0x88
);
r2
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0x88
);
r3
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0x88
);
r4
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0xdd
);
r5
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0xdd
);
r6
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0xdd
);
r7
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0xdd
);
r8
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0x88
);
r9
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0x88
);
ra
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0x88
);
rb
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0x88
);
rc
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0xdd
);
rd
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0xdd
);
re
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0xdd
);
rf
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0xdd
);
v0
=
_mm512_permutex2var_epi64
(
r0
,
idx_v
,
r8
);
t0
=
_mm512_mask_blend_epi64
(
mask0
,
r0
,
v0
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
0
*
out_width
),
_mm512_extracti64x4_epi64
(
t0
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
1
*
out_width
),
_mm512_extracti64x4_epi64
(
t0
,
1
));
t8
=
_mm512_mask_blend_epi64
(
mask1
,
r8
,
v0
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
*
out_width
),
_mm512_extracti64x4_epi64
(
t8
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
17
*
out_width
),
_mm512_extracti64x4_epi64
(
t8
,
1
));
v1
=
_mm512_permutex2var_epi64
(
r1
,
idx_v
,
r9
);
t1
=
_mm512_mask_blend_epi64
(
mask0
,
r1
,
v1
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
2
*
out_width
),
_mm512_extracti64x4_epi64
(
t1
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
3
*
out_width
),
_mm512_extracti64x4_epi64
(
t1
,
1
));
t9
=
_mm512_mask_blend_epi64
(
mask1
,
r9
,
v1
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
18
*
out_width
),
_mm512_extracti64x4_epi64
(
t9
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
19
*
out_width
),
_mm512_extracti64x4_epi64
(
t9
,
1
));
v2
=
_mm512_permutex2var_epi64
(
r2
,
idx_v
,
ra
);
t2
=
_mm512_mask_blend_epi64
(
mask0
,
r2
,
v2
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
4
*
out_width
),
_mm512_extracti64x4_epi64
(
t2
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
5
*
out_width
),
_mm512_extracti64x4_epi64
(
t2
,
1
));
ta
=
_mm512_mask_blend_epi64
(
mask1
,
ra
,
v2
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
20
*
out_width
),
_mm512_extracti64x4_epi64
(
ta
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
21
*
out_width
),
_mm512_extracti64x4_epi64
(
ta
,
1
));
v3
=
_mm512_permutex2var_epi64
(
r3
,
idx_v
,
rb
);
t3
=
_mm512_mask_blend_epi64
(
mask0
,
r3
,
v3
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
6
*
out_width
),
_mm512_extracti64x4_epi64
(
t3
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
7
*
out_width
),
_mm512_extracti64x4_epi64
(
t3
,
1
));
tb
=
_mm512_mask_blend_epi64
(
mask1
,
rb
,
v3
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
22
*
out_width
),
_mm512_extracti64x4_epi64
(
tb
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
23
*
out_width
),
_mm512_extracti64x4_epi64
(
tb
,
1
));
v4
=
_mm512_permutex2var_epi64
(
r4
,
idx_v
,
rc
);
t4
=
_mm512_mask_blend_epi64
(
mask0
,
r4
,
v4
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
8
*
out_width
),
_mm512_extracti64x4_epi64
(
t4
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
9
*
out_width
),
_mm512_extracti64x4_epi64
(
t4
,
1
));
tc
=
_mm512_mask_blend_epi64
(
mask1
,
rc
,
v4
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
24
*
out_width
),
_mm512_extracti64x4_epi64
(
tc
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
25
*
out_width
),
_mm512_extracti64x4_epi64
(
tc
,
1
));
v5
=
_mm512_permutex2var_epi64
(
r5
,
idx_v
,
rd
);
t5
=
_mm512_mask_blend_epi64
(
mask0
,
r5
,
v5
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
10
*
out_width
),
_mm512_extracti64x4_epi64
(
t5
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
11
*
out_width
),
_mm512_extracti64x4_epi64
(
t5
,
1
));
td
=
_mm512_mask_blend_epi64
(
mask1
,
rd
,
v5
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
26
*
out_width
),
_mm512_extracti64x4_epi64
(
td
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
27
*
out_width
),
_mm512_extracti64x4_epi64
(
td
,
1
));
v6
=
_mm512_permutex2var_epi64
(
r6
,
idx_v
,
re
);
t6
=
_mm512_mask_blend_epi64
(
mask0
,
r6
,
v6
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
12
*
out_width
),
_mm512_extracti64x4_epi64
(
t6
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
13
*
out_width
),
_mm512_extracti64x4_epi64
(
t6
,
1
));
te
=
_mm512_mask_blend_epi64
(
mask1
,
re
,
v6
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
28
*
out_width
),
_mm512_extracti64x4_epi64
(
te
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
29
*
out_width
),
_mm512_extracti64x4_epi64
(
te
,
1
));
v7
=
_mm512_permutex2var_epi64
(
r7
,
idx_v
,
rf
);
t7
=
_mm512_mask_blend_epi64
(
mask0
,
r7
,
v7
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
14
*
out_width
),
_mm512_extracti64x4_epi64
(
t7
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
15
*
out_width
),
_mm512_extracti64x4_epi64
(
t7
,
1
));
tf
=
_mm512_mask_blend_epi64
(
mask1
,
rf
,
v7
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
30
*
out_width
),
_mm512_extracti64x4_epi64
(
tf
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
31
*
out_width
),
_mm512_extracti64x4_epi64
(
tf
,
1
));
r0
=
_mm512_loadu_si512
(
in
+
16
*
32
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
16
*
32
+
1
*
in_width
);
t0
=
_mm512_unpacklo_epi16
(
r0
,
r1
);
t1
=
_mm512_unpackhi_epi16
(
r0
,
r1
);
r2
=
_mm512_loadu_si512
(
in
+
16
*
32
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
16
*
32
+
3
*
in_width
);
t2
=
_mm512_unpacklo_epi16
(
r2
,
r3
);
t3
=
_mm512_unpackhi_epi16
(
r2
,
r3
);
r4
=
_mm512_loadu_si512
(
in
+
16
*
32
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
16
*
32
+
5
*
in_width
);
t4
=
_mm512_unpacklo_epi16
(
r4
,
r5
);
t5
=
_mm512_unpackhi_epi16
(
r4
,
r5
);
r6
=
_mm512_loadu_si512
(
in
+
16
*
32
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
16
*
32
+
7
*
in_width
);
t6
=
_mm512_unpacklo_epi16
(
r6
,
r7
);
t7
=
_mm512_unpackhi_epi16
(
r6
,
r7
);
r8
=
_mm512_loadu_si512
(
in
+
16
*
32
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
16
*
32
+
9
*
in_width
);
t8
=
_mm512_unpacklo_epi16
(
r8
,
r9
);
t9
=
_mm512_unpackhi_epi16
(
r8
,
r9
);
ra
=
_mm512_loadu_si512
(
in
+
16
*
32
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
16
*
32
+
11
*
in_width
);
ta
=
_mm512_unpacklo_epi16
(
ra
,
rb
);
tb
=
_mm512_unpackhi_epi16
(
ra
,
rb
);
rc
=
_mm512_loadu_si512
(
in
+
16
*
32
+
12
*
in_width
);
rd
=
_mm512_loadu_si512
(
in
+
16
*
32
+
13
*
in_width
);
tc
=
_mm512_unpacklo_epi16
(
rc
,
rd
);
td
=
_mm512_unpackhi_epi16
(
rc
,
rd
);
re
=
_mm512_loadu_si512
(
in
+
16
*
32
+
14
*
in_width
);
rf
=
_mm512_loadu_si512
(
in
+
16
*
32
+
15
*
in_width
);
te
=
_mm512_unpacklo_epi16
(
re
,
rf
);
tf
=
_mm512_unpackhi_epi16
(
re
,
rf
);
r0
=
_mm512_unpacklo_epi32
(
t0
,
t2
);
r1
=
_mm512_unpackhi_epi32
(
t0
,
t2
);
r2
=
_mm512_unpacklo_epi32
(
t1
,
t3
);
r3
=
_mm512_unpackhi_epi32
(
t1
,
t3
);
r4
=
_mm512_unpacklo_epi32
(
t4
,
t6
);
r5
=
_mm512_unpackhi_epi32
(
t4
,
t6
);
r6
=
_mm512_unpacklo_epi32
(
t5
,
t7
);
r7
=
_mm512_unpackhi_epi32
(
t5
,
t7
);
r8
=
_mm512_unpacklo_epi32
(
t8
,
ta
);
r9
=
_mm512_unpackhi_epi32
(
t8
,
ta
);
ra
=
_mm512_unpacklo_epi32
(
t9
,
tb
);
rb
=
_mm512_unpackhi_epi32
(
t9
,
tb
);
rc
=
_mm512_unpacklo_epi32
(
tc
,
te
);
rd
=
_mm512_unpackhi_epi32
(
tc
,
te
);
re
=
_mm512_unpacklo_epi32
(
td
,
tf
);
rf
=
_mm512_unpackhi_epi32
(
td
,
tf
);
t0
=
_mm512_unpacklo_epi64
(
r0
,
r4
);
t1
=
_mm512_unpackhi_epi64
(
r0
,
r4
);
t2
=
_mm512_unpacklo_epi64
(
r1
,
r5
);
t3
=
_mm512_unpackhi_epi64
(
r1
,
r5
);
t4
=
_mm512_unpacklo_epi64
(
r2
,
r6
);
t5
=
_mm512_unpackhi_epi64
(
r2
,
r6
);
t6
=
_mm512_unpacklo_epi64
(
r3
,
r7
);
t7
=
_mm512_unpackhi_epi64
(
r3
,
r7
);
t8
=
_mm512_unpacklo_epi64
(
r8
,
rc
);
t9
=
_mm512_unpackhi_epi64
(
r8
,
rc
);
ta
=
_mm512_unpacklo_epi64
(
r9
,
rd
);
tb
=
_mm512_unpackhi_epi64
(
r9
,
rd
);
tc
=
_mm512_unpacklo_epi64
(
ra
,
re
);
td
=
_mm512_unpackhi_epi64
(
ra
,
re
);
te
=
_mm512_unpacklo_epi64
(
rb
,
rf
);
tf
=
_mm512_unpackhi_epi64
(
rb
,
rf
);
r0
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0x88
);
r1
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0x88
);
r2
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0x88
);
r3
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0x88
);
r4
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0xdd
);
r5
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0xdd
);
r6
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0xdd
);
r7
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0xdd
);
r8
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0x88
);
r9
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0x88
);
ra
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0x88
);
rb
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0x88
);
rc
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0xdd
);
rd
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0xdd
);
re
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0xdd
);
rf
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0xdd
);
v0
=
_mm512_permutex2var_epi64
(
r0
,
idx_v
,
r8
);
t0
=
_mm512_mask_blend_epi64
(
mask0
,
r0
,
v0
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
0
*
out_width
),
_mm512_extracti64x4_epi64
(
t0
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
1
*
out_width
),
_mm512_extracti64x4_epi64
(
t0
,
1
));
t8
=
_mm512_mask_blend_epi64
(
mask1
,
r8
,
v0
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
16
*
out_width
),
_mm512_extracti64x4_epi64
(
t8
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
17
*
out_width
),
_mm512_extracti64x4_epi64
(
t8
,
1
));
v1
=
_mm512_permutex2var_epi64
(
r1
,
idx_v
,
r9
);
t1
=
_mm512_mask_blend_epi64
(
mask0
,
r1
,
v1
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
2
*
out_width
),
_mm512_extracti64x4_epi64
(
t1
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
3
*
out_width
),
_mm512_extracti64x4_epi64
(
t1
,
1
));
t9
=
_mm512_mask_blend_epi64
(
mask1
,
r9
,
v1
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
18
*
out_width
),
_mm512_extracti64x4_epi64
(
t9
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
19
*
out_width
),
_mm512_extracti64x4_epi64
(
t9
,
1
));
v2
=
_mm512_permutex2var_epi64
(
r2
,
idx_v
,
ra
);
t2
=
_mm512_mask_blend_epi64
(
mask0
,
r2
,
v2
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
4
*
out_width
),
_mm512_extracti64x4_epi64
(
t2
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
5
*
out_width
),
_mm512_extracti64x4_epi64
(
t2
,
1
));
ta
=
_mm512_mask_blend_epi64
(
mask1
,
ra
,
v2
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
20
*
out_width
),
_mm512_extracti64x4_epi64
(
ta
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
21
*
out_width
),
_mm512_extracti64x4_epi64
(
ta
,
1
));
v3
=
_mm512_permutex2var_epi64
(
r3
,
idx_v
,
rb
);
t3
=
_mm512_mask_blend_epi64
(
mask0
,
r3
,
v3
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
6
*
out_width
),
_mm512_extracti64x4_epi64
(
t3
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
7
*
out_width
),
_mm512_extracti64x4_epi64
(
t3
,
1
));
tb
=
_mm512_mask_blend_epi64
(
mask1
,
rb
,
v3
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
22
*
out_width
),
_mm512_extracti64x4_epi64
(
tb
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
23
*
out_width
),
_mm512_extracti64x4_epi64
(
tb
,
1
));
v4
=
_mm512_permutex2var_epi64
(
r4
,
idx_v
,
rc
);
t4
=
_mm512_mask_blend_epi64
(
mask0
,
r4
,
v4
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
8
*
out_width
),
_mm512_extracti64x4_epi64
(
t4
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
9
*
out_width
),
_mm512_extracti64x4_epi64
(
t4
,
1
));
tc
=
_mm512_mask_blend_epi64
(
mask1
,
rc
,
v4
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
24
*
out_width
),
_mm512_extracti64x4_epi64
(
tc
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
25
*
out_width
),
_mm512_extracti64x4_epi64
(
tc
,
1
));
v5
=
_mm512_permutex2var_epi64
(
r5
,
idx_v
,
rd
);
t5
=
_mm512_mask_blend_epi64
(
mask0
,
r5
,
v5
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
10
*
out_width
),
_mm512_extracti64x4_epi64
(
t5
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
11
*
out_width
),
_mm512_extracti64x4_epi64
(
t5
,
1
));
td
=
_mm512_mask_blend_epi64
(
mask1
,
rd
,
v5
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
26
*
out_width
),
_mm512_extracti64x4_epi64
(
td
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
27
*
out_width
),
_mm512_extracti64x4_epi64
(
td
,
1
));
v6
=
_mm512_permutex2var_epi64
(
r6
,
idx_v
,
re
);
t6
=
_mm512_mask_blend_epi64
(
mask0
,
r6
,
v6
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
12
*
out_width
),
_mm512_extracti64x4_epi64
(
t6
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
13
*
out_width
),
_mm512_extracti64x4_epi64
(
t6
,
1
));
te
=
_mm512_mask_blend_epi64
(
mask1
,
re
,
v6
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
28
*
out_width
),
_mm512_extracti64x4_epi64
(
te
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
29
*
out_width
),
_mm512_extracti64x4_epi64
(
te
,
1
));
v7
=
_mm512_permutex2var_epi64
(
r7
,
idx_v
,
rf
);
t7
=
_mm512_mask_blend_epi64
(
mask0
,
r7
,
v7
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
14
*
out_width
),
_mm512_extracti64x4_epi64
(
t7
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
15
*
out_width
),
_mm512_extracti64x4_epi64
(
t7
,
1
));
tf
=
_mm512_mask_blend_epi64
(
mask1
,
rf
,
v7
);
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
30
*
out_width
),
_mm512_extracti64x4_epi64
(
tf
,
0
));
_mm256_storeu_si256
((
__m256i
*
)(
out
+
16
+
31
*
out_width
),
_mm512_extracti64x4_epi64
(
tf
,
1
));
#else
LIBXSMM_UNUSED
(
in
);
LIBXSMM_UNUSED
(
out
);
#endif
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
# define LIBXSMM_DNN_RNN_RELU_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
# define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
# define LIBXSMM_DNN_RNN_TANH_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_ck_generic.tpl.c"
}
else
{
/* should not happen */
}
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu
(
handle
,
kind
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16_amx.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
#define LIBXSMM_RNN_CELL_AVX512
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic_bf16_amx.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#endif
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_bf16_amx.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__ */
#define LIBXSMM_RNN_CELL_AVX512
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_ncnc_kcck_bf16_amx.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu
(
handle
,
kind
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
#define LIBXSMM_RNN_CELL_AVX512
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16_amx.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_bf16_amx.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#endif
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
#define LIBXSMM_RNN_CELL_AVX512
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
# define LIBXSMM_DNN_RNN_RELU_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
# define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
# define LIBXSMM_DNN_RNN_TANH_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_kcck.tpl.c"
}
else
{
/* should not happen */
}
#undef LIBXSMM_RNN_CELL_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
#if 0
typedef float element_input_type;
typedef float element_output_type;
typedef float element_filter_type;
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_ncnc_kcck_generic.tpl.c"
#endif
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
LIBXSMM_UNUSED
(
kind
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_f32_f32
(
handle
,
kind
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
if
(
handle
->
desc
.
N
%
2
!=
0
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu
(
handle
,
kind
,
start_thread
,
tid
);
}
else
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16
(
handle
,
kind
,
start_thread
,
tid
);
}
else
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_amx
(
handle
,
kind
,
start_thread
,
tid
);
}
#else
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CORE
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_ck_bf16_bf16_emu
(
handle
,
kind
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
#define LIBXSMM_DNN_RNN_RELU_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
#define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
#define LIBXSMM_DNN_RNN_TANH_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_ck_generic.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_ck_generic.tpl.c"
}
else
{
/* should not happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_f32_f32
(
handle
,
kind
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
if
(
handle
->
desc
.
N
%
2
!=
0
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu
(
handle
,
kind
,
start_thread
,
tid
);
}
else
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16
(
handle
,
kind
,
start_thread
,
tid
);
}
else
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx
(
handle
,
kind
,
start_thread
,
tid
);
}
#else
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_emu
(
handle
,
kind
,
start_thread
,
tid
);
}
else
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck_bf16_bf16_amx
(
handle
,
kind
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
#define LIBXSMM_DNN_RNN_RELU_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
#define LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
#define LIBXSMM_DNN_RNN_TANH_BWDUPD
# include "template/libxsmm_dnn_rnncell_st_rnn_bwdupd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_BWDUPD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
# include "template/libxsmm_dnn_rnncell_st_gru_bwdupd_nc_kcck.tpl.c"
}
else
{
/* should not happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32
(
handle
,
kind
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx
(
handle
,
kind
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#elif defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_f32_f32
(
handle
,
kind
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck_bf16_bf16_amx
(
handle
,
kind
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
LIBXSMM_UNUSED
(
kind
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_rnncell_backward_weight_update.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_RNNCELL_BACKWARD_WEIGHT_UPDATE_H
#define LIBXSMM_DNN_RNNCELL_BACKWARD_WEIGHT_UPDATE_H
#include <libxsmm_dnn.h>
#include <libxsmm_dnn_rnncell.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_ck
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_nc_kcck
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_bwdupd_ncnc_kcck
(
libxsmm_dnn_rnncell
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_RNNCELL_BACKWARD_WEIGHT_UPDATE_H */
third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Kunal Banerjee (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_rnncell_forward.h"
#include "libxsmm_dnn_elementwise.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
# define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
# define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
# define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c"
}
else
{
/* should not happen */
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__, __AVX512BW__, __AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__, __AVX512BW__, __AVX512DQ__, __AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__, __AVX512BW__, __AVX512DQ__, __AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__, __AVX512BW__, __AVX512DQ__ */
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#endif
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
# define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
# define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
# define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
# define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_RELU_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
# define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_SIGMOID_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
# define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
# undef LIBXSMM_DNN_RNN_TANH_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c"
}
else
{
/* should not happen */
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_ncnc_kcck_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__ */
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
#define LIBXSMM_RNN_CELL_AVX512
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16_amx.tpl.c"
#undef LIBXSMM_RNN_CELL_AVX512
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#endif
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32
(
handle
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
#define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
#define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
#define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c"
}
else
{
/* should not happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#elif defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
#define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
#define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
#define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
status
=
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
;
}
else
{
/* should not happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
#if 0
if (handle->? == 0 ) {
status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
return status;
}
#endif
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
libxsmm_target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32
(
handle
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
libxsmm_target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_RELU
)
{
#define LIBXSMM_DNN_RNN_RELU_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_RELU_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_SIGMOID
)
{
#define LIBXSMM_DNN_RNN_SIGMOID_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_SIGMOID_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_RNN_TANH
)
{
#define LIBXSMM_DNN_RNN_TANH_FWD
# include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
#undef LIBXSMM_DNN_RNN_TANH_FWD
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_LSTM
)
{
# include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c"
}
else
if
(
handle
->
desc
.
cell_type
==
LIBXSMM_DNN_RNNCELL_GRU
)
{
# include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c"
}
else
{
/* should not happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_rnncell_forward.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_RNNCELL_FORWARD_H
#define LIBXSMM_DNN_RNNCELL_FORWARD_H
#include <libxsmm_dnn.h>
#include <libxsmm_dnn_rnncell.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_ck
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_rnncell_st_fwd_nc_kcck
(
libxsmm_dnn_rnncell
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_RNNCELL_FORWARD_H */
third_party/libxsmm/src/libxsmm_dnn_softmaxloss.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_softmaxloss_backward.h"
#include "libxsmm_dnn_softmaxloss_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API
libxsmm_dnn_softmaxloss
*
libxsmm_dnn_create_softmaxloss
(
libxsmm_dnn_softmaxloss_desc
softmaxloss_desc
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_softmaxloss
*
handle
=
0
;
int
lpb
;
/* init libxsmm */
LIBXSMM_INIT
if
(
(
softmaxloss_desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_F32
)
||
(
softmaxloss_desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle
=
(
libxsmm_dnn_softmaxloss
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_softmaxloss
));
if
(
0
!=
handle
)
{
*
status
=
LIBXSMM_DNN_SUCCESS
;
/* let's make the description persistent */
handle
->
desc
=
softmaxloss_desc
;
/* cnn */
if
(
(
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
int
bk
;
/* we need to compute the memory layout given the */
*
status
=
libxsmm_dnn_get_feature_map_blocks
(
handle
->
desc
.
C
,
handle
->
desc
.
C
,
&
(
handle
->
bc
),
&
bk
,
&
lpb
,
handle
->
desc
.
datatype
,
handle
->
desc
.
datatype
);
/* compute the outer blocks */
handle
->
Bc
=
handle
->
desc
.
C
/
handle
->
bc
;
handle
->
bn
=
1
;
handle
->
Bn
=
handle
->
desc
.
N
;
}
else
if
(
(
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED
)
>
0
)
{
handle
->
bc
=
handle
->
desc
.
bc
;
handle
->
bn
=
handle
->
desc
.
bn
;
handle
->
Bc
=
handle
->
desc
.
C
/
handle
->
bc
;
handle
->
Bn
=
handle
->
desc
.
N
/
handle
->
bn
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_HANDLE
;
free
(
handle
);
handle
=
0
;
return
handle
;
}
/* create barrier */
handle
->
barrier
=
libxsmm_barrier_create
(
handle
->
desc
.
threads
,
1
);
/* calculate scratch size for local softmaxloss copies of one feature map block per thread */
if
(
softmaxloss_desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
handle
->
scratch_size
=
(
sizeof
(
float
)
*
handle
->
desc
.
C
*
handle
->
desc
.
N
*
2
);
}
else
{
handle
->
scratch_size
=
1
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_HANDLE
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
return
handle
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_destroy_softmaxloss
(
const
libxsmm_dnn_softmaxloss
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
/* Deallocate barrier */
if
(
handle
->
barrier
!=
0
)
{
libxsmm_barrier_release
((
const
libxsmm_barrier
*
)
handle
->
barrier
);
}
/* deallocate handle structure */
free
(
/*remove constness*/
(
libxsmm_dnn_softmaxloss
*
)
handle
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor_datalayout
*
libxsmm_dnn_softmaxloss_create_tensor_datalayout
(
const
libxsmm_dnn_softmaxloss
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor_datalayout
*
layout
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
layout
=
0
;
if
(
handle
!=
0
)
{
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout
=
(
libxsmm_dnn_tensor_datalayout
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_tensor_datalayout
));
if
(
layout
!=
0
)
{
layout
->
format
=
handle
->
desc
.
buffer_format
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
||
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
if
((
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
layout
->
datatype
=
handle
->
desc
.
datatype
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
3
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
3
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
layout
->
num_dims
=
3
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
bc
;
layout
->
dim_size
[
1
]
=
handle
->
Bc
;
layout
->
dim_size
[
2
]
=
handle
->
desc
.
N
;
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS
;
}
}
else
if
((
handle
->
desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED
)
>
0
)
{
layout
->
datatype
=
handle
->
desc
.
datatype
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
layout
->
num_dims
=
4
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
bc
;
layout
->
dim_size
[
1
]
=
handle
->
bn
;
layout
->
dim_size
[
2
]
=
handle
->
Bc
;
layout
->
dim_size
[
3
]
=
handle
->
Bn
;
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
type
==
LIBXSMM_DNN_LABEL
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_I32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
1
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
1
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
layout
->
num_dims
=
1
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
desc
.
N
;
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
layout
;
}
LIBXSMM_API
size_t
libxsmm_dnn_softmaxloss_get_scratch_size
(
const
libxsmm_dnn_softmaxloss
*
handle
,
libxsmm_dnn_err_t
*
status
)
{
size_t
l_scratch_size
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
l_scratch_size
=
handle
->
scratch_size
+
64
;
/* 64 byte extra in case the user code does not care about alignment */
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
l_scratch_size
;
}
LIBXSMM_API
void
*
libxsmm_dnn_softmaxloss_get_scratch_ptr
(
const
libxsmm_dnn_softmaxloss
*
handle
,
libxsmm_dnn_err_t
*
status
)
{
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
return
handle
->
scratch
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
0
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_bind_scratch
(
libxsmm_dnn_softmaxloss
*
handle
,
const
void
*
scratch
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
uintptr_t
address
=
(
uintptr_t
)
scratch
;
size_t
offset
=
0
;
if
(
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
if
(
0
!=
handle
)
{
/* align the internal scratch buffer if needed */
if
(
address
%
64
==
0
)
{
handle
->
scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch
=
(
void
*
)(
address
+
offset
);
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_release_scratch
(
libxsmm_dnn_softmaxloss
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
handle
->
scratch
=
0
;
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_bind_tensor
(
libxsmm_dnn_softmaxloss
*
handle
,
const
libxsmm_dnn_tensor
*
tensor
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_LABEL
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
&&
tensor
!=
0
)
{
libxsmm_dnn_tensor_datalayout
*
handle_layout
=
libxsmm_dnn_softmaxloss_create_tensor_datalayout
(
handle
,
type
,
&
status
);
if
(
libxsmm_dnn_compare_tensor_datalayout
(
handle_layout
,
tensor
->
layout
,
&
status
)
==
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
handle
->
reg_input
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
handle
->
grad_input
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
handle
->
reg_output
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_LABEL
)
{
handle
->
label
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_MISMATCH_TENSOR
;
}
libxsmm_dnn_destroy_tensor_datalayout
(
handle_layout
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor
*
libxsmm_dnn_softmaxloss_get_tensor
(
libxsmm_dnn_softmaxloss
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor
*
return_tensor
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_LABEL
)
)
{
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
return_tensor
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
return_tensor
=
handle
->
reg_input
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
return_tensor
=
handle
->
grad_input
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
return_tensor
=
handle
->
reg_output
;
}
else
if
(
type
==
LIBXSMM_DNN_LABEL
)
{
return_tensor
=
handle
->
label
;
}
else
{
/* cannot happen */
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
return_tensor
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_release_tensor
(
libxsmm_dnn_softmaxloss
*
handle
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_LABEL
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
handle
->
reg_input
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
handle
->
grad_input
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
handle
->
reg_output
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_LABEL
)
{
handle
->
label
=
0
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_execute_st
(
libxsmm_dnn_softmaxloss
*
handle
,
libxsmm_dnn_compute_kind
kind
,
/*unsigned*/
int
start_thread
,
/*unsigned*/
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
status
=
libxsmm_dnn_softmaxloss_st_fwd_ncnc
(
handle
,
start_thread
,
tid
);
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
{
status
=
libxsmm_dnn_softmaxloss_st_bwd_ncnc
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
float
libxsmm_dnn_softmaxloss_get_loss
(
const
libxsmm_dnn_softmaxloss
*
handle
,
libxsmm_dnn_err_t
*
status
)
{
float
l_loss
=
0
.
0
f
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
l_loss
=
handle
->
loss
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
l_loss
;
}
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_softmaxloss_backward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_bwd_ncnc_f32_f32
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_bwd_ncnc_bf16_bf16
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_bwd_ncnc_f32_f32
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
int
element_label_type
;
# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_bwd_ncnc_bf16_bf16
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
int
element_label_type
;
# define LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512
# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c"
# undef LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_bwd_ncnc
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and mask */
if
(
handle
->
grad_input
==
0
||
handle
->
reg_output
==
0
||
handle
->
label
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
{
if
(
handle
->
desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_softmaxloss_st_bwd_ncnc_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
status
=
libxsmm_dnn_softmaxloss_st_bwd_ncnc_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
int
element_label_type
;
# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c"
}
else
if
(
handle
->
desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
int
element_label_type
;
# define LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16
# include "template/libxsmm_dnn_softmaxloss_st_bwd_ncnc_generic.tpl.c"
# undef LIBXSMM_DNN_SOFTMAXLOSS_BWD_BF16
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_backward.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_SOFTMAXLOSS_BACKWARD_H
#define LIBXSMM_DNN_SOFTMAXLOSS_BACKWARD_H
#include <libxsmm_dnn_softmaxloss.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_bwd_ncnc
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_SOFTMAXLOSS_BACKWARD_H */
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_softmaxloss_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
int
element_label_type
;
# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
int
element_label_type
;
# define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512
# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
# undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16_AVX512
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_fwd_ncnc
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and mask */
if
(
handle
->
reg_input
==
0
||
handle
->
reg_output
==
0
||
handle
->
label
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512
)
{
if
(
handle
->
desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_softmaxloss_st_fwd_ncnc_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
status
=
libxsmm_dnn_softmaxloss_st_fwd_ncnc_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
int
element_label_type
;
# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
}
else
if
(
handle
->
desc
.
datatype
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
int
element_label_type
;
# define LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16
# include "template/libxsmm_dnn_softmaxloss_st_fwd_ncnc_generic.tpl.c"
# undef LIBXSMM_DNN_SOFTMAXLOSS_FWD_BF16
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_softmaxloss_forward.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_SOFTMAXLOSS_FORWARD_H
#define LIBXSMM_DNN_SOFTMAXLOSS_FORWARD_H
#include <libxsmm_dnn_softmaxloss.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_softmaxloss_st_fwd_ncnc
(
libxsmm_dnn_softmaxloss
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_SOFTMAXLOSS_FORWARD_H */
third_party/libxsmm/src/libxsmm_dnn_tensor.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include <libxsmm.h>
#include "libxsmm_main.h"
#include "libxsmm_dnn_tensor.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(_OPENMP)
# include <omp.h>
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
LIBXSMM_API
libxsmm_dnn_tensor
*
libxsmm_dnn_link_tensor
(
const
libxsmm_dnn_tensor_datalayout
*
layout
,
const
void
*
data
,
libxsmm_dnn_err_t
*
status
)
{
return
libxsmm_dnn_link_qtensor
(
layout
,
data
,
0
,
status
);
}
LIBXSMM_API
libxsmm_dnn_tensor
*
libxsmm_dnn_link_qtensor
(
const
libxsmm_dnn_tensor_datalayout
*
layout
,
const
void
*
data
,
const
unsigned
char
scf
,
libxsmm_dnn_err_t
*
status
)
{
/* zero entire content; not only safer but also sets data and code pointers to NULL */
libxsmm_dnn_tensor
*
tensor
=
(
libxsmm_dnn_tensor
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_tensor
));
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
layout
!=
0
&&
tensor
!=
0
&&
data
!=
0
)
{
tensor
->
layout
=
libxsmm_dnn_duplicate_tensor_datalayout
(
layout
,
status
);
tensor
->
data
=
(
void
*
)
data
;
tensor
->
scf
=
scf
;
/* when layout copy failed, free layout */
if
(
*
status
!=
LIBXSMM_DNN_SUCCESS
)
{
libxsmm_dnn_destroy_tensor_datalayout
(
tensor
->
layout
);
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_TENSOR
;
}
if
(
*
status
!=
LIBXSMM_DNN_SUCCESS
)
{
free
((
libxsmm_dnn_tensor
*
)
tensor
);
tensor
=
0
;
}
return
tensor
;
}
LIBXSMM_API
libxsmm_dnn_tensor_datalayout
*
libxsmm_dnn_duplicate_tensor_datalayout
(
const
libxsmm_dnn_tensor_datalayout
*
layout
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor_datalayout
*
dst_layout
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
dst_layout
=
0
;
if
(
layout
!=
0
&&
layout
->
num_dims
!=
0
)
{
unsigned
int
dim
=
0
;
/* zero entire content; not only safer but also sets data and code pointers to NULL */
dst_layout
=
(
libxsmm_dnn_tensor_datalayout
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_tensor_datalayout
));
if
(
0
!=
dst_layout
)
{
dst_layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
layout
->
num_dims
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
dst_layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
layout
->
num_dims
*
sizeof
(
unsigned
int
));
dst_layout
->
num_dims
=
layout
->
num_dims
;
dst_layout
->
format
=
layout
->
format
;
dst_layout
->
datatype
=
layout
->
datatype
;
dst_layout
->
tensor_type
=
layout
->
tensor_type
;
if
(
0
!=
dst_layout
->
dim_type
&&
0
!=
dst_layout
->
dim_size
)
{
for
(
dim
=
0
;
dim
<
layout
->
num_dims
;
++
dim
)
{
dst_layout
->
dim_type
[
dim
]
=
layout
->
dim_type
[
dim
];
dst_layout
->
dim_size
[
dim
]
=
layout
->
dim_size
[
dim
];
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_LAYOUT
;
}
return
dst_layout
;
}
LIBXSMM_API
unsigned
int
libxsmm_dnn_compare_tensor_datalayout
(
const
libxsmm_dnn_tensor_datalayout
*
layout_a
,
const
libxsmm_dnn_tensor_datalayout
*
layout_b
,
libxsmm_dnn_err_t
*
status
)
{
unsigned
int
result
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
layout_a
!=
0
&&
layout_b
!=
0
)
{
unsigned
int
dim
=
0
;
if
(
layout_a
->
num_dims
!=
layout_b
->
num_dims
)
{
result
=
1
;
}
if
(
layout_a
->
format
!=
layout_b
->
format
)
{
result
=
1
;
}
if
(
layout_a
->
datatype
!=
layout_b
->
datatype
)
{
result
=
1
;
}
if
(
result
==
0
)
{
for
(
dim
=
0
;
dim
<
layout_a
->
num_dims
;
++
dim
)
{
if
(
layout_a
->
dim_type
[
dim
]
!=
layout_b
->
dim_type
[
dim
]
)
{
result
=
1
;
}
if
(
layout_a
->
dim_size
[
dim
]
!=
layout_b
->
dim_size
[
dim
]
)
{
result
=
1
;
}
}
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_LAYOUT
;
result
=
100
;
}
return
result
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_destroy_tensor_datalayout
(
libxsmm_dnn_tensor_datalayout
*
layout
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
layout
)
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_LAYOUT
;
}
return
status
;
}
LIBXSMM_API
unsigned
int
libxsmm_dnn_get_tensor_size
(
const
libxsmm_dnn_tensor_datalayout
*
layout
,
libxsmm_dnn_err_t
*
status
)
{
unsigned
int
size
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
layout
)
{
unsigned
int
dim
=
0
;
size
=
(
unsigned
int
)
libxsmm_dnn_typesize
(
layout
->
datatype
);
for
(
dim
=
0
;
dim
<
layout
->
num_dims
;
++
dim
)
{
size
*=
layout
->
dim_size
[
dim
];
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_LAYOUT
;
}
return
size
;
}
LIBXSMM_API
unsigned
int
libxsmm_dnn_get_tensor_elements
(
const
libxsmm_dnn_tensor_datalayout
*
layout
,
libxsmm_dnn_err_t
*
status
)
{
unsigned
int
elements
=
1
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
layout
)
{
unsigned
int
dim
=
0
;
for
(
dim
=
0
;
dim
<
layout
->
num_dims
;
++
dim
)
{
elements
*=
layout
->
dim_size
[
dim
];
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_LAYOUT
;
elements
=
0
;
}
return
elements
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_set_tensor_data_ptr
(
libxsmm_dnn_tensor
*
tensor
,
const
void
*
data
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
((
0
!=
tensor
)
&&
(
0
!=
data
))
{
if
(
0
!=
tensor
->
layout
)
{
if
(
0
<
tensor
->
layout
->
num_dims
)
{
tensor
->
data
=
(
void
*
)
data
;
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_LAYOUT
;
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_LAYOUT
;
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
return
status
;
}
LIBXSMM_API
void
*
libxsmm_dnn_get_tensor_data_ptr
(
const
libxsmm_dnn_tensor
*
tensor
,
libxsmm_dnn_err_t
*
status
)
{
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
tensor
)
{
return
tensor
->
data
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
return
0
;
}
LIBXSMM_API
libxsmm_dnn_tensor_datalayout
*
libxsmm_dnn_get_tensor_datalayout
(
const
libxsmm_dnn_tensor
*
tensor
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor_datalayout
*
dst_layout
=
NULL
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
tensor
)
{
dst_layout
=
libxsmm_dnn_duplicate_tensor_datalayout
(
tensor
->
layout
,
status
);
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
return
dst_layout
;
}
LIBXSMM_API
unsigned
char
libxsmm_dnn_get_qtensor_scf
(
const
libxsmm_dnn_tensor
*
tensor
,
libxsmm_dnn_err_t
*
status
)
{
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
tensor
)
{
return
tensor
->
scf
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
return
0
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_set_qtensor_scf
(
libxsmm_dnn_tensor
*
tensor
,
const
unsigned
char
scf
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
tensor
)
{
tensor
->
scf
=
scf
;
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_destroy_tensor
(
const
libxsmm_dnn_tensor
*
tensor
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
tensor
)
{
/* it is not an error attempting to destroy a NULL-handle */
/* free layout information stored in tensor */
if
(
0
!=
tensor
->
layout
)
{
libxsmm_dnn_destroy_tensor_datalayout
(
(
libxsmm_dnn_tensor_datalayout
*
)
tensor
->
layout
);
}
/* deallocate handle structure */
free
(
/*remove constness*/
(
libxsmm_dnn_tensor
*
)
tensor
);
}
#if 0 /* releasing a NULL-buffer should be not an error (similar to freeing a NULL pointer) */
else {
status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
}
#endif
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_copyin_tensor
(
const
libxsmm_dnn_tensor
*
tensor
,
const
void
*
data
,
const
libxsmm_dnn_tensor_format
in_format
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* @TODO check for valid combination */
if
(
0
!=
tensor
)
{
switch
(
tensor
->
layout
->
tensor_type
)
{
case
LIBXSMM_DNN_REGULAR_INPUT
:
case
LIBXSMM_DNN_GRADIENT_INPUT
:
case
LIBXSMM_DNN_REGULAR_OUTPUT
:
case
LIBXSMM_DNN_GRADIENT_OUTPUT
:
case
LIBXSMM_DNN_INPUT
:
case
LIBXSMM_DNN_OUTPUT
:
case
LIBXSMM_DNN_ACTIVATION
:
{
switch
(
in_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_NCHW
:
{
if
(
(
tensor
->
layout
->
format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
switch
(
tensor
->
layout
->
datatype
)
{
case
LIBXSMM_DNN_DATATYPE_F32
:
{
typedef
float
element_type
;
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_BF16
:
{
typedef
libxsmm_bfloat16
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
case
LIBXSMM_DNN_DATATYPE_I32
:
{
typedef
int
element_type
;
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_I16
:
{
typedef
short
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
case
LIBXSMM_DNN_DATATYPE_I8
:
{
typedef
unsigned
char
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT
;
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT
;
}
}
}
break
;
case
LIBXSMM_DNN_REGULAR_FILTER
:
case
LIBXSMM_DNN_GRADIENT_FILTER
:
case
LIBXSMM_DNN_FILTER
:
{
switch
(
in_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_KCRS
:
{
if
(
(
tensor
->
layout
->
format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
switch
(
tensor
->
layout
->
datatype
)
{
case
LIBXSMM_DNN_DATATYPE_F32
:
{
typedef
float
element_type
;
#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_BF16
:
{
typedef
libxsmm_bfloat16
element_type
;
#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_I16
:
{
typedef
short
element_type
;
#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_I8
:
{
typedef
char
element_type
;
#include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT
;
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT
;
}
}
}
break
;
case
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
:
case
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
:
case
LIBXSMM_DNN_CHANNEL_BIAS
:
case
LIBXSMM_DNN_REGULAR_CHANNEL_BETA
:
case
LIBXSMM_DNN_GRADIENT_CHANNEL_BETA
:
case
LIBXSMM_DNN_CHANNEL_BETA
:
case
LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA
:
case
LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA
:
case
LIBXSMM_DNN_CHANNEL_GAMMA
:
case
LIBXSMM_DNN_CHANNEL_EXPECTVAL
:
case
LIBXSMM_DNN_CHANNEL_RCPSTDDEV
:
case
LIBXSMM_DNN_CHANNEL_VARIANCE
:
case
LIBXSMM_DNN_CHANNEL_SCALAR
:
{
switch
(
in_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_NCHW
:
{
if
(
(
tensor
->
layout
->
format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
switch
(
tensor
->
layout
->
datatype
)
{
case
LIBXSMM_DNN_DATATYPE_F32
:
{
typedef
float
element_type
;
#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_BF16
:
{
typedef
libxsmm_bfloat16
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
case
LIBXSMM_DNN_DATATYPE_I16
:
{
typedef
short
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
case
LIBXSMM_DNN_DATATYPE_I8
:
{
typedef
char
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT
;
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_zero_tensor
(
const
libxsmm_dnn_tensor
*
tensor
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
tensor
)
{
const
size_t
size
=
libxsmm_dnn_get_tensor_elements
(
tensor
->
layout
,
&
status
);
size_t
i
;
/* use for-loops to potentially leverage NUMA in the future */
switch
(
tensor
->
layout
->
datatype
)
{
case
LIBXSMM_DNN_DATATYPE_F32
:
{
float
*
fp32_data
=
(
float
*
)
tensor
->
data
;
for
(
i
=
0
;
i
<
size
;
++
i
)
fp32_data
[
i
]
=
0
.
0
f
;
}
break
;
case
LIBXSMM_DNN_DATATYPE_BF16
:
{
libxsmm_bfloat16
*
bfp16_data
=
(
libxsmm_bfloat16
*
)
tensor
->
data
;
for
(
i
=
0
;
i
<
size
;
++
i
)
bfp16_data
[
i
]
=
0
;
}
break
;
case
LIBXSMM_DNN_DATATYPE_I32
:
{
int
*
int32_data
=
(
int
*
)
tensor
->
data
;
for
(
i
=
0
;
i
<
size
;
++
i
)
int32_data
[
i
]
=
0
;
}
break
;
case
LIBXSMM_DNN_DATATYPE_I16
:
{
short
*
int16_data
=
(
short
*
)
tensor
->
data
;
for
(
i
=
0
;
i
<
size
;
++
i
)
int16_data
[
i
]
=
0
;
}
break
;
case
LIBXSMM_DNN_DATATYPE_I8
:
{
char
*
int8_data
=
(
char
*
)
tensor
->
data
;
for
(
i
=
0
;
i
<
size
;
++
i
)
int8_data
[
i
]
=
0
;
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_copyout_tensor
(
const
libxsmm_dnn_tensor
*
tensor
,
void
*
data
,
const
libxsmm_dnn_tensor_format
out_format
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* @TODO check for valid combination */
if
(
0
!=
tensor
)
{
switch
(
tensor
->
layout
->
tensor_type
)
{
case
LIBXSMM_DNN_REGULAR_INPUT
:
case
LIBXSMM_DNN_GRADIENT_INPUT
:
case
LIBXSMM_DNN_REGULAR_OUTPUT
:
case
LIBXSMM_DNN_GRADIENT_OUTPUT
:
case
LIBXSMM_DNN_INPUT
:
case
LIBXSMM_DNN_OUTPUT
:
case
LIBXSMM_DNN_ACTIVATION
:
{
switch
(
out_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_NCHW
:
{
if
(
(
tensor
->
layout
->
format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
switch
(
tensor
->
layout
->
datatype
)
{
case
LIBXSMM_DNN_DATATYPE_F32
:
{
typedef
float
element_type
;
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_BF16
:
{
typedef
libxsmm_bfloat16
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
case
LIBXSMM_DNN_DATATYPE_I32
:
{
typedef
int
element_type
;
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_I16
:
{
typedef
short
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
case
LIBXSMM_DNN_DATATYPE_I8
:
{
typedef
unsigned
char
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT
;
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT
;
}
}
}
break
;
case
LIBXSMM_DNN_REGULAR_FILTER
:
case
LIBXSMM_DNN_GRADIENT_FILTER
:
case
LIBXSMM_DNN_FILTER
:
{
switch
(
out_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_KCRS
:
{
if
(
(
tensor
->
layout
->
format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
switch
(
tensor
->
layout
->
datatype
)
{
case
LIBXSMM_DNN_DATATYPE_F32
:
{
typedef
float
element_type
;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_BF16
:
{
typedef
libxsmm_bfloat16
element_type
;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_I32
:
{
typedef
int
element_type
;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_I16
:
{
typedef
short
element_type
;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_I8
:
{
typedef
char
element_type
;
#include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT
;
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT
;
}
}
}
break
;
case
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
:
case
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
:
case
LIBXSMM_DNN_CHANNEL_BIAS
:
case
LIBXSMM_DNN_REGULAR_CHANNEL_BETA
:
case
LIBXSMM_DNN_GRADIENT_CHANNEL_BETA
:
case
LIBXSMM_DNN_CHANNEL_BETA
:
case
LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA
:
case
LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA
:
case
LIBXSMM_DNN_CHANNEL_GAMMA
:
case
LIBXSMM_DNN_CHANNEL_EXPECTVAL
:
case
LIBXSMM_DNN_CHANNEL_RCPSTDDEV
:
case
LIBXSMM_DNN_CHANNEL_VARIANCE
:
case
LIBXSMM_DNN_CHANNEL_SCALAR
:
{
switch
(
out_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_NCHW
:
{
if
(
(
tensor
->
layout
->
format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
switch
(
tensor
->
layout
->
datatype
)
{
case
LIBXSMM_DNN_DATATYPE_F32
:
{
typedef
float
element_type
;
#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
}
break
;
case
LIBXSMM_DNN_DATATYPE_BF16
:
{
typedef
libxsmm_bfloat16
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
case
LIBXSMM_DNN_DATATYPE_I16
:
{
typedef
short
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
case
LIBXSMM_DNN_DATATYPE_I8
:
{
typedef
char
element_type
;
#define LIBXSMM_DNN_COPY_LOW_PRECISION
#include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
#undef LIBXSMM_DNN_COPY_LOW_PRECISION
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT
;
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
return
status
;
}
third_party/libxsmm/src/libxsmm_ext.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include "libxsmm_ext.h"
#include "libxsmm_gemm.h"
#include <libxsmm.h>
#if defined(LIBXSMM_BUILD)
#if defined(LIBXSMM_BUILD_EXT) && !defined(__STATIC)
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_WEAK
void
LIBXSMM_FSYMBOL
(
dgemm_batch
)(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
double
alpha_array
[],
const
double
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
double
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
double
beta_array
[],
double
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
LIBXSMM_BLAS_NOEXCEPT
(
gemm_batch
)
{
if
(
LIBXSMM_FSYMBOL
(
__real_dgemm_batch
)
!=
libxsmm_original_dgemm_batch_function
)
{
LIBXSMM_FSYMBOL
(
__wrap_dgemm_batch
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
else
{
libxsmm_blas_error
(
"dgemm_batch"
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_WEAK
void
LIBXSMM_FSYMBOL
(
sgemm_batch
)(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
float
alpha_array
[],
const
float
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
float
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
float
beta_array
[],
float
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
LIBXSMM_BLAS_NOEXCEPT
(
gemm_batch
)
{
if
(
LIBXSMM_FSYMBOL
(
__real_sgemm_batch
)
!=
libxsmm_original_sgemm_batch_function
)
{
LIBXSMM_FSYMBOL
(
__wrap_sgemm_batch
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
else
{
libxsmm_blas_error
(
"sgemm_batch"
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_WEAK
void
LIBXSMM_FSYMBOL
(
dgemm
)(
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
double
*
alpha
,
const
double
*
a
,
const
libxsmm_blasint
*
lda
,
const
double
*
b
,
const
libxsmm_blasint
*
ldb
,
const
double
*
beta
,
double
*
c
,
const
libxsmm_blasint
*
ldc
)
LIBXSMM_BLAS_NOEXCEPT
(
gemm
)
{
if
(
LIBXSMM_FSYMBOL
(
__real_dgemm
)
!=
libxsmm_original_dgemm_function
)
{
LIBXSMM_FSYMBOL
(
__wrap_dgemm
)(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
else
{
libxsmm_blas_error
(
"dgemm"
)(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_WEAK
void
LIBXSMM_FSYMBOL
(
sgemm
)(
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
float
*
alpha
,
const
float
*
a
,
const
libxsmm_blasint
*
lda
,
const
float
*
b
,
const
libxsmm_blasint
*
ldb
,
const
float
*
beta
,
float
*
c
,
const
libxsmm_blasint
*
ldc
)
LIBXSMM_BLAS_NOEXCEPT
(
gemm
)
{
if
(
LIBXSMM_FSYMBOL
(
__real_sgemm
)
!=
libxsmm_original_sgemm_function
)
{
LIBXSMM_FSYMBOL
(
__wrap_sgemm
)(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
else
{
libxsmm_blas_error
(
"sgemm"
)(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_WEAK
void
LIBXSMM_FSYMBOL
(
dgemv
)(
const
char
*
trans
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
double
*
alpha
,
const
double
*
a
,
const
libxsmm_blasint
*
lda
,
const
double
*
x
,
const
libxsmm_blasint
*
incx
,
const
double
*
beta
,
double
*
y
,
const
libxsmm_blasint
*
incy
)
LIBXSMM_BLAS_NOEXCEPT
(
gemv
)
{
if
(
LIBXSMM_FSYMBOL
(
__real_dgemv
)
!=
libxsmm_original_dgemv_function
)
{
LIBXSMM_FSYMBOL
(
__wrap_dgemv
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
else
{
libxsmm_blas_error
(
"dgemv"
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_WEAK
void
LIBXSMM_FSYMBOL
(
sgemv
)(
const
char
*
trans
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
float
*
alpha
,
const
float
*
a
,
const
libxsmm_blasint
*
lda
,
const
float
*
x
,
const
libxsmm_blasint
*
incx
,
const
float
*
beta
,
float
*
y
,
const
libxsmm_blasint
*
incy
)
LIBXSMM_BLAS_NOEXCEPT
(
gemv
)
{
if
(
LIBXSMM_FSYMBOL
(
__real_sgemv
)
!=
libxsmm_original_sgemv_function
)
{
LIBXSMM_FSYMBOL
(
__wrap_sgemv
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
else
{
libxsmm_blas_error
(
"sgemv"
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_WEAK
void
dgemm_batch
(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
double
alpha_array
[],
const
double
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
double
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
double
beta_array
[],
double
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
LIBXSMM_BLAS_NOEXCEPT
(
gemm_batch
)
{
LIBXSMM_FSYMBOL
(
dgemm_batch
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_WEAK
void
sgemm_batch
(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
float
alpha_array
[],
const
float
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
float
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
float
beta_array
[],
float
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
LIBXSMM_BLAS_NOEXCEPT
(
gemm_batch
)
{
LIBXSMM_FSYMBOL
(
sgemm_batch
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
#elif (0 != LIBXSMM_NO_BLAS)
/* no-BLAS library */
LIBXSMM_APIVAR_PUBLIC_DEF
(
LIBXSMM_ATTRIBUTE_COMMON
unsigned
int
libxsmm_intrinsics_mm512_rng_state0
[
16
]);
LIBXSMM_APIVAR_PUBLIC_DEF
(
LIBXSMM_ATTRIBUTE_COMMON
unsigned
int
libxsmm_intrinsics_mm512_rng_state1
[
16
]);
LIBXSMM_APIVAR_PUBLIC_DEF
(
LIBXSMM_ATTRIBUTE_COMMON
unsigned
int
libxsmm_intrinsics_mm512_rng_state2
[
16
]);
LIBXSMM_APIVAR_PUBLIC_DEF
(
LIBXSMM_ATTRIBUTE_COMMON
unsigned
int
libxsmm_intrinsics_mm512_rng_state3
[
16
]);
LIBXSMM_API_INTERN
LIBXSMM_ATTRIBUTE_NO_TRACE
void
internal_noblas_sink
(
LIBXSMM_VARIADIC
);
LIBXSMM_API_INTERN
void
internal_noblas_sink
(
LIBXSMM_VARIADIC
)
{
/* does nothing else but sinking given arguments */
}
LIBXSMM_API_INTERN
LIBXSMM_ATTRIBUTE_NO_TRACE
libxsmm_sink_function
internal_noblas_error
(
const
char
*
/*symbol*/
);
LIBXSMM_API_INTERN
libxsmm_sink_function
internal_noblas_error
(
const
char
*
symbol
)
{
static
int
internal_noblas_nerror
=
0
;
LIBXSMM_BLAS_ERROR
(
symbol
,
&
internal_noblas_nerror
);
return
internal_noblas_sink
;
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_NO_TRACE
/*LIBXSMM_ATTRIBUTE_WEAK*/
void
LIBXSMM_FSYMBOL
(
dgemm_batch
)(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
double
alpha_array
[],
const
double
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
double
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
double
beta_array
[],
double
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
LIBXSMM_BLAS_NOEXCEPT
(
gemm_batch
)
{
internal_noblas_error
(
"dgemm_batch"
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_NO_TRACE
/*LIBXSMM_ATTRIBUTE_WEAK*/
void
LIBXSMM_FSYMBOL
(
sgemm_batch
)(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
float
alpha_array
[],
const
float
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
float
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
float
beta_array
[],
float
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
LIBXSMM_BLAS_NOEXCEPT
(
gemm_batch
)
{
internal_noblas_error
(
"sgemm_batch"
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_NO_TRACE
/*LIBXSMM_ATTRIBUTE_WEAK*/
void
LIBXSMM_FSYMBOL
(
dgemm
)(
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
double
*
alpha
,
const
double
*
a
,
const
libxsmm_blasint
*
lda
,
const
double
*
b
,
const
libxsmm_blasint
*
ldb
,
const
double
*
beta
,
double
*
c
,
const
libxsmm_blasint
*
ldc
)
LIBXSMM_BLAS_NOEXCEPT
(
gemm
)
{
internal_noblas_error
(
"dgemm"
)(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_NO_TRACE
/*LIBXSMM_ATTRIBUTE_WEAK*/
void
LIBXSMM_FSYMBOL
(
sgemm
)(
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
float
*
alpha
,
const
float
*
a
,
const
libxsmm_blasint
*
lda
,
const
float
*
b
,
const
libxsmm_blasint
*
ldb
,
const
float
*
beta
,
float
*
c
,
const
libxsmm_blasint
*
ldc
)
LIBXSMM_BLAS_NOEXCEPT
(
gemm
)
{
internal_noblas_error
(
"sgemm"
)(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_NO_TRACE
/*LIBXSMM_ATTRIBUTE_WEAK*/
void
LIBXSMM_FSYMBOL
(
dgemv
)(
const
char
*
trans
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
double
*
alpha
,
const
double
*
a
,
const
libxsmm_blasint
*
lda
,
const
double
*
x
,
const
libxsmm_blasint
*
incx
,
const
double
*
beta
,
double
*
y
,
const
libxsmm_blasint
*
incy
)
LIBXSMM_BLAS_NOEXCEPT
(
gemv
)
{
internal_noblas_error
(
"dgemv"
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_NO_TRACE
/*LIBXSMM_ATTRIBUTE_WEAK*/
void
LIBXSMM_FSYMBOL
(
sgemv
)(
const
char
*
trans
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
float
*
alpha
,
const
float
*
a
,
const
libxsmm_blasint
*
lda
,
const
float
*
x
,
const
libxsmm_blasint
*
incx
,
const
float
*
beta
,
float
*
y
,
const
libxsmm_blasint
*
incy
)
LIBXSMM_BLAS_NOEXCEPT
(
gemv
)
{
internal_noblas_error
(
"sgemv"
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_NO_TRACE
void
dgemm_batch
(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
double
alpha_array
[],
const
double
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
double
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
double
beta_array
[],
double
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
LIBXSMM_BLAS_NOEXCEPT
(
gemm_batch
)
{
LIBXSMM_FSYMBOL
(
dgemm_batch
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
LIBXSMM_BLAS_SYMBOL_VISIBILITY
LIBXSMM_ATTRIBUTE_NO_TRACE
void
sgemm_batch
(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
float
alpha_array
[],
const
float
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
float
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
float
beta_array
[],
float
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
LIBXSMM_BLAS_NOEXCEPT
(
gemm_batch
)
{
LIBXSMM_FSYMBOL
(
sgemm_batch
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
#endif
#endif
/*defined(LIBXSMM_BUILD)*/
third_party/libxsmm/src/libxsmm_ext.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_EXT_H
#define LIBXSMM_EXT_H
#include "libxsmm_main.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#if defined(_OPENMP)
# if !defined(__INTEL_COMPILER)
# if defined(__clang__)
# pragma clang diagnostic push
# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)
# pragma GCC diagnostic push
# endif
# if defined(__clang__)
# pragma clang diagnostic ignored "-Wpedantic"
# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)
# pragma GCC diagnostic ignored "-Wpedantic"
# endif
# endif
# include <omp.h>
# if defined(LIBXSMM_TRACE_CALLERID_GCCBUILTIN) && !defined(__INTEL_COMPILER)
# if defined(__clang__)
# pragma clang diagnostic pop
# elif defined(__GNUC__) && LIBXSMM_VERSION2(4, 6) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)
# pragma GCC diagnostic pop
# endif
# endif
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#endif
/*LIBXSMM_EXT_H*/
third_party/libxsmm/src/libxsmm_ext_gemm.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include <libxsmm.h>
#include "libxsmm_gemm.h"
#include "libxsmm_ext.h"
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
# include "libxsmm_trace.h"
#endif
#if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO) && 0
# define LIBXSMM_EXT_GEMM_PARGROUPS_INFO
#endif
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
# if !defined(LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH)
# define LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO)
# endif
# if !defined(LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH)
# define LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH 8
/*POT*/
# endif
LIBXSMM_APIVAR_DEFINE
(
libxsmm_gemm_descriptor
internal_ext_gemm_batchdesc
[
LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH
]);
LIBXSMM_APIVAR_DEFINE
(
unsigned
int
internal_ext_gemm_batchdepth
);
LIBXSMM_APIVAR_DEFINE
(
unsigned
int
internal_ext_gemm_batchsize
);
#endif
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
LIBXSMM_API_INLINE
int
internal_mmbatch_sortrev
(
const
void
*
stat_a
,
const
void
*
stat_b
)
{
const
libxsmm_mmbatch_item
*
const
a
=
(
const
libxsmm_mmbatch_item
*
)
stat_a
;
const
libxsmm_mmbatch_item
*
const
b
=
(
const
libxsmm_mmbatch_item
*
)
stat_b
;
LIBXSMM_ASSERT
(
NULL
!=
stat_a
&&
NULL
!=
stat_b
);
return
a
->
stat
.
count
<
b
->
stat
.
count
?
1
:
(
b
->
stat
.
count
<
a
->
stat
.
count
?
-
1
:
0
);
}
#endif
/*defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)*/
LIBXSMM_API_INLINE
int
internal_mmbatch_flush
(
const
libxsmm_gemm_descriptor
*
batchdesc
,
libxsmm_blasint
batchsize
,
libxsmm_mmbatch_item
*
batcharray
)
{
int
result
=
EXIT_SUCCESS
;
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
if
(
0
!=
batchsize
)
{
/* recorded/lazy multiplications */
const
libxsmm_blasint
itemsize
=
sizeof
(
libxsmm_mmbatch_item
);
LIBXSMM_ASSERT
(
NULL
!=
batchdesc
&&
0
<
batchsize
);
if
(
0
==
(
LIBXSMM_MMBATCH_FLAG_STATISTIC
&
batchdesc
->
flags
))
{
/* process batch */
const
libxsmm_xmmfunction
kernel
=
libxsmm_xmmdispatch
(
batchdesc
);
if
(
NULL
!=
kernel
.
xmm
)
{
const
unsigned
char
itypesize
=
libxsmm_typesize
((
libxsmm_datatype
)
LIBXSMM_GETENUM_INP
(
batchdesc
->
datatype
));
const
unsigned
char
otypesize
=
libxsmm_typesize
((
libxsmm_datatype
)
LIBXSMM_GETENUM_OUT
(
batchdesc
->
datatype
));
#if defined(_OPENMP)
if
(
0
==
(
LIBXSMM_MMBATCH_FLAG_SEQUENTIAL
&
batchdesc
->
flags
))
{
/* parallelized */
const
int
nchunks
=
(
int
)
LIBXSMM_UPDIV
(
batchsize
,
libxsmm_gemm_taskgrain
);
# if defined(LIBXSMM_EXT_TASKS)
if
(
0
==
omp_get_active_level
())
{
const
int
max_nthreads
=
omp_get_max_threads
();
const
int
nthreads
=
LIBXSMM_MIN
(
max_nthreads
,
nchunks
);
if
(
0
==
libxsmm_gemm_tasks
)
# else
if
(
0
==
omp_in_parallel
())
{
const
int
max_nthreads
=
omp_get_max_threads
();
const
int
nthreads
=
LIBXSMM_MIN
(
max_nthreads
,
nchunks
);
# endif
{
/* classic internal parallelization */
# pragma omp parallel num_threads(nthreads)
/*check*/
libxsmm_mmbatch_kernel
(
kernel
,
0
/*index_base*/
,
0
/*index_stride*/
,
&
itemsize
,
&
itemsize
,
&
itemsize
,
&
batcharray
->
value
.
a
,
&
batcharray
->
value
.
b
,
&
batcharray
->
value
.
c
,
0
==
(
LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED
&
batchdesc
->
flags
)
?
batchsize
:
-
batchsize
,
omp_get_thread_num
(),
nthreads
,
itypesize
,
otypesize
,
batchdesc
->
flags
);
}
# if defined(LIBXSMM_EXT_TASKS)
else
{
/* internal parallelization with tasks */
# pragma omp parallel num_threads(nthreads)
{
/* first thread discovering work will launch all tasks */
# pragma omp single nowait
/* anyone is good */
{
int
tid
;
for
(
tid
=
0
;
tid
<
nchunks
/*ntasks*/
;
++
tid
)
{
# pragma omp task untied
/*check*/
libxsmm_mmbatch_kernel
(
kernel
,
0
/*index_base*/
,
0
/*index_stride*/
,
&
itemsize
,
&
itemsize
,
&
itemsize
,
&
batcharray
->
value
.
a
,
&
batcharray
->
value
.
b
,
&
batcharray
->
value
.
c
,
0
==
(
LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED
&
batchdesc
->
flags
)
?
batchsize
:
-
batchsize
,
tid
,
nchunks
/*ntasks*/
,
itypesize
,
otypesize
,
batchdesc
->
flags
);
}
}
}
/* implicit synchronization (barrier) */
}
# endif
}
else
{
/* assume external parallelization */
int
tid
;
for
(
tid
=
0
;
tid
<
nchunks
/*ntasks*/
;
++
tid
)
{
# if defined(LIBXSMM_EXT_TASKS)
# pragma omp task untied
#endif
/*check*/
libxsmm_mmbatch_kernel
(
kernel
,
0
/*index_base*/
,
0
/*index_stride*/
,
&
itemsize
,
&
itemsize
,
&
itemsize
,
&
batcharray
->
value
.
a
,
&
batcharray
->
value
.
b
,
&
batcharray
->
value
.
c
,
0
==
(
LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED
&
batchdesc
->
flags
)
?
batchsize
:
-
batchsize
,
tid
,
nchunks
/*ntasks*/
,
itypesize
,
otypesize
,
batchdesc
->
flags
);
}
# if defined(LIBXSMM_EXT_TASKS)
if
(
0
==
libxsmm_nosync
)
{
/* allow to omit synchronization */
# pragma omp taskwait
}
# endif
}
}
else
#endif
{
/* sequential */
result
=
libxsmm_mmbatch_kernel
(
kernel
,
0
/*index_base*/
,
0
/*index_stride*/
,
&
itemsize
,
&
itemsize
,
&
itemsize
,
&
batcharray
->
value
.
a
,
&
batcharray
->
value
.
b
,
&
batcharray
->
value
.
c
,
batchsize
,
0
/*tid*/
,
1
/*nthreads*/
,
itypesize
,
otypesize
,
batchdesc
->
flags
);
}
}
else
{
/* no fallback */
/* several reasons to arrive here: try-lock, unsuitable SMM, etc. */
result
=
EXIT_FAILURE
;
}
memset
(
batcharray
,
0
,
(
size_t
)
batchsize
*
(
size_t
)
itemsize
);
/* clear */
}
else
{
/* print statistic */
const
libxsmm_blasint
limit
=
(
LIBXSMM_GEMM_MMBATCH_VERBOSITY
<
libxsmm_verbosity
?
batchsize
/*unlimited*/
:
7
/*limited*/
);
unsigned
int
threshold
,
batchcount
;
libxsmm_blasint
count
=
0
,
i
;
LIBXSMM_ASSERT
(
NULL
!=
batcharray
);
qsort
(
batcharray
,
(
size_t
)
batchsize
,
(
size_t
)
itemsize
,
internal_mmbatch_sortrev
);
batchcount
=
batcharray
[
0
].
stat
.
count
;
threshold
=
((
LIBXSMM_GEMM_MMBATCH_VERBOSITY
<
libxsmm_verbosity
||
3
>=
batchsize
)
?
0
:
(
batchcount
/
2
));
for
(
i
=
1
;
i
<
batchsize
;
++
i
)
batchcount
+=
batcharray
[
i
].
stat
.
count
;
LIBXSMM_STDIO_ACQUIRE
();
for
(
i
=
0
;
i
<
batchsize
;
++
i
)
{
const
libxsmm_gemm_descriptor
descriptor
=
batcharray
[
i
].
stat
.
desc
;
const
libxsmm_blasint
lda
=
descriptor
.
lda
,
ldb
=
descriptor
.
ldb
,
ldc
=
descriptor
.
ldc
;
const
libxsmm_blasint
m
=
descriptor
.
m
,
n
=
descriptor
.
n
,
k
=
descriptor
.
k
;
const
char
*
const
symbol
=
batcharray
[
i
].
stat
.
symbol
;
const
unsigned
int
ci
=
batcharray
[
i
].
stat
.
count
;
LIBXSMM_MEMZERO127
(
batcharray
+
i
);
/* clear */
if
(
threshold
<
ci
&&
count
<
limit
/* limit printed statistic */
&&
0
<
m
&&
0
<
n
&&
0
<
k
)
{
const
unsigned
int
ciperc
=
(
unsigned
int
)(
100
.
0
*
ci
/
batchcount
+
0
.
5
);
if
(
0
!=
ciperc
)
{
LIBXSMM_ASSERT
(
0
!=
ci
);
if
(
0
==
count
)
{
fprintf
(
stderr
,
"
\n
LIBXSMM STATISTIC: %u multiplication%c
\n
"
,
batchcount
,
1
<
batchcount
?
's'
:
' '
);
}
LIBXSMM_GEMM_PRINT2
(
stderr
,
LIBXSMM_GETENUM_INP
(
descriptor
.
datatype
),
LIBXSMM_GETENUM_OUT
(
descriptor
.
datatype
),
descriptor
.
flags
,
m
,
n
,
k
,
/*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & descriptor.flags) ? 0 : */
1
,
NULL
/*a*/
,
lda
,
NULL
/*b*/
,
ldb
,
0
!=
(
LIBXSMM_GEMM_FLAG_BETA_0
&
descriptor
.
flags
)
?
0
:
1
,
NULL
/*c*/
,
ldc
);
if
(
NULL
!=
symbol
&&
0
!=
*
symbol
)
{
fprintf
(
stderr
,
": %u%% [%s]
\n
"
,
ciperc
,
symbol
);
}
else
{
fprintf
(
stderr
,
": %u%%
\n
"
,
ciperc
);
}
++
count
;
}
else
break
;
}
}
LIBXSMM_STDIO_RELEASE
();
}
}
#else
LIBXSMM_UNUSED
(
batchdesc
);
LIBXSMM_UNUSED
(
batchsize
);
LIBXSMM_UNUSED
(
batcharray
);
#endif
return
result
;
}
#if defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT)
#if defined(LIBXSMM_BLAS_WRAP_DYNAMIC)
LIBXSMM_API
libxsmm_dgemm_batch_function
libxsmm_original_dgemm_batch
(
void
)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER
(
1
,
double
,
gemm_batch
,
libxsmm_original_dgemm_batch_function
,
libxsmm_original_dgemm_batch
/*self*/
);
/*LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_batch_function);*/
# else
LIBXSMM_BLAS_WRAPPER
(
0
,
double
,
gemm_batch
,
libxsmm_original_dgemm_batch_function
,
libxsmm_original_dgemm_batch
/*self*/
);
# endif
return
libxsmm_original_dgemm_batch_function
;
}
LIBXSMM_API
libxsmm_sgemm_batch_function
libxsmm_original_sgemm_batch
(
void
)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER
(
1
,
float
,
gemm_batch
,
libxsmm_original_sgemm_batch_function
,
libxsmm_original_sgemm_batch
/*self*/
);
/*LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_batch_function);*/
# else
LIBXSMM_BLAS_WRAPPER
(
0
,
float
,
gemm_batch
,
libxsmm_original_sgemm_batch_function
,
libxsmm_original_sgemm_batch
/*self*/
);
# endif
return
libxsmm_original_sgemm_batch_function
;
}
LIBXSMM_API
libxsmm_dgemm_function
libxsmm_original_dgemm
(
void
)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER
(
1
,
double
,
gemm
,
libxsmm_original_dgemm_function
,
libxsmm_original_dgemm
/*self*/
);
LIBXSMM_ASSERT
(
NULL
!=
libxsmm_original_dgemm_function
);
# else
LIBXSMM_BLAS_WRAPPER
(
0
,
double
,
gemm
,
libxsmm_original_dgemm_function
,
libxsmm_original_dgemm
/*self*/
);
# endif
return
libxsmm_original_dgemm_function
;
}
LIBXSMM_API
libxsmm_sgemm_function
libxsmm_original_sgemm
(
void
)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER
(
1
,
float
,
gemm
,
libxsmm_original_sgemm_function
,
libxsmm_original_sgemm
/*self*/
);
LIBXSMM_ASSERT
(
NULL
!=
libxsmm_original_sgemm_function
);
# else
LIBXSMM_BLAS_WRAPPER
(
0
,
float
,
gemm
,
libxsmm_original_sgemm_function
,
libxsmm_original_sgemm
/*self*/
);
# endif
return
libxsmm_original_sgemm_function
;
}
LIBXSMM_API
libxsmm_dgemv_function
libxsmm_original_dgemv
(
void
)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER
(
1
,
double
,
gemv
,
libxsmm_original_dgemv_function
,
libxsmm_original_dgemv
/*self*/
);
LIBXSMM_ASSERT
(
NULL
!=
libxsmm_original_dgemv_function
);
# else
LIBXSMM_BLAS_WRAPPER
(
0
,
double
,
gemv
,
libxsmm_original_dgemv_function
,
libxsmm_original_dgemv
/*self*/
);
# endif
return
libxsmm_original_dgemv_function
;
}
LIBXSMM_API
libxsmm_sgemv_function
libxsmm_original_sgemv
(
void
)
{
# if (0 != LIBXSMM_BLAS)
LIBXSMM_BLAS_WRAPPER
(
1
,
float
,
gemv
,
libxsmm_original_sgemv_function
,
libxsmm_original_sgemv
/*self*/
);
LIBXSMM_ASSERT
(
NULL
!=
libxsmm_original_sgemv_function
);
# else
LIBXSMM_BLAS_WRAPPER
(
0
,
float
,
gemv
,
libxsmm_original_sgemv_function
,
libxsmm_original_sgemv
/*self*/
);
# endif
return
libxsmm_original_sgemv_function
;
}
#endif
/*defined(LIBXSMM_BLAS_WRAP_DYNAMIC)*/
LIBXSMM_APIEXT
LIBXSMM_ATTRIBUTE_USED
void
LIBXSMM_FSYMBOL
(
__wrap_dgemm_batch
)(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
double
alpha_array
[],
const
double
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
double
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
double
beta_array
[],
double
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
{
LIBXSMM_ASSERT
(
NULL
!=
lda_array
&&
NULL
!=
ldb_array
&&
NULL
!=
ldc_array
&&
NULL
!=
m_array
&&
NULL
!=
n_array
&&
NULL
!=
k_array
);
LIBXSMM_ASSERT
(
NULL
!=
transa_array
&&
NULL
!=
transb_array
&&
NULL
!=
alpha_array
&&
NULL
!=
beta_array
);
LIBXSMM_ASSERT
(
NULL
!=
group_count
&&
NULL
!=
group_size
);
LIBXSMM_INIT
if
(
0
!=
libxsmm_gemm_wrap
)
{
if
(
0
!=
(
libxsmm_gemm_wrap
&
1
))
{
/* sequential */
libxsmm_dgemm_batch
(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
else
{
/* parallelized */
libxsmm_dgemm_batch_omp
(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
}
else
{
LIBXSMM_GEMM_BATCH_SYMBOL
(
double
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
}
LIBXSMM_APIEXT
LIBXSMM_ATTRIBUTE_USED
void
LIBXSMM_FSYMBOL
(
__wrap_sgemm_batch
)(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
float
alpha_array
[],
const
float
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
float
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
float
beta_array
[],
float
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
{
LIBXSMM_ASSERT
(
NULL
!=
lda_array
&&
NULL
!=
ldb_array
&&
NULL
!=
ldc_array
&&
NULL
!=
m_array
&&
NULL
!=
n_array
&&
NULL
!=
k_array
);
LIBXSMM_ASSERT
(
NULL
!=
transa_array
&&
NULL
!=
transb_array
&&
NULL
!=
alpha_array
&&
NULL
!=
beta_array
);
LIBXSMM_ASSERT
(
NULL
!=
group_count
&&
NULL
!=
group_size
);
LIBXSMM_INIT
if
(
0
!=
libxsmm_gemm_wrap
)
{
if
(
0
!=
(
libxsmm_gemm_wrap
&
1
))
{
/* sequential */
libxsmm_sgemm_batch
(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
else
{
/* parallelized */
libxsmm_sgemm_batch_omp
(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
}
else
{
LIBXSMM_GEMM_BATCH_SYMBOL
(
float
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
}
LIBXSMM_APIEXT
LIBXSMM_ATTRIBUTE_USED
void
LIBXSMM_FSYMBOL
(
__wrap_dgemm
)(
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
double
*
alpha
,
const
double
*
a
,
const
libxsmm_blasint
*
lda
,
const
double
*
b
,
const
libxsmm_blasint
*
ldb
,
const
double
*
beta
,
double
*
c
,
const
libxsmm_blasint
*
ldc
)
{
LIBXSMM_ASSERT
(
NULL
!=
lda
&&
NULL
!=
ldb
&&
NULL
!=
ldc
&&
NULL
!=
m
&&
NULL
!=
n
&&
NULL
!=
k
);
LIBXSMM_ASSERT
(
NULL
!=
transa
&&
NULL
!=
transb
&&
NULL
!=
alpha
&&
NULL
!=
beta
);
{
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
unsigned
int
i
=
0
;
/* no flush */
int
flags
=
-
1
;
# if !defined(NDEBUG)
static
int
error_once
=
0
;
int
result
=
EXIT_SUCCESS
;
# endif
LIBXSMM_INIT
if
(
0
!=
libxsmm_gemm_wrap
&&
(
NULL
==
libxsmm_mmbatch_array
||
LIBXSMM_GEMM_PRECISION_F64
!=
libxsmm_mmbatch_desc
.
datatype
||
((
unsigned
int
)
*
lda
)
!=
libxsmm_mmbatch_desc
.
lda
||
((
unsigned
int
)
*
ldb
)
!=
libxsmm_mmbatch_desc
.
ldb
||
((
unsigned
int
)
*
ldc
)
!=
libxsmm_mmbatch_desc
.
ldc
||
((
unsigned
int
)
*
m
)
!=
libxsmm_mmbatch_desc
.
m
||
((
unsigned
int
)
*
n
)
!=
libxsmm_mmbatch_desc
.
n
||
((
unsigned
int
)
*
k
)
!=
libxsmm_mmbatch_desc
.
k
||
(
flags
=
LIBXSMM_GEMM_FLAGS
(
*
transa
,
*
transb
))
!=
(
int
)(
LIBXSMM_GEMM_FLAG_TRANS_AB
&
libxsmm_mmbatch_desc
.
flags
)
||
LIBXSMM_NEQ
(
/*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & libxsmm_mmbatch_desc.flags) ? 0 : */
1
,
*
alpha
)
||
LIBXSMM_NEQ
(
0
!=
(
LIBXSMM_GEMM_FLAG_BETA_0
&
libxsmm_mmbatch_desc
.
flags
)
?
0
:
1
,
*
beta
)))
#endif
{
#if defined(_DEBUG)
const
char
*
const
env_check
=
getenv
(
"LIBXSMM_GEMM_CHECK"
);
const
double
check
=
LIBXSMM_ABS
(
NULL
==
env_check
?
0
:
atof
(
env_check
));
void
*
d
=
NULL
;
if
(
LIBXSMM_NEQ
(
0
,
check
))
{
const
size_t
size
=
(
size_t
)(
*
ldc
)
*
(
size_t
)(
*
n
)
*
sizeof
(
double
);
d
=
libxsmm_scratch_malloc
(
size
,
0
/*auto*/
,
LIBXSMM_MALLOC_INTERNAL_CALLER
);
if
(
NULL
!=
d
&&
LIBXSMM_NEQ
(
0
,
*
beta
))
memcpy
(
d
,
c
,
size
);
/* copy destination */
}
#endif
if
(
0
!=
(
libxsmm_gemm_wrap
&
1
))
{
/* sequential */
libxsmm_dgemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
else
{
/* parallelized */
libxsmm_dgemm_omp
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
#if defined(_DEBUG)
if
(
NULL
!=
d
)
{
libxsmm_matdiff_info
diff
;
libxsmm_blas_dgemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
d
,
ldc
);
if
(
EXIT_SUCCESS
==
libxsmm_matdiff
(
&
diff
,
LIBXSMM_DATATYPE_F64
,
*
m
,
*
n
,
d
,
c
,
ldc
,
ldc
)
&&
check
<
100
.
0
*
diff
.
normf_rel
)
{
LIBXSMM_STDIO_ACQUIRE
();
fprintf
(
stderr
,
"LIBXSMM: "
);
libxsmm_gemm_print
(
stderr
,
LIBXSMM_GEMM_PRECISION_F64
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
fprintf
(
stderr
,
" => %f%% ERROR
\n
"
,
100
.
0
*
diff
.
normf_rel
);
LIBXSMM_STDIO_RELEASE
();
}
libxsmm_free
(
d
);
}
#endif
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
if
(
0
!=
(
LIBXSMM_MMBATCH_FLAG_STATISTIC
&
libxsmm_mmbatch_desc
.
flags
))
{
libxsmm_descriptor_blob
blob
;
const
libxsmm_gemm_descriptor
*
const
descriptor
=
libxsmm_dgemm_descriptor_init
(
&
blob
,
*
m
,
*
n
,
*
k
,
*
lda
,
*
ldb
,
*
ldc
,
*
alpha
,
*
beta
,
LIBXSMM_GEMM_FLAGS
(
*
transa
,
*
transb
),
LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH
);
LIBXSMM_ASSERT
(
0
!=
libxsmm_mmbatch_size
);
if
(
NULL
!=
descriptor
)
{
const
unsigned
int
max_batchsize
=
(
unsigned
int
)((
LIBXSMM_GEMM_MMBATCH_SCALE
)
*
libxsmm_mmbatch_size
);
const
unsigned
int
batchsize
=
LIBXSMM_ATOMIC_LOAD
(
&
internal_ext_gemm_batchsize
,
LIBXSMM_ATOMIC_RELAXED
);
const
unsigned
int
max_size
=
(
0
!=
batchsize
?
(((
batchsize
-
1
)
%
max_batchsize
)
+
1
)
:
0
);
libxsmm_mmbatch_item
*
const
batcharray
=
(
libxsmm_mmbatch_item
*
)
libxsmm_mmbatch_array
;
libxsmm_mmbatch_item
*
batcharray_cur
=
batcharray
;
unsigned
int
size
=
max_size
;
if
(
libxsmm_mmbatch_size
<
max_size
)
{
size
=
max_size
-
libxsmm_mmbatch_size
;
batcharray_cur
+=
libxsmm_mmbatch_size
;
}
i
=
libxsmm_diff_n
(
descriptor
,
batcharray_cur
,
sizeof
(
libxsmm_gemm_descriptor
),
sizeof
(
libxsmm_mmbatch_item
)
/*stride*/
,
0
/*hint*/
,
size
);
if
(
i
<
size
)
{
/* update existing entry */
LIBXSMM_ATOMIC_ADD_FETCH
(
&
batcharray_cur
[
i
].
stat
.
count
,
1
,
LIBXSMM_ATOMIC_RELAXED
);
}
else
{
/* new entry needed */
const
int
all
=
-
1
,
shift
=
0
;
void
*
extra
=
0
;
i
=
((
LIBXSMM_ATOMIC_ADD_FETCH
(
&
internal_ext_gemm_batchsize
,
1
,
LIBXSMM_ATOMIC_RELAXED
)
-
1
)
%
max_batchsize
)
+
1
;
batcharray
[
i
-
1
].
stat
.
desc
=
*
descriptor
;
batcharray
[
i
-
1
].
stat
.
count
=
1
;
batcharray
[
i
-
1
].
stat
.
symbol
=
libxsmm_trace_info
(
NULL
/*depth*/
,
NULL
/*tid*/
,
&
all
,
LIBXSMM_FUNCNAME
,
&
shift
,
&
all
);
if
(
EXIT_SUCCESS
==
libxsmm_get_malloc_xinfo
(
libxsmm_mmbatch_array
,
NULL
/*size*/
,
NULL
/*flags*/
,
&
extra
))
{
*
(
libxsmm_mmbatch_flush_function
*
)
extra
=
libxsmm_mmbatch_end
;
}
# if !defined(NDEBUG)
else
{
result
=
EXIT_FAILURE
;
}
# endif
}
}
}
#endif
}
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
else
{
libxsmm_mmbatch_item
*
const
batcharray
=
(
libxsmm_mmbatch_item
*
)
libxsmm_mmbatch_array
;
const
unsigned
int
max_batchsize
=
(
unsigned
int
)((
LIBXSMM_GEMM_MMBATCH_SCALE
)
*
libxsmm_mmbatch_size
);
i
=
((
LIBXSMM_ATOMIC_ADD_FETCH
(
&
internal_ext_gemm_batchsize
,
1
,
LIBXSMM_ATOMIC_RELAXED
)
-
1
)
%
max_batchsize
)
+
1
;
batcharray
[
i
-
1
].
value
.
a
=
a
;
batcharray
[
i
-
1
].
value
.
b
=
b
;
batcharray
[
i
-
1
].
value
.
c
=
c
;
LIBXSMM_ASSERT
(
0
<=
flags
);
}
if
(
libxsmm_mmbatch_size
==
(
i
-
1
))
{
/* condition ensure to flush once (first discovery) */
# if !defined(NDEBUG)
result
=
# endif
internal_mmbatch_flush
(
&
libxsmm_mmbatch_desc
,
libxsmm_mmbatch_size
,
(
libxsmm_mmbatch_item
*
)
libxsmm_mmbatch_array
);
}
# if !defined(NDEBUG)
/* library code is expected to be mute */
if
(
EXIT_SUCCESS
!=
result
&&
0
!=
libxsmm_verbosity
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM ERROR: DGEMM batch recording failed!
\n
"
);
}
# endif
#endif
}
}
LIBXSMM_APIEXT
LIBXSMM_ATTRIBUTE_USED
void
LIBXSMM_FSYMBOL
(
__wrap_sgemm
)(
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
float
*
alpha
,
const
float
*
a
,
const
libxsmm_blasint
*
lda
,
const
float
*
b
,
const
libxsmm_blasint
*
ldb
,
const
float
*
beta
,
float
*
c
,
const
libxsmm_blasint
*
ldc
)
{
LIBXSMM_ASSERT
(
NULL
!=
lda
&&
NULL
!=
ldb
&&
NULL
!=
ldc
&&
NULL
!=
m
&&
NULL
!=
n
&&
NULL
!=
k
);
LIBXSMM_ASSERT
(
NULL
!=
transa
&&
NULL
!=
transb
&&
NULL
!=
alpha
&&
NULL
!=
beta
);
{
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
unsigned
int
i
=
0
;
/* no flush */
int
flags
=
-
1
;
# if !defined(NDEBUG)
static
int
error_once
=
0
;
int
result
=
EXIT_SUCCESS
;
# endif
LIBXSMM_INIT
if
(
0
!=
libxsmm_gemm_wrap
&&
(
NULL
==
libxsmm_mmbatch_array
||
LIBXSMM_GEMM_PRECISION_F32
!=
libxsmm_mmbatch_desc
.
datatype
||
((
unsigned
int
)
*
lda
)
!=
libxsmm_mmbatch_desc
.
lda
||
((
unsigned
int
)
*
ldb
)
!=
libxsmm_mmbatch_desc
.
ldb
||
((
unsigned
int
)
*
ldc
)
!=
libxsmm_mmbatch_desc
.
ldc
||
((
unsigned
int
)
*
m
)
!=
libxsmm_mmbatch_desc
.
m
||
((
unsigned
int
)
*
n
)
!=
libxsmm_mmbatch_desc
.
n
||
((
unsigned
int
)
*
k
)
!=
libxsmm_mmbatch_desc
.
k
||
(
flags
=
LIBXSMM_GEMM_FLAGS
(
*
transa
,
*
transb
))
!=
(
int
)(
LIBXSMM_GEMM_FLAG_TRANS_AB
&
libxsmm_mmbatch_desc
.
flags
)
||
LIBXSMM_NEQ
(
/*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & libxsmm_mmbatch_desc.flags) ? 0 : */
1
,
*
alpha
)
||
LIBXSMM_NEQ
(
0
!=
(
LIBXSMM_GEMM_FLAG_BETA_0
&
libxsmm_mmbatch_desc
.
flags
)
?
0
:
1
,
*
beta
)))
#endif
{
#if defined(_DEBUG)
const
char
*
const
env_check
=
getenv
(
"LIBXSMM_GEMM_CHECK"
);
const
double
check
=
LIBXSMM_ABS
(
NULL
==
env_check
?
0
:
atof
(
env_check
));
void
*
d
=
NULL
;
if
(
LIBXSMM_NEQ
(
0
,
check
))
{
const
size_t
size
=
(
size_t
)(
*
ldc
)
*
(
size_t
)(
*
n
)
*
sizeof
(
float
);
d
=
libxsmm_scratch_malloc
(
size
,
0
/*auto*/
,
LIBXSMM_MALLOC_INTERNAL_CALLER
);
if
(
NULL
!=
d
&&
LIBXSMM_NEQ
(
0
,
*
beta
))
memcpy
(
d
,
c
,
size
);
/* copy destination */
}
#endif
if
(
0
!=
(
libxsmm_gemm_wrap
&
1
))
{
/* sequential */
libxsmm_sgemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
else
{
/* parallelized */
libxsmm_sgemm_omp
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
#if defined(_DEBUG)
if
(
NULL
!=
d
)
{
libxsmm_matdiff_info
diff
;
libxsmm_blas_sgemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
d
,
ldc
);
if
(
EXIT_SUCCESS
==
libxsmm_matdiff
(
&
diff
,
LIBXSMM_DATATYPE_F32
,
*
m
,
*
n
,
d
,
c
,
ldc
,
ldc
)
&&
check
<
100
.
0
*
diff
.
normf_rel
)
{
LIBXSMM_STDIO_ACQUIRE
();
fprintf
(
stderr
,
"LIBXSMM: "
);
libxsmm_gemm_print
(
stderr
,
LIBXSMM_GEMM_PRECISION_F32
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
fprintf
(
stderr
,
" => %f%% ERROR
\n
"
,
100
.
0
*
diff
.
normf_rel
);
LIBXSMM_STDIO_RELEASE
();
}
libxsmm_free
(
d
);
}
#endif
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
if
(
0
!=
(
LIBXSMM_MMBATCH_FLAG_STATISTIC
&
libxsmm_mmbatch_desc
.
flags
))
{
libxsmm_descriptor_blob
blob
;
const
libxsmm_gemm_descriptor
*
const
descriptor
=
libxsmm_sgemm_descriptor_init
(
&
blob
,
*
m
,
*
n
,
*
k
,
*
lda
,
*
ldb
,
*
ldc
,
*
alpha
,
*
beta
,
LIBXSMM_GEMM_FLAGS
(
*
transa
,
*
transb
),
LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH
);
LIBXSMM_ASSERT
(
0
!=
libxsmm_mmbatch_size
);
if
(
NULL
!=
descriptor
)
{
const
unsigned
int
max_batchsize
=
(
unsigned
int
)((
LIBXSMM_GEMM_MMBATCH_SCALE
)
*
libxsmm_mmbatch_size
);
const
unsigned
int
batchsize
=
LIBXSMM_ATOMIC_LOAD
(
&
internal_ext_gemm_batchsize
,
LIBXSMM_ATOMIC_RELAXED
);
const
unsigned
int
max_size
=
(
0
!=
batchsize
?
(((
batchsize
-
1
)
%
max_batchsize
)
+
1
)
:
0
);
libxsmm_mmbatch_item
*
const
batcharray
=
(
libxsmm_mmbatch_item
*
)
libxsmm_mmbatch_array
;
libxsmm_mmbatch_item
*
batcharray_cur
=
batcharray
;
unsigned
int
size
=
max_size
;
if
(
libxsmm_mmbatch_size
<
max_size
)
{
size
=
max_size
-
libxsmm_mmbatch_size
;
batcharray_cur
+=
libxsmm_mmbatch_size
;
}
i
=
libxsmm_diff_n
(
descriptor
,
batcharray_cur
,
sizeof
(
libxsmm_gemm_descriptor
),
sizeof
(
libxsmm_mmbatch_item
)
/*stride*/
,
0
/*hint*/
,
size
);
if
(
i
<
size
)
{
/* update existing entry */
LIBXSMM_ATOMIC_ADD_FETCH
(
&
batcharray_cur
[
i
].
stat
.
count
,
1
,
LIBXSMM_ATOMIC_RELAXED
);
}
else
{
/* new entry needed */
const
int
all
=
-
1
,
shift
=
0
;
void
*
extra
=
0
;
i
=
((
LIBXSMM_ATOMIC_ADD_FETCH
(
&
internal_ext_gemm_batchsize
,
1
,
LIBXSMM_ATOMIC_RELAXED
)
-
1
)
%
max_batchsize
)
+
1
;
batcharray
[
i
-
1
].
stat
.
desc
=
*
descriptor
;
batcharray
[
i
-
1
].
stat
.
count
=
1
;
batcharray
[
i
-
1
].
stat
.
symbol
=
libxsmm_trace_info
(
NULL
/*depth*/
,
NULL
/*tid*/
,
&
all
,
LIBXSMM_FUNCNAME
,
&
shift
,
&
all
);
if
(
EXIT_SUCCESS
==
libxsmm_get_malloc_xinfo
(
libxsmm_mmbatch_array
,
NULL
/*size*/
,
NULL
/*flags*/
,
&
extra
))
{
*
(
libxsmm_mmbatch_flush_function
*
)
extra
=
libxsmm_mmbatch_end
;
}
# if !defined(NDEBUG)
else
{
result
=
EXIT_FAILURE
;
}
# endif
}
}
}
#endif
}
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
else
{
libxsmm_mmbatch_item
*
const
batcharray
=
(
libxsmm_mmbatch_item
*
)
libxsmm_mmbatch_array
;
const
unsigned
int
max_batchsize
=
(
unsigned
int
)((
LIBXSMM_GEMM_MMBATCH_SCALE
)
*
libxsmm_mmbatch_size
);
i
=
((
LIBXSMM_ATOMIC_ADD_FETCH
(
&
internal_ext_gemm_batchsize
,
1
,
LIBXSMM_ATOMIC_RELAXED
)
-
1
)
%
max_batchsize
)
+
1
;
batcharray
[
i
-
1
].
value
.
a
=
a
;
batcharray
[
i
-
1
].
value
.
b
=
b
;
batcharray
[
i
-
1
].
value
.
c
=
c
;
LIBXSMM_ASSERT
(
0
<=
flags
);
}
if
(
libxsmm_mmbatch_size
==
(
i
-
1
))
{
/* condition ensure to flush once (first discovery) */
# if !defined(NDEBUG)
result
=
# endif
internal_mmbatch_flush
(
&
libxsmm_mmbatch_desc
,
libxsmm_mmbatch_size
,
(
libxsmm_mmbatch_item
*
)
libxsmm_mmbatch_array
);
}
# if !defined(NDEBUG)
/* library code is expected to be mute */
if
(
EXIT_SUCCESS
!=
result
&&
0
!=
libxsmm_verbosity
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM ERROR: SGEMM batch recording failed!
\n
"
);
}
# endif
#endif
}
}
LIBXSMM_APIEXT
LIBXSMM_ATTRIBUTE_USED
void
LIBXSMM_FSYMBOL
(
__wrap_dgemv
)(
const
char
*
trans
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
double
*
alpha
,
const
double
*
a
,
const
libxsmm_blasint
*
lda
,
const
double
*
x
,
const
libxsmm_blasint
*
incx
,
const
double
*
beta
,
double
*
y
,
const
libxsmm_blasint
*
incy
)
{
LIBXSMM_ASSERT
(
NULL
!=
trans
&&
NULL
!=
m
&&
NULL
!=
n
&&
NULL
!=
lda
&&
NULL
!=
incx
&&
NULL
!=
incy
&&
NULL
!=
alpha
&&
NULL
!=
beta
);
LIBXSMM_INIT
if
((
2
<
libxsmm_gemm_wrap
||
2
>
libxsmm_gemm_wrap
)
&&
1
==
*
incx
&&
1
==
*
incy
&&
LIBXSMM_SMM
(
*
m
,
1
,
*
n
,
2
/*RFO*/
,
sizeof
(
double
)))
{
if
(
0
!=
(
libxsmm_gemm_wrap
&
1
))
{
/* sequential */
const
int
flags
=
LIBXSMM_GEMM_FLAGS
(
*
trans
,
'N'
);
const
libxsmm_dmmfunction
xgemv
=
libxsmm_dmmdispatch
(
*
m
,
1
,
*
n
,
lda
,
n
/*ldb*/
,
m
/*ldc*/
,
alpha
,
beta
,
&
flags
,
NULL
);
if
(
NULL
!=
xgemv
)
{
LIBXSMM_MMCALL_LDX
(
xgemv
,
a
,
x
,
y
,
*
m
,
1
,
*
n
,
*
lda
,
*
n
/*ldb*/
,
*
m
/*ldc*/
);
}
else
{
LIBXSMM_GEMV_SYMBOL
(
double
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
}
else
{
/* TODO: parallelized */
LIBXSMM_GEMV_SYMBOL
(
double
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
}
else
{
LIBXSMM_GEMV_SYMBOL
(
double
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
}
LIBXSMM_APIEXT
LIBXSMM_ATTRIBUTE_USED
void
LIBXSMM_FSYMBOL
(
__wrap_sgemv
)(
const
char
*
trans
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
float
*
alpha
,
const
float
*
a
,
const
libxsmm_blasint
*
lda
,
const
float
*
x
,
const
libxsmm_blasint
*
incx
,
const
float
*
beta
,
float
*
y
,
const
libxsmm_blasint
*
incy
)
{
LIBXSMM_ASSERT
(
NULL
!=
trans
&&
NULL
!=
m
&&
NULL
!=
n
&&
NULL
!=
lda
&&
NULL
!=
incx
&&
NULL
!=
incy
&&
NULL
!=
alpha
&&
NULL
!=
beta
);
LIBXSMM_INIT
if
((
2
<
libxsmm_gemm_wrap
||
2
>
libxsmm_gemm_wrap
)
&&
1
==
*
incx
&&
1
==
*
incy
&&
LIBXSMM_SMM
(
*
m
,
1
,
*
n
,
2
/*RFO*/
,
sizeof
(
float
)))
{
if
(
0
!=
(
libxsmm_gemm_wrap
&
1
))
{
/* sequential */
const
int
flags
=
LIBXSMM_GEMM_FLAGS
(
*
trans
,
'N'
);
const
libxsmm_smmfunction
xgemv
=
libxsmm_smmdispatch
(
*
m
,
1
,
*
n
,
lda
,
n
/*ldb*/
,
m
/*ldc*/
,
alpha
,
beta
,
&
flags
,
NULL
);
if
(
NULL
!=
xgemv
)
{
LIBXSMM_MMCALL_LDX
(
xgemv
,
a
,
x
,
y
,
*
m
,
1
,
*
n
,
*
lda
,
*
n
/*ldb*/
,
*
m
/*ldc*/
);
}
else
{
LIBXSMM_GEMV_SYMBOL
(
float
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
}
else
{
/* TODO: parallelized */
LIBXSMM_GEMV_SYMBOL
(
float
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
}
else
{
LIBXSMM_GEMV_SYMBOL
(
float
)(
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
}
LIBXSMM_APIEXT
LIBXSMM_ATTRIBUTE_USED
void
__wrap_dgemm_batch
(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
double
alpha_array
[],
const
double
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
double
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
double
beta_array
[],
double
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
{
LIBXSMM_FSYMBOL
(
__wrap_dgemm_batch
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
LIBXSMM_APIEXT
LIBXSMM_ATTRIBUTE_USED
void
__wrap_sgemm_batch
(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
float
alpha_array
[],
const
float
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
float
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
float
beta_array
[],
float
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
{
LIBXSMM_FSYMBOL
(
__wrap_sgemm_batch
)(
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
a_array
,
lda_array
,
b_array
,
ldb_array
,
beta_array
,
c_array
,
ldc_array
,
group_count
,
group_size
);
}
#endif
/*defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT)*/
LIBXSMM_APIEXT
void
libxsmm_xgemm_omp
(
libxsmm_gemm_precision
iprec
,
libxsmm_gemm_precision
oprec
,
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
void
*
alpha
,
const
void
*
a
,
const
libxsmm_blasint
*
lda
,
const
void
*
b
,
const
libxsmm_blasint
*
ldb
,
const
void
*
beta
,
void
*
c
,
const
libxsmm_blasint
*
ldc
)
{
libxsmm_gemm_blob
blob
;
#if defined(LIBXSMM_EXT_TASKS)
/* implies _OPENMP */
const
int
outerpar
=
omp_get_active_level
(),
nthreads
=
(
0
==
outerpar
?
omp_get_max_threads
()
:
omp_get_num_threads
());
#elif defined(_OPENMP)
const
int
outerpar
=
omp_in_parallel
(),
nthreads
=
(
0
==
outerpar
?
omp_get_max_threads
()
:
1
);
#else
const
int
nthreads
=
1
;
#endif
const
libxsmm_gemm_handle
*
const
handle
=
libxsmm_gemm_handle_init
(
&
blob
,
iprec
,
oprec
,
transa
,
transb
,
m
,
n
,
k
,
lda
,
ldb
,
ldc
,
alpha
,
beta
,
LIBXSMM_GEMM_HANDLE_FLAG_AUTO
,
nthreads
);
const
size_t
scratch_size
=
libxsmm_gemm_handle_get_scratch_size
(
handle
);
void
*
scratch
=
NULL
;
if
(
NULL
!=
handle
&&
(
0
==
scratch_size
||
NULL
!=
(
scratch
=
libxsmm_scratch_malloc
(
scratch_size
,
LIBXSMM_CACHELINE
,
LIBXSMM_MALLOC_INTERNAL_CALLER
))))
{
#if defined(_OPENMP)
if
(
0
==
outerpar
)
{
/* enable internal parallelization */
# if defined(LIBXSMM_EXT_TASKS)
if
(
0
==
libxsmm_gemm_tasks
)
# endif
{
# pragma omp parallel num_threads(nthreads)
libxsmm_gemm_task
(
handle
,
scratch
,
a
,
b
,
c
,
omp_get_thread_num
(),
nthreads
);
}
# if defined(LIBXSMM_EXT_TASKS)
else
{
/* tasks requested */
const
int
ntasks
=
nthreads
;
/* TODO: apply grain-size */
# pragma omp parallel num_threads(nthreads)
{
/* first thread discovering work will launch all tasks */
# pragma omp single nowait
/* anyone is good */
{
int
tid
;
for
(
tid
=
0
;
tid
<
ntasks
;
++
tid
)
{
# pragma omp task untied
libxsmm_gemm_task
(
handle
,
scratch
,
a
,
b
,
c
,
tid
,
ntasks
);
}
}
}
/* implicit synchronization (barrier) */
}
# endif
}
else
{
/* assume external parallelization */
# if defined(LIBXSMM_EXT_TASKS)
/* implies _OPENMP */
const
int
ntasks
=
nthreads
;
/* TODO: apply grain-size */
int
tid
;
for
(
tid
=
0
;
tid
<
ntasks
;
++
tid
)
{
# pragma omp task untied
libxsmm_gemm_task
(
handle
,
scratch
,
a
,
b
,
c
,
tid
,
ntasks
);
}
if
(
0
==
libxsmm_nosync
)
{
/* allow to omit synchronization */
# pragma omp taskwait
}
# else
libxsmm_gemm_task
(
handle
,
scratch
,
a
,
b
,
c
,
0
/*tid*/
,
1
/*nthreads*/
);
# endif
}
if
(
LIBXSMM_VERBOSITY_HIGH
<=
libxsmm_verbosity
||
0
>
libxsmm_verbosity
)
{
/* library code is expected to be mute */
const
unsigned
int
ntasks
=
handle
->
mt
*
handle
->
nt
*
handle
->
kt
;
const
double
imbalance
=
100
.
0
*
LIBXSMM_DELTA
((
unsigned
int
)
nthreads
,
ntasks
)
/
nthreads
;
static
double
max_imbalance
=
50
.
0
;
if
(
max_imbalance
<
imbalance
)
{
fprintf
(
stderr
,
"LIBXSMM WARNING: XGEMM %.0f%% imbalance (%u of %i workers utilized)!
\n
"
,
imbalance
,
ntasks
,
nthreads
);
max_imbalance
=
imbalance
;
}
}
#else
libxsmm_gemm_task
(
handle
,
scratch
,
a
,
b
,
c
,
0
/*tid*/
,
1
/*nthreads*/
);
#endif
/*defined(_OPENMP)*/
libxsmm_free
(
scratch
);
}
else
{
/* fallback or error */
static
int
error_once
=
0
;
if
(
NULL
==
handle
)
{
/* fallback */
if
((
LIBXSMM_VERBOSITY_HIGH
<=
libxsmm_verbosity
||
0
>
libxsmm_verbosity
)
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM WARNING: XGEMM fallback code path triggered!
\n
"
);
}
}
else
if
(
0
!=
libxsmm_verbosity
&&
/* library code is expected to be mute */
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM ERROR: failed to allocate GEMM-scratch memory!
\n
"
);
}
libxsmm_blas_xgemm
(
iprec
,
oprec
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
}
LIBXSMM_API_INLINE
void
internal_gemm_batch_omp
(
libxsmm_gemm_precision
iprec
,
libxsmm_gemm_precision
oprec
,
const
char
transa
[],
const
char
transb
[],
const
libxsmm_blasint
m
[],
const
libxsmm_blasint
n
[],
const
libxsmm_blasint
k
[],
const
void
*
alpha
,
const
void
*
a
[],
const
libxsmm_blasint
lda
[],
const
void
*
b
[],
const
libxsmm_blasint
ldb
[],
const
void
*
beta
,
void
*
c
[],
const
libxsmm_blasint
ldc
[],
libxsmm_blasint
index_base
,
libxsmm_blasint
index_stride
,
const
libxsmm_blasint
stride_a
[],
const
libxsmm_blasint
stride_b
[],
const
libxsmm_blasint
stride_c
[],
const
libxsmm_blasint
batchsize
[],
libxsmm_blasint
group_count
)
{
static
int
error_once
=
0
;
LIBXSMM_INIT
if
(
/* check for sensible arguments */
#if defined(LIBXSMM_BATCH_CHECK)
NULL
!=
a
&&
NULL
!=
b
&&
NULL
!=
c
&&
(
1
==
group_count
||
-
1
==
group_count
||
(
0
==
index_stride
&&
(
NULL
==
stride_a
||
0
!=
*
stride_a
)
&&
(
NULL
==
stride_b
||
0
!=
*
stride_b
)
&&
(
NULL
==
stride_c
||
0
!=
*
stride_c
)))
&&
#endif
0
!=
group_count
)
{
int
result
=
EXIT_SUCCESS
;
const
int
max_npargroups
=
(
int
)(
0
<
libxsmm_gemm_npargroups
?
LIBXSMM_MIN
(
libxsmm_gemm_npargroups
,
LIBXSMM_GEMM_NPARGROUPS
)
:
LIBXSMM_GEMM_NPARGROUPS
);
const
libxsmm_gemm_prefetch_type
prefetch
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_PREFETCH_AUTO
);
const
size_t
sa
=
(
NULL
!=
stride_a
?
(
size_t
)(
*
stride_a
)
:
sizeof
(
void
*
));
const
size_t
sb
=
(
NULL
!=
stride_b
?
(
size_t
)(
*
stride_b
)
:
sizeof
(
void
*
));
const
size_t
sc
=
(
NULL
!=
stride_c
?
(
size_t
)(
*
stride_c
)
:
sizeof
(
void
*
));
const
unsigned
char
otypesize
=
libxsmm_typesize
((
libxsmm_datatype
)
oprec
);
const
int
ngroups
=
(
int
)
LIBXSMM_ABS
(
group_count
);
int
group
=
0
,
group_next
=
LIBXSMM_GEMM_NPARGROUPS
;
libxsmm_code_pointer
kernel
[
LIBXSMM_GEMM_NPARGROUPS
];
libxsmm_blasint
base
[
LIBXSMM_GEMM_NPARGROUPS
],
i
;
#if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
int
kflags
[
LIBXSMM_GEMM_NPARGROUPS
];
#endif
int
max_nthreads
=
1
;
#if defined(_OPENMP)
# if defined(LIBXSMM_EXT_TASKS)
const
int
outerpar
=
omp_get_active_level
();
# else
const
int
outerpar
=
omp_in_parallel
();
# endif
if
(
0
==
outerpar
)
max_nthreads
=
omp_get_max_threads
();
#endif
for
(
i
=
0
;
i
<
max_npargroups
;
++
i
)
{
#if !defined(NDEBUG)
kernel
[
i
].
ptr
=
NULL
;
# if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
kflags
[
i
]
=
0
;
# endif
#endif
base
[
i
]
=
0
;
}
for
(
group
=
0
;
group
<
ngroups
;
group
=
group_next
,
group_next
+=
max_npargroups
)
{
const
int
npargroups
=
LIBXSMM_MIN
(
group_next
,
ngroups
);
libxsmm_blasint
size
=
0
;
int
suitable
=
0
;
if
(
0
<
group
)
{
/* base is maintained even if par-group is not suitable */
for
(
i
=
0
;
i
<
npargroups
;
++
i
)
{
const
libxsmm_blasint
isize
=
batchsize
[
group
+
i
-
1
],
asize
=
LIBXSMM_ABS
(
isize
);
base
[
i
]
+=
asize
;
}
}
for
(
i
=
0
;
i
<
npargroups
;
++
i
)
{
const
libxsmm_blasint
g
=
group
+
i
,
im
=
m
[
g
],
in
=
n
[
g
],
ik
=
k
[
g
];
suitable
=
LIBXSMM_SMM_AI
(
im
,
in
,
ik
,
2
/*RFO*/
,
otypesize
);
if
(
0
!=
suitable
)
{
const
libxsmm_blasint
isize
=
batchsize
[
g
],
asize
=
LIBXSMM_ABS
(
isize
);
const
char
*
const
ta
=
(
NULL
!=
transa
?
(
transa
+
g
)
:
NULL
);
const
char
*
const
tb
=
(
NULL
!=
transb
?
(
transb
+
g
)
:
NULL
);
const
int
flags
=
LIBXSMM_GEMM_PFLAGS
(
ta
,
tb
,
LIBXSMM_FLAGS
);
const
void
**
const
galpha
=
&
alpha
,
**
const
gbeta
=
&
beta
;
libxsmm_descriptor_blob
blob
;
/* coverity[ptr_arith] */
libxsmm_gemm_descriptor
*
const
desc
=
libxsmm_gemm_descriptor_init2
(
&
blob
,
iprec
,
oprec
,
im
,
in
,
ik
,
NULL
!=
lda
?
lda
[
g
]
:
(
0
==
(
LIBXSMM_GEMM_FLAG_TRANS_A
&
flags
)
?
im
:
ik
),
NULL
!=
ldb
?
ldb
[
g
]
:
(
0
==
(
LIBXSMM_GEMM_FLAG_TRANS_B
&
flags
)
?
ik
:
in
),
NULL
!=
ldc
?
ldc
[
g
]
:
im
,
NULL
!=
alpha
?
galpha
[
g
]
:
NULL
,
NULL
!=
beta
?
gbeta
[
g
]
:
NULL
,
flags
,
prefetch
);
if
(
NULL
!=
desc
)
{
libxsmm_gemm_internal_set_batchflag
(
desc
,
c
,
index_stride
,
0
<
group_count
?
isize
:
-
asize
,
1
!=
max_nthreads
);
kernel
[
i
].
xgemm
=
libxsmm_xmmdispatch
(
desc
);
}
else
kernel
[
i
].
ptr
=
NULL
;
if
(
NULL
!=
kernel
[
i
].
ptr_const
)
{
if
(
size
<
asize
)
size
=
asize
;
#if !defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
LIBXSMM_ASSERT
(
NULL
!=
desc
);
/* coverity[var_deref_op] */
kflags
[
i
]
=
desc
->
flags
;
#endif
}
else
{
suitable
=
0
;
break
;
}
}
else
break
;
}
if
(
0
!=
suitable
)
{
/* check if an SMM is suitable */
const
unsigned
char
itypesize
=
libxsmm_typesize
((
libxsmm_datatype
)
iprec
);
#if defined(_OPENMP)
const
int
nchunks
=
(
int
)
LIBXSMM_UPDIV
(
size
,
libxsmm_gemm_taskgrain
);
const
int
ntasks
=
nchunks
*
npargroups
,
nthreads
=
LIBXSMM_MIN
(
max_nthreads
,
ntasks
);
if
(
1
<
nthreads
)
{
if
(
0
==
outerpar
)
{
/* enable internal parallelization */
# if defined(LIBXSMM_EXT_TASKS)
if
(
0
==
libxsmm_gemm_tasks
)
# endif
{
# pragma omp parallel for num_threads(nthreads) private(i)
for
(
i
=
0
;
i
<
ntasks
;
++
i
)
{
const
libxsmm_blasint
j
=
i
*
libxsmm_gemm_taskgrain
,
u
=
j
/
size
,
v
=
j
-
u
*
size
,
g
=
group
+
u
;
const
libxsmm_blasint
isize
=
batchsize
[
g
],
asize
=
LIBXSMM_ABS
(
isize
);
if
(
v
<
asize
)
{
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
libxsmm_mmkernel_info
kernel_info
;
#endif
/*check*/
libxsmm_mmbatch_kernel
(
kernel
[
g
].
xgemm
,
index_base
,
index_stride
,
stride_a
,
stride_b
,
stride_c
,
(
const
char
*
)
a
+
sa
*
base
[
u
],
(
const
char
*
)
b
+
sb
*
base
[
u
],
(
char
*
)
c
+
sc
*
base
[
u
],
0
<
group_count
?
isize
:
-
asize
,
(
int
)
i
,
nchunks
,
itypesize
,
otypesize
,
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
EXIT_SUCCESS
==
libxsmm_get_mmkernel_info
(
kernel
[
g
].
xgemm
,
&
kernel_info
)
?
kernel_info
.
flags
:
0
);
#else
kflags
[
g
]);
#endif
}
}
}
# if defined(LIBXSMM_EXT_TASKS)
else
{
/* tasks requested */
# pragma omp parallel num_threads(nthreads) private(i)
{
/* first thread discovering work will launch all tasks */
# pragma omp single nowait
/* anyone is good */
for
(
i
=
0
;
i
<
ntasks
;
++
i
)
{
const
libxsmm_blasint
j
=
i
*
libxsmm_gemm_taskgrain
,
u
=
j
/
size
,
v
=
j
-
u
*
size
,
g
=
group
+
u
;
const
libxsmm_blasint
isize
=
batchsize
[
g
],
asize
=
LIBXSMM_ABS
(
isize
);
if
(
v
<
asize
)
{
# pragma omp task
{
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
libxsmm_mmkernel_info
kernel_info
;
#endif
/*check*/
libxsmm_mmbatch_kernel
(
kernel
[
g
].
xgemm
,
index_base
,
index_stride
,
stride_a
,
stride_b
,
stride_c
,
(
const
char
*
)
a
+
sa
*
base
[
u
],
(
const
char
*
)
b
+
sb
*
base
[
u
],
(
char
*
)
c
+
sc
*
base
[
u
],
0
<
group_count
?
isize
:
-
asize
,
(
int
)
i
,
nchunks
,
itypesize
,
otypesize
,
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
EXIT_SUCCESS
==
libxsmm_get_mmkernel_info
(
kernel
[
g
].
xgemm
,
&
kernel_info
)
?
kernel_info
.
flags
:
0
);
#else
kflags
[
g
]);
#endif
}
}
}
}
/* implicit synchronization (barrier) */
}
# endif
}
else
{
/* assume external parallelization */
for
(
i
=
0
;
i
<
(
libxsmm_blasint
)
ntasks
;
++
i
)
{
const
libxsmm_blasint
j
=
i
*
libxsmm_gemm_taskgrain
,
u
=
j
/
size
,
v
=
j
-
u
*
size
,
g
=
group
+
u
;
const
libxsmm_blasint
isize
=
batchsize
[
g
],
asize
=
LIBXSMM_ABS
(
isize
);
if
(
v
<
asize
)
{
# if defined(LIBXSMM_EXT_TASKS)
/* OpenMP-tasks */
# pragma omp task
#endif
{
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
libxsmm_mmkernel_info
kernel_info
;
#endif
/*check*/
libxsmm_mmbatch_kernel
(
kernel
[
g
].
xgemm
,
index_base
,
index_stride
,
stride_a
,
stride_b
,
stride_c
,
(
const
char
*
)
a
+
sa
*
base
[
u
],
(
const
char
*
)
b
+
sb
*
base
[
u
],
(
char
*
)
c
+
sc
*
base
[
u
],
0
<
group_count
?
isize
:
-
asize
,
(
int
)
i
,
nchunks
,
itypesize
,
otypesize
,
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
EXIT_SUCCESS
==
libxsmm_get_mmkernel_info
(
kernel
[
g
].
xgemm
,
&
kernel_info
)
?
kernel_info
.
flags
:
0
);
#else
kflags
[
g
]);
#endif
}
}
}
# if defined(LIBXSMM_EXT_TASKS)
/* OpenMP-tasks */
if
(
0
==
libxsmm_nosync
)
{
/* allow to omit synchronization */
# pragma omp taskwait
}
# endif
}
}
else
#endif
/*defined(_OPENMP)*/
{
/* sequential */
for
(
i
=
0
;
i
<
npargroups
;
++
i
)
{
const
libxsmm_blasint
g
=
group
+
i
;
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
libxsmm_mmkernel_info
kernel_info
;
#endif
libxsmm_mmbatch_kernel
(
kernel
[
i
].
xgemm
,
index_base
,
index_stride
,
stride_a
,
stride_b
,
stride_c
,
(
const
char
*
)
a
+
sa
*
base
[
i
],
(
const
char
*
)
b
+
sb
*
base
[
i
],
(
char
*
)
c
+
sc
*
base
[
i
],
batchsize
[
g
],
0
/*tid*/
,
1
/*nthreads*/
,
itypesize
,
otypesize
,
#if defined(LIBXSMM_EXT_GEMM_PARGROUPS_INFO)
EXIT_SUCCESS
==
libxsmm_get_mmkernel_info
(
kernel
[
i
].
xgemm
,
&
kernel_info
)
?
kernel_info
.
flags
:
0
);
#else
kflags
[
i
]);
#endif
}
}
}
else
{
/* trigger fallback */
result
=
EXIT_FAILURE
;
}
if
(
EXIT_SUCCESS
!=
result
)
{
for
(
i
=
0
;
i
<
npargroups
;
++
i
)
{
const
libxsmm_blasint
g
=
group
+
i
;
const
char
*
const
ta
=
(
NULL
!=
transa
?
(
transa
+
g
)
:
NULL
);
const
char
*
const
tb
=
(
NULL
!=
transb
?
(
transb
+
g
)
:
NULL
);
const
int
flags
=
LIBXSMM_GEMM_PFLAGS
(
ta
,
tb
,
LIBXSMM_FLAGS
);
const
libxsmm_blasint
im
=
m
[
g
],
in
=
n
[
g
],
ik
=
k
[
g
];
const
libxsmm_blasint
ilda
=
(
NULL
!=
lda
?
lda
[
g
]
:
(
0
==
(
LIBXSMM_GEMM_FLAG_TRANS_A
&
flags
)
?
im
:
ik
));
const
libxsmm_blasint
ildb
=
(
NULL
!=
ldb
?
ldb
[
g
]
:
(
0
==
(
LIBXSMM_GEMM_FLAG_TRANS_B
&
flags
)
?
ik
:
in
));
const
libxsmm_blasint
ildc
=
(
NULL
!=
ldc
?
ldc
[
g
]
:
im
);
const
void
**
const
galpha
=
&
alpha
,
**
const
gbeta
=
&
beta
;
/* coverity[overrun-local] */
const
void
*
const
ialpha
=
(
NULL
!=
alpha
?
galpha
[
g
]
:
NULL
);
/* coverity[overrun-local] */
const
void
*
const
ibeta
=
(
NULL
!=
beta
?
gbeta
[
g
]
:
NULL
);
if
(
EXIT_SUCCESS
==
libxsmm_mmbatch_blas
(
iprec
,
oprec
,
ta
,
tb
,
im
,
in
,
ik
,
ialpha
,
(
const
char
*
)
a
+
sa
*
base
[
i
],
&
ilda
,
(
const
char
*
)
b
+
sb
*
base
[
i
],
&
ildb
,
ibeta
,
(
char
*
)
c
+
sc
*
base
[
i
],
&
ildc
,
index_base
,
index_stride
,
stride_a
,
stride_b
,
stride_c
,
batchsize
[
g
]))
{
if
(
LIBXSMM_VERBOSITY_WARN
<=
libxsmm_verbosity
||
0
>
libxsmm_verbosity
)
{
const
size_t
threshold
=
LIBXSMM_MNK_SIZE
(
im
,
in
,
im
);
static
size_t
threshold_max
=
0
;
if
(
threshold_max
<
threshold
)
{
LIBXSMM_STDIO_ACQUIRE
();
fprintf
(
stderr
,
"LIBXSMM WARNING: "
);
libxsmm_gemm_print2
(
stderr
,
iprec
,
oprec
,
ta
,
tb
,
&
im
,
&
in
,
&
ik
,
ialpha
,
NULL
/*a*/
,
&
ilda
,
NULL
/*b*/
,
&
ildb
,
ibeta
,
NULL
/*c*/
,
&
ildc
);
fprintf
(
stderr
,
" => batched GEMM/omp was falling back to BLAS!
\n
"
);
LIBXSMM_STDIO_RELEASE
();
threshold_max
=
threshold
;
}
}
}
else
{
if
(
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM ERROR: libxsmm_gemm_batch_omp failed!
\n
"
);
}
return
;
/* exit routine */
}
}
}
}
}
#if defined(LIBXSMM_BATCH_CHECK)
else
if
(
0
!=
group_count
&&
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM ERROR: incorrect arguments (libxsmm_gemm_batch_omp)!
\n
"
);
}
#endif
}
LIBXSMM_APIEXT
void
libxsmm_gemm_batch_omp
(
libxsmm_gemm_precision
iprec
,
libxsmm_gemm_precision
oprec
,
const
char
*
transa
,
const
char
*
transb
,
libxsmm_blasint
m
,
libxsmm_blasint
n
,
libxsmm_blasint
k
,
const
void
*
alpha
,
const
void
*
a
,
const
libxsmm_blasint
*
lda
,
const
void
*
b
,
const
libxsmm_blasint
*
ldb
,
const
void
*
beta
,
void
*
c
,
const
libxsmm_blasint
*
ldc
,
libxsmm_blasint
index_base
,
libxsmm_blasint
index_stride
,
const
libxsmm_blasint
stride_a
[],
const
libxsmm_blasint
stride_b
[],
const
libxsmm_blasint
stride_c
[],
libxsmm_blasint
batchsize
)
{
internal_gemm_batch_omp
(
iprec
,
oprec
,
transa
,
transb
,
&
m
,
&
n
,
&
k
,
alpha
,
(
const
void
**
)
a
,
lda
,
(
const
void
**
)
b
,
ldb
,
beta
,
(
void
**
)
c
,
ldc
,
index_base
,
index_stride
,
stride_a
,
stride_b
,
stride_c
,
&
batchsize
,
1
);
}
LIBXSMM_APIEXT
void
libxsmm_dgemm_batch_omp
(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
double
alpha_array
[],
const
double
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
double
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
double
beta_array
[],
double
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
{
if
(
NULL
!=
group_count
)
{
const
libxsmm_blasint
ptrsize
=
sizeof
(
void
*
);
internal_gemm_batch_omp
(
LIBXSMM_GEMM_PRECISION_F64
,
LIBXSMM_GEMM_PRECISION_F64
,
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
(
const
void
**
)
a_array
,
lda_array
,
(
const
void
**
)
b_array
,
ldb_array
,
beta_array
,
(
void
**
)
c_array
,
ldc_array
,
0
/*index_base*/
,
0
/*index_stride*/
,
&
ptrsize
,
&
ptrsize
,
&
ptrsize
,
group_size
,
*
group_count
);
}
}
LIBXSMM_APIEXT
void
libxsmm_sgemm_batch_omp
(
const
char
transa_array
[],
const
char
transb_array
[],
const
libxsmm_blasint
m_array
[],
const
libxsmm_blasint
n_array
[],
const
libxsmm_blasint
k_array
[],
const
float
alpha_array
[],
const
float
*
a_array
[],
const
libxsmm_blasint
lda_array
[],
const
float
*
b_array
[],
const
libxsmm_blasint
ldb_array
[],
const
float
beta_array
[],
float
*
c_array
[],
const
libxsmm_blasint
ldc_array
[],
const
libxsmm_blasint
*
group_count
,
const
libxsmm_blasint
group_size
[])
{
if
(
NULL
!=
group_count
)
{
const
libxsmm_blasint
ptrsize
=
sizeof
(
void
*
);
internal_gemm_batch_omp
(
LIBXSMM_GEMM_PRECISION_F32
,
LIBXSMM_GEMM_PRECISION_F32
,
transa_array
,
transb_array
,
m_array
,
n_array
,
k_array
,
alpha_array
,
(
const
void
**
)
a_array
,
lda_array
,
(
const
void
**
)
b_array
,
ldb_array
,
beta_array
,
(
void
**
)
c_array
,
ldc_array
,
0
/*index_base*/
,
0
/*index_stride*/
,
&
ptrsize
,
&
ptrsize
,
&
ptrsize
,
group_size
,
*
group_count
);
}
}
LIBXSMM_APIEXT
void
libxsmm_mmbatch_begin
(
libxsmm_gemm_precision
precision
,
const
int
*
flags
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
libxsmm_blasint
*
lda
,
const
libxsmm_blasint
*
ldb
,
const
libxsmm_blasint
*
ldc
,
const
void
*
alpha
,
const
void
*
beta
)
{
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
# if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 26115)
/* try-lock is treated incorrectly by static analysis */
# endif
LIBXSMM_INIT
if
(
NULL
!=
libxsmm_mmbatch_array
/* batch-recording available, but not yet running */
/* currently, batch recording is only enabled if all values are present (no complex filtering) */
&&
NULL
!=
flags
&&
NULL
!=
alpha
&&
NULL
!=
beta
&&
NULL
!=
lda
&&
NULL
!=
ldb
&&
NULL
!=
ldc
&&
NULL
!=
m
&&
NULL
!=
n
&&
NULL
!=
k
&&
LIBXSMM_LOCK_ACQUIRED
(
LIBXSMM_LOCK_DEFAULT
)
==
LIBXSMM_LOCK_TRYLOCK
(
LIBXSMM_LOCK_DEFAULT
,
&
libxsmm_mmbatch_lock
))
{
libxsmm_descriptor_blob
blob
;
const
libxsmm_gemm_descriptor
*
const
descriptor
=
libxsmm_gemm_descriptor_init
(
&
blob
,
precision
,
*
m
,
*
n
,
*
k
,
*
lda
,
*
ldb
,
*
ldc
,
alpha
,
beta
,
*
flags
,
libxsmm_get_gemm_prefetch
(
LIBXSMM_EXT_GEMM_MMBATCH_PREFETCH
));
static
int
error_once
=
0
;
int
result
=
EXIT_SUCCESS
;
if
(
NULL
!=
descriptor
)
{
const
unsigned
int
max_batchsize
=
(
unsigned
int
)((
LIBXSMM_GEMM_MMBATCH_SCALE
)
*
libxsmm_mmbatch_size
);
unsigned
int
i
;
#if !defined(NDEBUG)
const
unsigned
int
mmbatch_maxdepth
=
LIBXSMM_UP2POT
(
LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH
);
LIBXSMM_ASSERT
((
LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH
)
==
mmbatch_maxdepth
/*is pot*/
);
#endif
/* eventually overwrite the oldest entry */
i
=
LIBXSMM_MOD2
(
internal_ext_gemm_batchdepth
,
LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH
);
internal_ext_gemm_batchdesc
[
i
]
=
libxsmm_mmbatch_desc
;
/* backup */
++
internal_ext_gemm_batchdepth
;
/* ensure descriptor does not match any GEMM such that... */
LIBXSMM_MEMZERO127
(
&
libxsmm_mmbatch_desc
);
/* ...the batch stops and completely flushes */
if
(
0
!=
internal_ext_gemm_batchsize
)
{
result
=
internal_mmbatch_flush
(
internal_ext_gemm_batchdesc
+
i
,
(((
libxsmm_blasint
)
internal_ext_gemm_batchsize
-
1
)
%
max_batchsize
)
+
1
,
(
libxsmm_mmbatch_item
*
)
libxsmm_mmbatch_array
);
}
if
(
EXIT_SUCCESS
==
result
)
{
/* enable descriptor */
internal_ext_gemm_batchsize
=
0
;
/* reset */
if
(
0
==
(
LIBXSMM_MMBATCH_FLAG_STATISTIC
&
*
flags
))
{
libxsmm_mmbatch_desc
=
*
descriptor
;
}
else
{
libxsmm_mmbatch_desc
.
flags
=
LIBXSMM_MMBATCH_FLAG_STATISTIC
;
}
}
}
else
{
result
=
EXIT_FAILURE
;
}
if
(
EXIT_SUCCESS
!=
result
&&
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM ERROR: GEMM batch enabling failed!
\n
"
);
}
LIBXSMM_LOCK_RELEASE
(
LIBXSMM_LOCK_DEFAULT
,
&
libxsmm_mmbatch_lock
);
}
# if defined(_MSC_VER)
# pragma warning(pop)
# endif
#else
LIBXSMM_UNUSED
(
precision
);
LIBXSMM_UNUSED
(
flags
);
LIBXSMM_UNUSED
(
m
);
LIBXSMM_UNUSED
(
n
);
LIBXSMM_UNUSED
(
k
);
LIBXSMM_UNUSED
(
lda
);
LIBXSMM_UNUSED
(
ldb
);
LIBXSMM_UNUSED
(
ldc
);
LIBXSMM_UNUSED
(
alpha
);
LIBXSMM_UNUSED
(
beta
);
#endif
}
LIBXSMM_APIEXT
void
libxsmm_mmbatch_end
(
void
)
{
#if defined(LIBXSMM_WRAP) && defined(LIBXSMM_BUILD_EXT)
# if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 26115)
/* try-lock is treated incorrectly by static analysis */
# endif
/*const*/
int
trystate
=
LIBXSMM_LOCK_TRYLOCK
(
LIBXSMM_LOCK_DEFAULT
,
&
libxsmm_mmbatch_lock
);
if
(
LIBXSMM_LOCK_ACQUIRED
(
LIBXSMM_LOCK_DEFAULT
)
==
trystate
)
{
const
unsigned
int
max_batchsize
=
(
unsigned
int
)((
LIBXSMM_GEMM_MMBATCH_SCALE
)
*
libxsmm_mmbatch_size
);
const
libxsmm_gemm_descriptor
flushdesc
=
libxsmm_mmbatch_desc
;
static
int
error_once
=
0
;
#if !defined(NDEBUG)
const
unsigned
int
mmbatch_maxdepth
=
LIBXSMM_UP2POT
(
LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH
);
#endif
/* ensure descriptor does not match any GEMM such that... */
LIBXSMM_MEMZERO127
(
&
libxsmm_mmbatch_desc
);
/* ...the batch stops and completely flushes */
if
(
EXIT_SUCCESS
==
internal_mmbatch_flush
(
&
flushdesc
,
0
!=
internal_ext_gemm_batchsize
?
(((
internal_ext_gemm_batchsize
-
1
)
%
max_batchsize
)
+
1
)
:
0
,
(
libxsmm_mmbatch_item
*
)
libxsmm_mmbatch_array
))
{
internal_ext_gemm_batchsize
=
0
;
/* reset */
--
internal_ext_gemm_batchdepth
;
/* restore the previous descriptor */
assert
((
LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH
)
==
mmbatch_maxdepth
/*is pot*/
);
/* no LIBXSMM_ASSERT! */
libxsmm_mmbatch_desc
=
internal_ext_gemm_batchdesc
[
LIBXSMM_MOD2
(
internal_ext_gemm_batchdepth
,
LIBXSMM_EXT_GEMM_MMBATCH_MAXDEPTH
)];
}
else
if
(
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM ERROR: GEMM batch processing failed!
\n
"
);
}
LIBXSMM_LOCK_RELEASE
(
LIBXSMM_LOCK_DEFAULT
,
&
libxsmm_mmbatch_lock
);
}
# if defined(_MSC_VER)
# pragma warning(pop)
# endif
#endif
}
#if defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_xgemm_omp
)(
const
libxsmm_gemm_precision
*
,
const
libxsmm_gemm_precision
*
,
const
char
*
,
const
char
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
double
*
,
const
double
*
,
const
libxsmm_blasint
*
,
const
double
*
,
const
libxsmm_blasint
*
,
const
double
*
,
double
*
,
const
libxsmm_blasint
*
);
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_xgemm_omp
)(
const
libxsmm_gemm_precision
*
iprec
,
const
libxsmm_gemm_precision
*
oprec
,
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
double
*
alpha
,
const
double
*
a
,
const
libxsmm_blasint
*
lda
,
const
double
*
b
,
const
libxsmm_blasint
*
ldb
,
const
double
*
beta
,
double
*
c
,
const
libxsmm_blasint
*
ldc
)
{
LIBXSMM_ASSERT
(
NULL
!=
iprec
&&
NULL
!=
oprec
);
libxsmm_xgemm_omp
(
*
iprec
,
*
oprec
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_dgemm_omp
)(
const
char
*
,
const
char
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
double
*
,
const
double
*
,
const
libxsmm_blasint
*
,
const
double
*
,
const
libxsmm_blasint
*
,
const
double
*
,
double
*
,
const
libxsmm_blasint
*
);
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_dgemm_omp
)(
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
double
*
alpha
,
const
double
*
a
,
const
libxsmm_blasint
*
lda
,
const
double
*
b
,
const
libxsmm_blasint
*
ldb
,
const
double
*
beta
,
double
*
c
,
const
libxsmm_blasint
*
ldc
)
{
libxsmm_dgemm_omp
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_sgemm_omp
)(
const
char
*
,
const
char
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
float
*
,
const
float
*
,
const
libxsmm_blasint
*
,
const
float
*
,
const
libxsmm_blasint
*
,
const
float
*
,
float
*
,
const
libxsmm_blasint
*
);
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_sgemm_omp
)(
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
float
*
alpha
,
const
float
*
a
,
const
libxsmm_blasint
*
lda
,
const
float
*
b
,
const
libxsmm_blasint
*
ldb
,
const
float
*
beta
,
float
*
c
,
const
libxsmm_blasint
*
ldc
)
{
libxsmm_sgemm_omp
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_gemm_batch_omp
)(
const
libxsmm_gemm_precision
*
,
const
libxsmm_gemm_precision
*
,
const
char
*
,
const
char
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
void
*
,
const
void
*
,
const
libxsmm_blasint
*
,
const
void
*
,
const
libxsmm_blasint
*
,
const
void
*
,
void
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
[],
const
libxsmm_blasint
[],
const
libxsmm_blasint
[],
const
libxsmm_blasint
*
);
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_gemm_batch_omp
)(
const
libxsmm_gemm_precision
*
iprec
,
const
libxsmm_gemm_precision
*
oprec
,
const
char
*
transa
,
const
char
*
transb
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
void
*
alpha
,
const
void
*
a
,
const
libxsmm_blasint
*
lda
,
const
void
*
b
,
const
libxsmm_blasint
*
ldb
,
const
void
*
beta
,
void
*
c
,
const
libxsmm_blasint
*
ldc
,
const
libxsmm_blasint
*
index_base
,
const
libxsmm_blasint
*
index_stride
,
const
libxsmm_blasint
stride_a
[],
const
libxsmm_blasint
stride_b
[],
const
libxsmm_blasint
stride_c
[],
const
libxsmm_blasint
*
batchsize
)
{
LIBXSMM_ASSERT
(
NULL
!=
iprec
&&
NULL
!=
oprec
&&
NULL
!=
m
&&
NULL
!=
n
&&
NULL
!=
k
);
LIBXSMM_ASSERT
(
NULL
!=
index_base
&&
NULL
!=
index_stride
&&
NULL
!=
batchsize
);
libxsmm_gemm_batch_omp
(
*
iprec
,
*
oprec
,
transa
,
transb
,
*
m
,
*
n
,
*
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
,
*
index_base
,
*
index_stride
,
stride_a
,
stride_b
,
stride_c
,
*
batchsize
);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_mmbatch_begin
)(
const
libxsmm_gemm_precision
*
,
const
int
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
libxsmm_blasint
*
,
const
void
*
,
const
void
*
);
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_mmbatch_begin
)(
const
libxsmm_gemm_precision
*
precision
,
const
int
*
flags
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
k
,
const
libxsmm_blasint
*
lda
,
const
libxsmm_blasint
*
ldb
,
const
libxsmm_blasint
*
ldc
,
const
void
*
alpha
,
const
void
*
beta
)
{
LIBXSMM_ASSERT
(
NULL
!=
precision
);
libxsmm_mmbatch_begin
(
*
precision
,
flags
,
m
,
n
,
k
,
lda
,
ldb
,
ldc
,
alpha
,
beta
);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_mmbatch_end
)(
void
);
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_mmbatch_end
)(
void
)
{
libxsmm_mmbatch_end
();
}
#endif
/*defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/
third_party/libxsmm/src/libxsmm_ext_xcopy.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include "libxsmm_xcopy.h"
#include "libxsmm_ext.h"
#define LIBXSMM_MCOPY_MT(MT, NT, M, N) ((MT) <= (M) && (NT) <= (N) && (64U * 64U) <= (((unsigned int)(M)) * (N)))
LIBXSMM_APIEXT
void
libxsmm_matcopy_omp
(
void
*
out
,
const
void
*
in
,
unsigned
int
typesize
,
libxsmm_blasint
m
,
libxsmm_blasint
n
,
libxsmm_blasint
ldi
,
libxsmm_blasint
ldo
)
{
LIBXSMM_INIT
if
(
0
<
typesize
&&
256
>
typesize
&&
m
<=
ldi
&&
m
<=
ldo
&&
out
!=
in
&&
((
NULL
!=
out
&&
0
<
m
&&
0
<
n
)
||
(
0
==
m
&&
0
==
n
)))
{
if
(
0
<
m
&&
0
<
n
)
{
#if defined(_OPENMP)
unsigned
int
tm
,
tn
,
ts
;
if
(
NULL
!=
in
)
{
/* mcopy */
tm
=
LIBXSMM_UPDIV
(
libxsmm_mcopy_mbytes
,
typesize
);
tn
=
(
unsigned
int
)(
libxsmm_mcopy_nscale
*
tm
);
ts
=
libxsmm_mcopy_mbytes
;
}
else
{
/* mzero */
tm
=
LIBXSMM_UPDIV
(
libxsmm_mzero_mbytes
,
typesize
);
tn
=
(
unsigned
int
)(
libxsmm_mzero_nscale
*
tm
);
ts
=
libxsmm_mzero_mbytes
;
}
if
(
0
==
tm
)
tm
=
m
;
if
(
0
==
tn
)
tn
=
LIBXSMM_MIN
(
LIBXSMM_XCOPY_TILE_MIN
,
n
);
if
(
0
!=
ts
&&
ts
<
(
tm
*
tn
*
typesize
))
{
tm
=
LIBXSMM_MAX
(
ts
/
(
tn
*
typesize
),
LIBXSMM_XCOPY_TILE_MIN
);
}
if
(
LIBXSMM_MCOPY_MT
(
tm
,
tn
,
(
unsigned
int
)
m
,
(
unsigned
int
)
n
))
{
/* consider problem-size */
libxsmm_xcopykernel
kernel
;
kernel
.
ptr
=
NULL
;
# if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 2))
if
(
0
!=
(
2
&
libxsmm_xcopy_jit
))
{
/* JIT'ted matrix-copy permitted? */
switch
(
typesize
)
{
case
8
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
tm
,
tn
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_F64
,
LIBXSMM_DATATYPE_F64
,
LIBXSMM_DATATYPE_F64
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
NULL
!=
in
?
LIBXSMM_MELTW_TYPE_UNARY_IDENTITY
/*mcopy*/
:
LIBXSMM_MELTW_TYPE_UNARY_XOR
/*mzero*/
);
break
;
case
4
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
tm
,
tn
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
NULL
!=
in
?
LIBXSMM_MELTW_TYPE_UNARY_IDENTITY
/*mcopy*/
:
LIBXSMM_MELTW_TYPE_UNARY_XOR
/*mzero*/
);
break
;
case
2
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
tm
,
tn
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_I16
,
LIBXSMM_DATATYPE_I16
,
LIBXSMM_DATATYPE_I16
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
NULL
!=
in
?
LIBXSMM_MELTW_TYPE_UNARY_IDENTITY
/*mcopy*/
:
LIBXSMM_MELTW_TYPE_UNARY_XOR
/*mzero*/
);
break
;
case
1
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
tm
,
tn
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_I8
,
LIBXSMM_DATATYPE_I8
,
LIBXSMM_DATATYPE_I8
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
NULL
!=
in
?
LIBXSMM_MELTW_TYPE_UNARY_IDENTITY
/*mcopy*/
:
LIBXSMM_MELTW_TYPE_UNARY_XOR
/*mzero*/
);
break
;
}
}
# endif
# if defined(LIBXSMM_EXT_TASKS) && 0
/* implies _OPENMP */
if
(
0
==
omp_get_active_level
())
# else
if
(
0
==
omp_in_parallel
())
# endif
{
/* enable internal parallelization */
const
int
nthreads
=
omp_get_max_threads
();
# if defined(LIBXSMM_EXT_TASKS)
if
(
0
>=
libxsmm_xcopy_taskscale
)
# endif
{
# pragma omp parallel num_threads(nthreads)
libxsmm_matcopy_task_internal
(
out
,
in
,
typesize
,
(
unsigned
int
)
m
,
(
unsigned
int
)
n
,
(
unsigned
int
)
ldi
,
(
unsigned
int
)
ldo
,
tm
,
tn
,
kernel
,
omp_get_thread_num
(),
nthreads
);
}
# if defined(LIBXSMM_EXT_TASKS)
else
{
/* tasks requested */
const
int
ntasks
=
nthreads
*
libxsmm_xcopy_taskscale
;
# pragma omp parallel num_threads(nthreads)
{
/* first thread discovering work will launch all tasks */
# pragma omp single nowait
/* anyone is good */
{
int
tid
;
for
(
tid
=
0
;
tid
<
ntasks
;
++
tid
)
{
# pragma omp task untied
libxsmm_matcopy_task_internal
(
out
,
in
,
typesize
,
(
unsigned
int
)
m
,
(
unsigned
int
)
n
,
(
unsigned
int
)
ldi
,
(
unsigned
int
)
ldo
,
tm
,
tn
,
kernel
,
tid
,
ntasks
);
}
}
}
}
# endif
}
else
{
/* assume external parallelization */
# if defined(LIBXSMM_EXT_TASKS)
/* implies _OPENMP */
const
int
nthreads
=
omp_get_num_threads
();
const
int
ntasks
=
(
0
==
libxsmm_xcopy_taskscale
?
(
LIBXSMM_XCOPY_TASKSCALE
)
:
libxsmm_xcopy_taskscale
)
*
nthreads
;
int
tid
;
for
(
tid
=
0
;
tid
<
ntasks
;
++
tid
)
{
# pragma omp task untied
libxsmm_matcopy_task_internal
(
out
,
in
,
typesize
,
(
unsigned
int
)
m
,
(
unsigned
int
)
n
,
(
unsigned
int
)
ldi
,
(
unsigned
int
)
ldo
,
tm
,
tn
,
kernel
,
tid
,
ntasks
);
}
if
(
0
==
libxsmm_nosync
)
{
/* allow to omit synchronization */
# pragma omp taskwait
}
# else
libxsmm_matcopy_task_internal
(
out
,
in
,
typesize
,
(
unsigned
int
)
m
,
(
unsigned
int
)
n
,
(
unsigned
int
)
ldi
,
(
unsigned
int
)
ldo
,
tm
,
tn
,
kernel
,
0
/*tid*/
,
1
/*nthreads*/
);
# endif
}
}
else
#endif
/*defined(_OPENMP)*/
if
(
NULL
!=
in
)
{
/* no MT, or small problem-size */
LIBXSMM_XCOPY_NONJIT
(
LIBXSMM_MCOPY_KERNEL
,
typesize
,
out
,
in
,
ldi
,
ldo
,
0
,
m
,
0
,
n
);
}
else
{
/* no MT, or small problem-size */
/* coverity[ptr_arith] */
LIBXSMM_XCOPY_NONJIT
(
LIBXSMM_MZERO_KERNEL
,
typesize
,
out
,
in
,
ldi
,
ldo
,
0
,
m
,
0
,
n
);
}
}
}
else
{
static
int
error_once
=
0
;
if
(
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
if
(
NULL
==
out
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: the matrix-copy input and/or output is NULL!
\n
"
);
}
else
if
(
out
==
in
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: output and input of the matrix-copy must be different!
\n
"
);
}
else
if
(
0
==
typesize
||
256
<=
typesize
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: invalid type-size for matrix-copy specified!
\n
"
);
}
else
if
(
ldi
<
m
||
ldo
<
m
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: the leading dimension(s) of the matrix-copy is/are too small!
\n
"
);
}
else
if
(
0
>
m
||
0
>
n
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: the matrix extent(s) of the matrix-copy is/are negative!
\n
"
);
}
}
}
}
LIBXSMM_APIEXT
void
libxsmm_otrans_omp
(
void
*
out
,
const
void
*
in
,
unsigned
int
typesize
,
libxsmm_blasint
m
,
libxsmm_blasint
n
,
libxsmm_blasint
ldi
,
libxsmm_blasint
ldo
)
{
static
int
error_once
=
0
;
LIBXSMM_INIT
if
(
0
<
typesize
&&
256
>
typesize
&&
m
<=
ldi
&&
n
<=
ldo
&&
((
NULL
!=
out
&&
NULL
!=
in
&&
0
<
m
&&
0
<
n
)
||
(
0
==
m
&&
0
==
n
)))
{
if
(
0
<
m
&&
0
<
n
)
{
if
(
out
!=
in
)
{
#if defined(_OPENMP)
unsigned
int
tm
=
LIBXSMM_UPDIV
(
libxsmm_tcopy_mbytes
,
typesize
);
unsigned
int
tn
=
(
unsigned
int
)(
libxsmm_tcopy_nscale
*
tm
);
if
(
0
==
tm
)
tm
=
m
;
if
(
0
==
tn
)
tn
=
LIBXSMM_MIN
(
LIBXSMM_XCOPY_TILE_MIN
,
n
);
if
(
0
!=
libxsmm_tcopy_mbytes
&&
libxsmm_tcopy_mbytes
<
(
tm
*
tn
*
typesize
))
{
tm
=
LIBXSMM_MAX
(
libxsmm_tcopy_mbytes
/
(
tn
*
typesize
),
LIBXSMM_XCOPY_TILE_MIN
);
}
if
(
tm
<=
(
unsigned
int
)
m
&&
tn
<=
(
unsigned
int
)
n
)
{
/* consider problem-size */
libxsmm_xcopykernel
kernel
;
kernel
.
ptr
=
NULL
;
# if defined(LIBXSMM_EXT_TASKS)
/* implies _OPENMP */
if
(
0
==
omp_get_active_level
())
# else
if
(
0
==
omp_in_parallel
())
# endif
{
/* enable internal parallelization */
const
int
nthreads
=
omp_get_max_threads
();
# if defined(LIBXSMM_EXT_TASKS)
if
(
0
>=
libxsmm_xcopy_taskscale
)
# endif
{
# pragma omp parallel num_threads(nthreads)
{
/* coverity[divide_by_zero] */
libxsmm_otrans_task_internal
(
out
,
in
,
typesize
,
(
unsigned
int
)
m
,
(
unsigned
int
)
n
,
(
unsigned
int
)
ldi
,
(
unsigned
int
)
ldo
,
tm
,
tn
,
kernel
,
omp_get_thread_num
(),
nthreads
);
}
}
# if defined(LIBXSMM_EXT_TASKS)
else
{
/* tasks requested */
const
int
ntasks
=
nthreads
*
libxsmm_xcopy_taskscale
;
# pragma omp parallel num_threads(nthreads)
{
/* first thread discovering work will launch all tasks */
# pragma omp single nowait
/* anyone is good */
{
int
tid
;
for
(
tid
=
0
;
tid
<
ntasks
;
++
tid
)
{
# pragma omp task untied
libxsmm_otrans_task_internal
(
out
,
in
,
typesize
,
(
unsigned
int
)
m
,
(
unsigned
int
)
n
,
(
unsigned
int
)
ldi
,
(
unsigned
int
)
ldo
,
tm
,
tn
,
kernel
,
tid
,
ntasks
);
}
}
}
}
# endif
}
else
{
/* assume external parallelization */
# if defined(LIBXSMM_EXT_TASKS)
/* implies _OPENMP */
const
int
nthreads
=
omp_get_num_threads
();
const
int
ntasks
=
(
0
==
libxsmm_xcopy_taskscale
?
(
LIBXSMM_XCOPY_TASKSCALE
)
:
libxsmm_xcopy_taskscale
)
*
nthreads
;
int
tid
;
for
(
tid
=
0
;
tid
<
ntasks
;
++
tid
)
{
# pragma omp task untied
libxsmm_otrans_task_internal
(
out
,
in
,
typesize
,
(
unsigned
int
)
m
,
(
unsigned
int
)
n
,
(
unsigned
int
)
ldi
,
(
unsigned
int
)
ldo
,
tm
,
tn
,
kernel
,
tid
,
ntasks
);
}
if
(
0
==
libxsmm_nosync
)
{
/* allow to omit synchronization */
# pragma omp taskwait
}
# else
/* coverity[divide_by_zero] */
libxsmm_otrans_task_internal
(
out
,
in
,
typesize
,
(
unsigned
int
)
m
,
(
unsigned
int
)
n
,
(
unsigned
int
)
ldi
,
(
unsigned
int
)
ldo
,
tm
,
tn
,
kernel
,
0
/*tid*/
,
1
/*nthreads*/
);
# endif
}
}
else
#endif
/*defined(_OPENMP)*/
{
/* no MT, or small problem-size */
#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1))
libxsmm_xcopykernel
kernel
;
kernel
.
ptr
=
NULL
;
if
(
0
!=
(
1
&
libxsmm_xcopy_jit
))
{
/* JIT'ted transpose permitted? */
switch
(
typesize
)
{
case
8
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
m
,
n
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_F64
,
LIBXSMM_DATATYPE_F64
,
LIBXSMM_DATATYPE_F64
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
break
;
case
4
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
m
,
n
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
break
;
case
2
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
m
,
n
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_I16
,
LIBXSMM_DATATYPE_I16
,
LIBXSMM_DATATYPE_I16
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
break
;
case
1
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
m
,
n
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_I8
,
LIBXSMM_DATATYPE_I8
,
LIBXSMM_DATATYPE_I8
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
break
;
}
if
(
NULL
!=
kernel
.
ptr
)
{
/* JIT-kernel available */
LIBXSMM_TCOPY_CALL
(
kernel
,
typesize
,
in
,
ldi
,
out
,
ldo
);
}
}
else
#endif
{
LIBXSMM_XCOPY_NONJIT
(
LIBXSMM_TCOPY_KERNEL
,
typesize
,
out
,
in
,
ldi
,
ldo
,
0
,
m
,
0
,
n
);
}
}
}
else
if
(
ldi
==
ldo
)
{
libxsmm_itrans
/*TODO: omp*/
(
out
,
typesize
,
m
,
n
,
ldi
,
ldo
);
}
else
if
(
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM ERROR: output and input of the transpose must be different!
\n
"
);
}
}
}
else
{
if
(
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
if
(
NULL
==
out
||
NULL
==
in
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: the transpose input and/or output is NULL!
\n
"
);
}
else
if
(
out
==
in
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: output and input of the transpose must be different!
\n
"
);
}
else
if
(
0
==
typesize
||
256
<=
typesize
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: invalid type-size for matrix-transpose specified!
\n
"
);
}
else
if
(
ldi
<
m
||
ldo
<
n
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: the leading dimension(s) of the transpose is/are too small!
\n
"
);
}
else
if
(
0
>
m
||
0
>
n
)
{
fprintf
(
stderr
,
"LIBXSMM ERROR: the matrix extent(s) of the transpose is/are negative!
\n
"
);
}
}
}
}
LIBXSMM_APIEXT
void
libxsmm_itrans_batch_omp
(
void
*
inout
,
unsigned
int
typesize
,
libxsmm_blasint
m
,
libxsmm_blasint
n
,
libxsmm_blasint
ldi
,
libxsmm_blasint
ldo
,
libxsmm_blasint
index_base
,
libxsmm_blasint
index_stride
,
const
libxsmm_blasint
stride
[],
libxsmm_blasint
batchsize
)
{
#if defined(_OPENMP)
if
(
1
<
batchsize
)
{
/* consider problem-size */
const
libxsmm_blasint
scratchsize
=
m
*
n
*
typesize
;
const
libxsmm_blasint
size
=
LIBXSMM_ABS
(
batchsize
);
char
buffer
[
LIBXSMM_ITRANS_BUFFER_MAXSIZE
];
char
*
const
mat0
=
(
char
*
)
inout
;
void
*
scratch
=
NULL
;
libxsmm_xcopykernel
kernel
=
{
NULL
};
if
(
m
!=
n
||
ldi
!=
ldo
||
127
<
typesize
)
{
if
(
scratchsize
<=
LIBXSMM_ITRANS_BUFFER_MAXSIZE
)
{
scratch
=
buffer
;
}
else
{
static
int
error_once
=
0
;
LIBXSMM_INIT
if
(
EXIT_SUCCESS
!=
libxsmm_xmalloc
(
&
scratch
,
scratchsize
,
0
/*auto-align*/
,
LIBXSMM_MALLOC_FLAG_SCRATCH
|
LIBXSMM_MALLOC_FLAG_PRIVATE
,
0
/*extra*/
,
0
/*extra_size*/
)
&&
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM ERROR: failed to allocate buffer for in-place transpose!
\n
"
);
}
}
#if (defined(LIBXSMM_XCOPY_JIT) && 0 != (LIBXSMM_XCOPY_JIT & 1))
if
(
0
!=
(
1
&
libxsmm_xcopy_jit
)
/* JIT'ted transpose permitted? */
/* avoid outgrown transpose kernel upfront */
&&
(
m
<=
LIBXSMM_CONFIG_MAX_DIM
||
n
<=
LIBXSMM_CONFIG_MAX_DIM
))
{
switch
(
typesize
)
{
case
8
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
m
,
n
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_F64
,
LIBXSMM_DATATYPE_F64
,
LIBXSMM_DATATYPE_F64
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
break
;
case
4
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
m
,
n
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
break
;
case
2
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
m
,
n
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_I16
,
LIBXSMM_DATATYPE_I16
,
LIBXSMM_DATATYPE_I16
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
break
;
case
1
:
kernel
.
function
=
libxsmm_dispatch_meltw_unary
(
m
,
n
,
&
ldi
,
&
ldo
,
LIBXSMM_DATATYPE_I8
,
LIBXSMM_DATATYPE_I8
,
LIBXSMM_DATATYPE_I8
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
break
;
}
}
#endif
}
# if defined(LIBXSMM_EXT_TASKS) && 0
/* implies _OPENMP */
if
(
0
==
omp_get_active_level
())
# else
if
(
0
==
omp_in_parallel
())
# endif
{
/* enable internal parallelization */
const
int
nthreads
=
omp_get_max_threads
();
# if defined(LIBXSMM_EXT_TASKS)
if
(
0
>=
libxsmm_xcopy_taskscale
)
# endif
{
const
libxsmm_blasint
tasksize
=
LIBXSMM_UPDIV
(
size
,
nthreads
);
# pragma omp parallel num_threads(nthreads)
{
const
libxsmm_blasint
begin
=
omp_get_thread_num
()
*
tasksize
;
const
libxsmm_blasint
span
=
begin
+
tasksize
;
libxsmm_itrans_internal
(
mat0
,
scratch
,
typesize
,
m
,
n
,
ldi
,
ldo
,
index_base
,
index_stride
,
stride
,
kernel
,
begin
,
LIBXSMM_MIN
(
span
,
size
));
}
}
# if defined(LIBXSMM_EXT_TASKS)
else
{
/* tasks requested */
const
int
ntasks
=
nthreads
*
libxsmm_xcopy_taskscale
;
const
libxsmm_blasint
tasksize
=
LIBXSMM_UPDIV
(
size
,
ntasks
);
# pragma omp parallel num_threads(nthreads)
{
/* first thread discovering work will launch all tasks */
# pragma omp single nowait
/* anyone is good */
{
int
tid
;
for
(
tid
=
0
;
tid
<
ntasks
;
++
tid
)
{
const
libxsmm_blasint
begin
=
tid
*
tasksize
;
const
libxsmm_blasint
span
=
begin
+
tasksize
;
# pragma omp task untied
libxsmm_itrans_internal
(
mat0
,
scratch
,
typesize
,
m
,
n
,
ldi
,
ldo
,
index_base
,
index_stride
,
stride
,
kernel
,
begin
,
LIBXSMM_MIN
(
span
,
size
));
}
}
}
}
# endif
}
else
{
/* assume external parallelization */
# if defined(LIBXSMM_EXT_TASKS)
/* implies _OPENMP */
const
int
nthreads
=
omp_get_num_threads
();
const
int
ntasks
=
(
0
==
libxsmm_xcopy_taskscale
?
(
LIBXSMM_XCOPY_TASKSCALE
)
:
libxsmm_xcopy_taskscale
)
*
nthreads
;
const
libxsmm_blasint
tasksize
=
LIBXSMM_UPDIV
(
size
,
ntasks
);
int
tid
;
for
(
tid
=
0
;
tid
<
ntasks
;
++
tid
)
{
const
libxsmm_blasint
begin
=
tid
*
tasksize
;
const
libxsmm_blasint
span
=
begin
+
tasksize
;
# pragma omp task untied
libxsmm_itrans_internal
(
mat0
,
scratch
,
typesize
,
m
,
n
,
ldi
,
ldo
,
index_base
,
index_stride
,
stride
,
kernel
,
begin
,
LIBXSMM_MIN
(
span
,
size
));
}
if
(
0
==
libxsmm_nosync
)
{
/* allow to omit synchronization */
# pragma omp taskwait
}
# else
libxsmm_itrans_internal
(
mat0
,
scratch
,
typesize
,
m
,
n
,
ldi
,
ldo
,
index_base
,
index_stride
,
stride
,
kernel
,
0
,
batchsize
);
# endif
}
if
(
NULL
!=
scratch
&&
LIBXSMM_ITRANS_BUFFER_MAXSIZE
<
scratchsize
)
{
libxsmm_xfree
(
scratch
,
0
/*no check*/
);
}
}
else
#endif
/*defined(_OPENMP)*/
libxsmm_itrans_batch
(
inout
,
typesize
,
m
,
n
,
ldi
,
ldo
,
index_base
,
index_stride
,
stride
,
batchsize
,
0
/*tid*/
,
1
/*ntasks*/
);
}
#if defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_matcopy_omp
)(
void
*
/*out*/
,
const
void
*
/*in*/
,
const
int
*
/*typesize*/
,
const
libxsmm_blasint
*
/*m*/
,
const
libxsmm_blasint
*
/*n*/
,
const
libxsmm_blasint
*
/*ldi*/
,
const
libxsmm_blasint
*
/*ldo*/
);
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_matcopy_omp
)(
void
*
out
,
const
void
*
in
,
const
int
*
typesize
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
ldi
,
const
libxsmm_blasint
*
ldo
)
{
libxsmm_blasint
ldx
;
LIBXSMM_ASSERT
(
NULL
!=
typesize
&&
0
<
*
typesize
&&
NULL
!=
m
);
ldx
=
*
(
NULL
!=
ldi
?
ldi
:
m
);
libxsmm_matcopy_omp
(
out
,
in
,
(
unsigned
int
)
*
typesize
,
*
m
,
*
(
NULL
!=
n
?
n
:
m
),
ldx
,
NULL
!=
ldo
?
*
ldo
:
ldx
);
}
/* implementation provided for Fortran 77 compatibility */
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_otrans_omp
)(
void
*
/*out*/
,
const
void
*
/*in*/
,
const
int
*
/*typesize*/
,
const
libxsmm_blasint
*
/*m*/
,
const
libxsmm_blasint
*
/*n*/
,
const
libxsmm_blasint
*
/*ldi*/
,
const
libxsmm_blasint
*
/*ldo*/
);
LIBXSMM_APIEXT
void
LIBXSMM_FSYMBOL
(
libxsmm_otrans_omp
)(
void
*
out
,
const
void
*
in
,
const
int
*
typesize
,
const
libxsmm_blasint
*
m
,
const
libxsmm_blasint
*
n
,
const
libxsmm_blasint
*
ldi
,
const
libxsmm_blasint
*
ldo
)
{
libxsmm_blasint
ldx
;
LIBXSMM_ASSERT
(
NULL
!=
typesize
&&
0
<
*
typesize
&&
NULL
!=
m
);
ldx
=
*
(
NULL
!=
ldi
?
ldi
:
m
);
libxsmm_otrans_omp
(
out
,
in
,
(
unsigned
int
)
*
typesize
,
*
m
,
*
(
NULL
!=
n
?
n
:
m
),
ldx
,
NULL
!=
ldo
?
*
ldo
:
ldx
);
}
#endif
/*defined(LIBXSMM_BUILD) && defined(LIBXSMM_BUILD_EXT) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/
Prev
1
…
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment