Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c454d419
Commit
c454d419
authored
May 12, 2023
by
lisj
Browse files
删除子模块的gitignore
parent
3359c1f1
Changes
264
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7404 additions
and
0 deletions
+7404
-0
third_party/libxsmm/samples/smm/.make
third_party/libxsmm/samples/smm/.make
+0
-0
third_party/libxsmm/samples/utilities/mhd/mhd_in.mhd
third_party/libxsmm/samples/utilities/mhd/mhd_in.mhd
+0
-0
third_party/libxsmm/scripts/libxsmm_config.py
third_party/libxsmm/scripts/libxsmm_config.py
+145
-0
third_party/libxsmm/scripts/libxsmm_dispatch.py
third_party/libxsmm/scripts/libxsmm_dispatch.py
+116
-0
third_party/libxsmm/scripts/libxsmm_interface.py
third_party/libxsmm/scripts/libxsmm_interface.py
+195
-0
third_party/libxsmm/scripts/libxsmm_source.sh
third_party/libxsmm/scripts/libxsmm_source.sh
+68
-0
third_party/libxsmm/scripts/libxsmm_specialized.py
third_party/libxsmm/scripts/libxsmm_specialized.py
+205
-0
third_party/libxsmm/scripts/libxsmm_utilities.py
third_party/libxsmm/scripts/libxsmm_utilities.py
+320
-0
third_party/libxsmm/scripts/libxsmm_version.sh
third_party/libxsmm/scripts/libxsmm_version.sh
+30
-0
third_party/libxsmm/src/libxsmm_cpuid_arm.c
third_party/libxsmm/src/libxsmm_cpuid_arm.c
+96
-0
third_party/libxsmm/src/libxsmm_cpuid_x86.c
third_party/libxsmm/src/libxsmm_cpuid_x86.c
+336
-0
third_party/libxsmm/src/libxsmm_diff.h
third_party/libxsmm/src/libxsmm_diff.h
+144
-0
third_party/libxsmm/src/libxsmm_dnn.c
third_party/libxsmm/src/libxsmm_dnn.c
+759
-0
third_party/libxsmm/src/libxsmm_dnn_convolution.c
third_party/libxsmm/src/libxsmm_dnn_convolution.c
+2747
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.c
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.c
+719
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.h
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.h
+22
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.c
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.c
+544
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.h
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.h
+22
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.c
...party/libxsmm/src/libxsmm_dnn_convolution_weight_update.c
+914
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.h
...party/libxsmm/src/libxsmm_dnn_convolution_weight_update.h
+22
-0
No files found.
Too many changes to show.
To preserve performance only
264 of 264+
files are displayed.
Plain diff
Email patch
third_party/libxsmm/samples/smm/.make
0 → 100644
View file @
c454d419
third_party/libxsmm/samples/utilities/mhd/mhd_in.mhd
0 → 100644
View file @
c454d419
File added
third_party/libxsmm/scripts/libxsmm_config.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
from
string
import
Template
from
datetime
import
date
import
libxsmm_utilities
import
fnmatch
import
sys
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
1
<
argc
:
# required argument(s)
filename
=
sys
.
argv
[
1
]
# default configuration if no arguments are given
ilp64
=
offload
=
precision
=
flags
=
threshold
=
0
sync
=
jit
=
1
alpha
=
beta
=
1
cacheline
=
64
prefetch
=
-
1
wrap
=
1
malloc
=
0
mnklist
=
list
()
# optional argument(s)
if
2
<
argc
:
ilp64
=
int
(
sys
.
argv
[
2
])
if
3
<
argc
:
offload
=
int
(
sys
.
argv
[
3
])
if
4
<
argc
:
cacheline
=
libxsmm_utilities
.
sanitize_alignment
(
int
(
sys
.
argv
[
4
]))
if
5
<
argc
:
precision
=
int
(
sys
.
argv
[
5
])
if
6
<
argc
:
prefetch
=
int
(
sys
.
argv
[
6
])
if
7
<
argc
:
threshold
=
int
(
sys
.
argv
[
7
])
if
8
<
argc
:
sync
=
int
(
sys
.
argv
[
8
])
if
9
<
argc
:
jit
=
int
(
sys
.
argv
[
9
])
if
10
<
argc
:
flags
=
int
(
sys
.
argv
[
10
])
if
11
<
argc
:
alpha
=
int
(
sys
.
argv
[
11
])
if
12
<
argc
:
beta
=
int
(
sys
.
argv
[
12
])
if
13
<
argc
:
wrap
=
int
(
sys
.
argv
[
13
])
if
14
<
argc
:
malloc
=
int
(
sys
.
argv
[
14
])
if
15
<
argc
:
mnklist
=
sorted
(
libxsmm_utilities
.
load_mnklist
(
sys
.
argv
[
15
:],
0
))
version
,
branch
,
realversion
=
libxsmm_utilities
.
version_branch
()
major
,
minor
,
update
,
patch
=
libxsmm_utilities
.
version_numbers
(
version
)
if
0
==
threshold
:
threshold
=
64
*
64
*
64
maxmnk
=
libxsmm_utilities
.
max_mnk
(
mnklist
,
threshold
)
maxdim
=
int
(
maxmnk
**
(
1.0
/
3.0
)
+
0.5
)
avgdim
=
int
(
0.5
*
maxdim
+
0.5
)
avgm
=
libxsmm_utilities
.
median
(
list
(
map
(
lambda
mnk
:
mnk
[
0
],
mnklist
)),
avgdim
,
False
)
avgn
=
libxsmm_utilities
.
median
(
list
(
map
(
lambda
mnk
:
mnk
[
1
],
mnklist
)),
avgdim
,
False
)
avgk
=
libxsmm_utilities
.
median
(
list
(
map
(
lambda
mnk
:
mnk
[
2
],
mnklist
)),
avgdim
,
False
)
maxm
=
libxsmm_utilities
.
max_mnk
(
mnklist
,
avgdim
,
0
)
maxn
=
libxsmm_utilities
.
max_mnk
(
mnklist
,
avgdim
,
1
)
maxk
=
libxsmm_utilities
.
max_mnk
(
mnklist
,
avgdim
,
2
)
substitute
=
{
"VERSION"
:
realversion
,
"BRANCH"
:
branch
,
"MAJOR"
:
major
,
"MINOR"
:
minor
,
"UPDATE"
:
update
,
"PATCH"
:
patch
,
"DATE"
:
date
.
today
().
strftime
(
"%Y%m%d"
),
"CACHELINE"
:
cacheline
,
"PREFETCH"
:
[
-
1
,
prefetch
][
0
<=
prefetch
],
"MAX_MNK"
:
maxmnk
,
"MAX_DIM"
:
maxdim
,
"AVG_DIM"
:
int
((
maxdim
+
1
)
/
2
),
"MAX_M"
:
[
maxdim
,
maxm
][
avgm
<
maxm
],
"MAX_N"
:
[
maxdim
,
maxn
][
avgn
<
maxn
],
"MAX_K"
:
[
maxdim
,
maxk
][
avgk
<
maxk
],
"FLAGS"
:
flags
,
"ILP64"
:
[
0
,
1
][
0
!=
ilp64
],
"ALPHA"
:
alpha
,
"BETA"
:
beta
,
"WRAP"
:
wrap
,
"MALLOC"
:
malloc
,
"SYNC"
:
[
0
,
1
][
0
!=
sync
],
"JIT"
:
[
0
,
1
][
0
!=
jit
],
"LIBXSMM_OFFLOAD_BUILD"
:
[
""
,
"
\n
#define LIBXSMM_OFFLOAD_BUILD"
][
0
!=
offload
],
"MNK_PREPROCESSOR_LIST"
:
""
,
}
template
=
Template
(
open
(
filename
,
"r"
).
read
())
if
fnmatch
.
fnmatch
(
filename
,
"*.h*"
):
if
mnklist
:
first
=
mnklist
[
0
]
for
mnk
in
mnklist
:
mnkstr
=
"_"
.
join
(
map
(
str
,
mnk
))
if
mnk
!=
first
:
substitute
[
"MNK_PREPROCESSOR_LIST"
]
+=
"
\n
"
if
2
!=
precision
:
substitute
[
"MNK_PREPROCESSOR_LIST"
]
+=
(
"#define LIBXSMM_SMM_"
+
mnkstr
)
if
mnk
!=
first
or
0
==
precision
:
substitute
[
"MNK_PREPROCESSOR_LIST"
]
+=
"
\n
"
if
1
!=
precision
:
substitute
[
"MNK_PREPROCESSOR_LIST"
]
+=
(
"#define LIBXSMM_DMM_"
+
mnkstr
)
print
(
template
.
substitute
(
substitute
))
else
:
substitute
[
"BLASINT_KIND"
]
=
[
"C_INT"
,
"C_LONG_LONG"
][
0
!=
ilp64
]
print
(
template
.
safe_substitute
(
substitute
))
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
sys
.
argv
[
0
]
+
": wrong number of arguments!"
)
third_party/libxsmm/scripts/libxsmm_dispatch.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import
libxsmm_utilities
import
sys
import
os
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
1
<
argc
:
arg1_filename
=
[
sys
.
argv
[
1
],
""
][
"0"
==
sys
.
argv
[
1
]]
arg1_isfile
=
os
.
path
.
isfile
(
arg1_filename
)
base
=
1
if
arg1_isfile
:
print
(
"#if !defined(_WIN32)"
)
print
(
"{ static const char *const build_state ="
)
print
(
'# include "../'
+
os
.
path
.
basename
(
arg1_filename
)
+
'"'
)
print
(
" ;"
)
print
(
" internal_build_state = build_state;"
)
print
(
"}"
)
print
(
"#endif"
)
base
=
2
if
(
base
+
2
)
<
argc
:
precision
=
int
(
sys
.
argv
[
base
+
0
])
threshold
=
int
(
sys
.
argv
[
base
+
1
])
mnklist
=
libxsmm_utilities
.
load_mnklist
(
sys
.
argv
[
base
+
2
:],
0
)
print
(
"/* omit registering code if JIT is enabled"
" and if an ISA extension is found"
)
print
(
" * which is beyond the static code"
" path used to compile the library"
)
print
(
" */"
)
print
(
"#if (0 != LIBXSMM_JIT) && !defined(__MIC__)"
)
print
(
"if (LIBXSMM_X86_GENERIC > libxsmm_target_archid "
"/* JIT code gen. is not available */"
)
print
(
" /* conditions allows to avoid JIT "
"(if static code is good enough) */"
)
print
(
" || (LIBXSMM_STATIC_TARGET_ARCH == libxsmm_target_archid)"
)
print
(
" || (LIBXSMM_X86_AVX512_CORE <= libxsmm_target_archid &&"
)
print
(
" libxsmm_cpuid_vlen32(LIBXSMM_STATIC_TARGET_ARCH) =="
)
print
(
" libxsmm_cpuid_vlen32(libxsmm_target_archid)))"
)
print
(
"#endif"
)
print
(
"{"
)
print
(
" libxsmm_xmmfunction func;"
)
for
mnk
in
mnklist
:
mstr
,
nstr
,
kstr
,
mnkstr
=
(
str
(
mnk
[
0
]),
str
(
mnk
[
1
]),
str
(
mnk
[
2
]),
"_"
.
join
(
map
(
str
,
mnk
)),
)
mnksig
=
mstr
+
", "
+
nstr
+
", "
+
kstr
# prefer registering double-precision kernels
# when approaching an exhausted registry
if
1
!=
precision
:
# only double-precision
print
(
" func.dmm = (libxsmm_dmmfunction)libxsmm_dmm_"
+
mnkstr
+
";"
)
print
(
" internal_register_static_code("
+
"LIBXSMM_GEMM_PRECISION_F64, "
+
mnksig
+
", func, new_registry);"
)
for
mnk
in
mnklist
:
mstr
,
nstr
,
kstr
,
mnkstr
=
(
str
(
mnk
[
0
]),
str
(
mnk
[
1
]),
str
(
mnk
[
2
]),
"_"
.
join
(
map
(
str
,
mnk
)),
)
mnksig
=
mstr
+
", "
+
nstr
+
", "
+
kstr
# prefer registering double-precision kernels
# when approaching an exhausted registry
if
2
!=
precision
:
# only single-precision
print
(
" func.smm = (libxsmm_smmfunction)libxsmm_smm_"
+
mnkstr
+
";"
)
print
(
" internal_register_static_code("
+
"LIBXSMM_GEMM_PRECISION_F32, "
+
mnksig
+
", func, new_registry);"
)
print
(
"}"
)
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
sys
.
argv
[
0
]
+
": wrong number of arguments!"
)
third_party/libxsmm/scripts/libxsmm_interface.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
from
string
import
Template
import
libxsmm_utilities
import
fnmatch
import
sys
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
1
<
argc
:
# required argument(s)
filename
=
sys
.
argv
[
1
]
# default configuration if no arguments are given
precision
=
0
# all
ifversion
=
1
# interface
prefetch
=
-
1
# auto
mnklist
=
list
()
# optional argument(s)
if
2
<
argc
:
ivalue
=
int
(
sys
.
argv
[
2
])
ifversion
=
(
ivalue
>>
2
)
precision
=
(
ivalue
&
3
)
if
3
<
argc
:
prefetch
=
int
(
sys
.
argv
[
3
])
if
4
<
argc
:
mnklist
=
sorted
(
libxsmm_utilities
.
load_mnklist
(
sys
.
argv
[
4
:],
0
))
template
=
Template
(
open
(
filename
,
"r"
).
read
())
if
fnmatch
.
fnmatch
(
filename
,
"*.h*"
):
optional
=
[
", ..."
,
""
][
0
<=
prefetch
]
substitute
=
{
"MNK_INTERFACE_LIST"
:
""
}
for
mnk
in
mnklist
:
mnkstr
=
"_"
.
join
(
map
(
str
,
mnk
))
if
2
!=
precision
:
pfsig
=
[
optional
+
");"
,
",
\n
"
"const float* pa, "
"const float* pb, "
"const float* pc);"
][
0
<
prefetch
]
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
LIBXSMM_API void libxsmm_smm_"
+
mnkstr
+
"(const float* a, const float* b, float* c"
+
pfsig
)
if
1
!=
precision
:
pfsig
=
[
optional
+
");"
,
",
\n
"
"const double* pa, "
"const double* pb, "
"const double* pc);"
][
0
<
prefetch
]
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
LIBXSMM_API void libxsmm_dmm_"
+
mnkstr
+
"(const double* a, const double* b, double* c"
+
pfsig
)
if
0
==
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
"
if
mnklist
and
0
!=
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
"
print
(
template
.
substitute
(
substitute
))
else
:
# Fortran interface
if
1
>
ifversion
and
0
!=
ifversion
:
raise
ValueError
(
"Fortran interface level is inconsistent!"
)
# Fortran's OPTIONAL allows to always generate an interface
# with prefetch signature (more flexible usage)
if
0
==
prefetch
:
prefetch
=
-
1
version
,
branch
,
realversion
=
libxsmm_utilities
.
version_branch
(
16
)
major
,
minor
,
update
,
patch
=
libxsmm_utilities
.
version_numbers
(
version
)
substitute
=
{
"VERSION"
:
realversion
,
"BRANCH"
:
branch
,
"MAJOR"
:
major
,
"MINOR"
:
minor
,
"UPDATE"
:
update
,
"PATCH"
:
patch
,
"MNK_INTERFACE_LIST"
:
""
,
"CONTIGUOUS"
:
[
""
,
", CONTIGUOUS"
][
1
<
ifversion
]
}
if
mnklist
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
"
for
mnk
in
mnklist
:
mnkstr
=
"_"
.
join
(
map
(
str
,
mnk
))
if
0
==
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_"
+
mnkstr
+
", libxsmm_dmm_"
+
mnkstr
)
elif
2
!=
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_"
+
mnkstr
)
elif
1
!=
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_dmm_"
+
mnkstr
)
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
INTERFACE"
optional
=
[
", OPTIONAL"
,
""
][
0
<
prefetch
]
bindc
=
[
""
,
"BIND(C)"
][
0
<
prefetch
]
for
mnk
in
mnklist
:
mnkstr
=
"_"
.
join
(
map
(
str
,
mnk
))
if
2
!=
precision
:
pfsiga
=
[
") BIND(C)
\n
"
,
","
+
"&"
.
rjust
(
26
-
len
(
mnkstr
))
+
"
\n
& pa, pb, pc) "
+
bindc
+
"
\n
"
][
0
!=
prefetch
]
pfsigb
=
[
""
,
" REAL(C_FLOAT), "
"INTENT(IN)"
+
optional
+
" :: "
"pa(*), "
"pb(*), "
"pc(*)
\n
"
][
0
!=
prefetch
]
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"PURE SUBROUTINE libxsmm_smm_"
+
mnkstr
+
"(a, b, c"
+
pfsiga
+
" IMPORT :: C_FLOAT
\n
"
" REAL(C_FLOAT), "
"INTENT(IN) :: a(*), b(*)
\n
"
" REAL(C_FLOAT), "
"INTENT(INOUT) :: c(*)
\n
"
+
pfsigb
+
" END SUBROUTINE"
)
if
1
!=
precision
:
pfsiga
=
[
") BIND(C)
\n
"
,
","
+
"&"
.
rjust
(
26
-
len
(
mnkstr
))
+
"
\n
& pa, pb, pc) "
+
bindc
+
"
\n
"
][
0
!=
prefetch
]
pfsigb
=
[
""
,
" REAL(C_DOUBLE), "
"INTENT(IN)"
+
optional
+
" :: "
"pa(*), "
"pb(*), "
"pc(*)
\n
"
][
0
!=
prefetch
]
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"PURE SUBROUTINE libxsmm_dmm_"
+
mnkstr
+
"(a, b, c"
+
pfsiga
+
" IMPORT :: C_DOUBLE
\n
"
" REAL(C_DOUBLE), "
"INTENT(IN) :: a(*), b(*)
\n
"
" REAL(C_DOUBLE), "
"INTENT(INOUT) :: c(*)
\n
"
+
pfsigb
+
" END SUBROUTINE"
)
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
END INTERFACE"
print
(
template
.
safe_substitute
(
substitute
))
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
sys
.
argv
[
0
]
+
": wrong number of arguments!"
)
third_party/libxsmm/scripts/libxsmm_source.sh
0 → 100755
View file @
c454d419
#!/usr/bin/env sh
SRCDIR
=
../src
GREP
=
$(
command
-v
grep
)
if
[
""
=
"
${
GREP
}
"
]
;
then
>
&2
echo
"Error: missing prerequisites!"
exit
1
fi
cat
<<
EOM
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_SOURCE_H
#define LIBXSMM_SOURCE_H
#if defined(LIBXSMM_MACROS_H)
# error Please do not include any LIBXSMM header other than libxsmm_source.h!
#endif
#if defined(LIBXSMM_BUILD)
# error LIBXSMM_BUILD cannot be defined for the header-only LIBXSMM!
#endif
/**
* This header is intentionally called "libxsmm_source.h" since the followings block
* includes *internal* files, and thereby exposes LIBXSMM's implementation.
* The so-called "header-only" usage model gives up the clearly defined binary interface
* (including support for hot-fixes after deployment), and requires to rebuild client
* code for every (internal) change of LIBXSMM. Please make sure to only rely on the
* public interface as the internal implementation may change without notice.
*/
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
EOM
HERE
=
$(
cd
"
$(
dirname
"
$0
"
)
"
&&
pwd
-P
)
if
[
""
=
"
$1
"
]
;
then
DSTDIR
=
${
SRCDIR
}
else
DSTDIR
=
$1
fi
# determine order of filenames in directory list
export
LC_ALL
=
C
# good-enough pattern to match a main function, and to exclude this translation unit
for
FILE
in
$(
cd
"
${
HERE
}
/
${
SRCDIR
}
"
&&
${
GREP
}
-L
"main[[:space:]]*(.*)"
./
*
.c
)
;
do
BASENAME
=
$(
basename
"
${
FILE
}
"
)
echo
"#include
\"
${
DSTDIR
}
/
${
BASENAME
}
\"
"
done
cat
<<
EOM
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#endif /*LIBXSMM_SOURCE_H*/
EOM
third_party/libxsmm/scripts/libxsmm_specialized.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import
sys
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
6
==
argc
:
precision
=
int
(
sys
.
argv
[
1
])
m
,
n
,
k
=
int
(
sys
.
argv
[
2
]),
int
(
sys
.
argv
[
3
]),
int
(
sys
.
argv
[
4
])
prefetch
=
int
(
sys
.
argv
[
5
])
mnkstr
=
str
(
m
)
+
"_"
+
str
(
n
)
+
"_"
+
str
(
k
)
optional
=
[
""
,
", ..."
][
0
>
prefetch
]
signature
=
[
"a, b, c"
,
"a, b, c, pa, pb, pc"
][
0
<
prefetch
]
if
2
!=
precision
:
pfsig
=
[
optional
+
")"
,
"
\n
"
", const float* pa"
", const float* pb"
", const float* pc)"
,
][
0
<
prefetch
]
print
print
print
(
"LIBXSMM_API void libxsmm_smm_"
+
mnkstr
+
"(const float* a, const float* b, float* c"
+
pfsig
)
print
(
"{"
)
print
(
"#if defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_skx_sp) &&
\\
"
)
print
(
" !(defined(__AVX512PF__) && defined(__AVX512ER__))"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_skx("
+
signature
+
");"
)
print
(
"#elif defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_knl_sp)"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_knl("
+
signature
+
");"
)
print
(
"#elif defined(__AVX2__) && "
"defined(LIBXSMM_GENTARGET_hsw_sp)"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_hsw("
+
signature
+
");"
)
print
(
"#elif defined(__AVX__) && "
"defined(LIBXSMM_GENTARGET_snb_sp)"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_snb("
+
signature
+
");"
)
print
(
"#elif defined(__SSE3__) && "
"defined(LIBXSMM_GENTARGET_wsm_sp)"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_wsm("
+
signature
+
");"
)
print
(
"#else"
)
print
(
" const char transa = (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print
(
" const char transb = (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print
(
" const float alpha = LIBXSMM_ALPHA, beta = LIBXSMM_BETA;"
)
print
(
" const libxsmm_blasint "
"m = "
+
str
(
m
)
+
", "
"n = "
+
str
(
n
)
+
", "
"k = "
+
str
(
k
)
+
";"
)
if
0
<
prefetch
:
print
(
" LIBXSMM_UNUSED(pa);"
" LIBXSMM_UNUSED(pb);"
" LIBXSMM_UNUSED(pc);"
)
print
(
" LIBXSMM_INLINE_XGEMM(float, float, &transa, &transb,"
" &m, &n, &k, &alpha, a, &m, b, &k, &beta, c, &m);"
)
print
(
"#endif"
)
print
(
"}"
)
print
print
print
(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_smm_"
+
mnkstr
+
")(const float* a, const float* b, float* c"
+
pfsig
+
";"
)
print
(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_smm_"
+
mnkstr
+
")(const float* a, const float* b, float* c"
+
pfsig
)
print
(
"{"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"("
+
signature
+
");"
)
print
(
"}"
)
if
1
!=
precision
:
pfsig
=
[
optional
+
")"
,
"
\n
"
", const double* pa"
", const double* pb"
", const double* pc)"
,
][
0
<
prefetch
]
print
print
print
(
"LIBXSMM_API void libxsmm_dmm_"
+
mnkstr
+
"(const double* a, const double* b, double* c"
+
pfsig
)
print
(
"{"
)
print
(
"#if defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_skx_dp) &&
\\
"
)
print
(
" !(defined(__AVX512PF__) && defined(__AVX512ER__))"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_skx("
+
signature
+
");"
)
print
(
"#elif defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_knl_dp)"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_knl("
+
signature
+
");"
)
print
(
"#elif defined(__AVX2__) && "
"defined(LIBXSMM_GENTARGET_hsw_dp)"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_hsw("
+
signature
+
");"
)
print
(
"#elif defined(__AVX__) && "
"defined(LIBXSMM_GENTARGET_snb_dp)"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_snb("
+
signature
+
");"
)
print
(
"#elif defined(__SSE3__) && "
"defined(LIBXSMM_GENTARGET_wsm_dp)"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_wsm("
+
signature
+
");"
)
print
(
"#else"
)
print
(
" const char transa = (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print
(
" const char transb = (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print
(
" const double alpha = LIBXSMM_ALPHA, beta = LIBXSMM_BETA;"
)
print
(
" const libxsmm_blasint "
"m = "
+
str
(
m
)
+
", "
"n = "
+
str
(
n
)
+
", "
"k = "
+
str
(
k
)
+
";"
)
if
0
<
prefetch
:
print
(
" LIBXSMM_UNUSED(pa);"
" LIBXSMM_UNUSED(pb);"
" LIBXSMM_UNUSED(pc);"
)
print
(
" LIBXSMM_INLINE_XGEMM(double, double, &transa, &transb,"
" &m, &n, &k, &alpha, a, &m, b, &k, &beta, c, &m);"
)
print
(
"#endif"
)
print
(
"}"
)
print
print
print
(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dmm_"
+
mnkstr
+
")(const double* a, const double* b, double* c"
+
pfsig
+
";"
)
print
(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dmm_"
+
mnkstr
+
")(const double* a, const double* b, double* c"
+
pfsig
)
print
(
"{"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"("
+
signature
+
");"
)
print
(
"}"
)
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
sys
.
argv
[
0
]
+
": wrong number of arguments!"
)
third_party/libxsmm/scripts/libxsmm_utilities.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import
itertools
import
operator
import
inspect
import
sys
import
os
try
:
from
functools
import
reduce
except
ImportError
:
pass
def
upper_list
(
lists
,
level
):
nlist
=
len
(
lists
)
upper
=
[
level
,
level
+
nlist
][
1
>
level
]
-
1
above
=
lists
[
upper
]
if
above
:
return
above
elif
-
nlist
<=
level
:
return
upper_list
(
lists
,
level
-
1
)
else
:
return
[]
# https://docs.python.org/3/library/itertools.html#itertools.product
def
itertools_product
(
*
args
):
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
pools
=
[
tuple
(
pool
)
for
pool
in
args
]
result
=
[[]]
for
pool
in
pools
:
result
=
[
x
+
[
y
]
for
x
in
result
for
y
in
pool
]
for
prod
in
result
:
yield
tuple
(
prod
)
def
load_mnklist
(
argv
,
threshold
,
inputformat
=
0
,
resultset
=
None
):
if
resultset
is
None
:
resultset
=
set
()
if
0
==
inputformat
:
# indexes format
resultset
=
set
(
map
(
lambda
mnk
:
tuple
(
map
(
int
,
mnk
.
split
(
"_"
))),
argv
))
elif
-
1
==
inputformat
:
# new input format
groups
=
map
(
lambda
group
:
[
int
(
i
)
for
i
in
group
.
split
()],
" "
.
join
(
argv
[
0
:]).
split
(
","
),
)
resultset
=
set
(
itertools
.
chain
(
*
[
list
(
itertools_product
(
*
(
i
,
i
,
i
)))
for
i
in
groups
]
)
)
elif
-
2
==
inputformat
:
# legacy format
mlist
=
list
(
map
(
int
,
map
(
lambda
s
:
str
(
s
).
replace
(
","
,
" "
).
strip
(),
argv
[
2
:
2
+
int
(
argv
[
0
])],
),
)
)
nlist
=
list
(
map
(
int
,
map
(
lambda
s
:
str
(
s
).
replace
(
","
,
" "
).
strip
(),
argv
[
2
+
int
(
argv
[
0
]):
2
+
int
(
argv
[
0
])
+
int
(
argv
[
1
])],
),
)
)
klist
=
list
(
map
(
int
,
map
(
lambda
s
:
str
(
s
).
replace
(
","
,
" "
).
strip
(),
argv
[
2
+
int
(
argv
[
0
])
+
int
(
argv
[
1
]):],
),
)
)
mnk
=
[
mlist
,
nlist
,
klist
]
top
=
[
[
mlist
,
upper_list
(
mnk
,
0
)][
0
==
len
(
mlist
)],
[
nlist
,
upper_list
(
mnk
,
1
)][
0
==
len
(
nlist
)],
[
klist
,
upper_list
(
mnk
,
2
)][
0
==
len
(
klist
)],
]
for
m
in
top
[
0
]:
for
n
in
top
[
1
]:
if
not
nlist
:
n
=
m
for
k
in
top
[
2
]:
if
not
klist
:
k
=
n
if
not
mlist
:
m
=
k
resultset
.
add
((
m
,
n
,
k
))
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
"load_mnklist: unexpected input format!"
)
if
0
!=
threshold
:
# threshold requested
return
set
(
filter
(
lambda
mnk
:
(
0
<
mnk
[
0
])
and
(
0
<
mnk
[
1
])
and
(
0
<
mnk
[
2
])
and
(
threshold
>=
(
mnk
[
0
]
*
mnk
[
1
]
*
mnk
[
2
])),
resultset
,
)
)
else
:
return
set
(
filter
(
lambda
mnk
:
(
0
<
mnk
[
0
])
and
(
0
<
mnk
[
1
])
and
(
0
<
mnk
[
2
]),
resultset
,
)
)
def
max_mnk
(
mnklist
,
init
=
0
,
index
=
None
):
if
index
is
not
None
and
0
<=
index
and
index
<
3
:
mapped
=
map
(
lambda
mnk
:
mnk
[
index
],
mnklist
)
else
:
mapped
=
map
(
lambda
mnk
:
mnk
[
0
]
*
mnk
[
1
]
*
mnk
[
2
],
mnklist
)
return
reduce
(
max
,
mapped
,
init
)
def
median
(
list_of_numbers
,
fallback
=
None
,
average
=
True
):
size
=
len
(
list_of_numbers
)
if
0
<
size
:
# TODO: use nth element
list_of_numbers
.
sort
()
size2
=
int
(
size
/
2
)
if
average
and
0
==
(
size
-
size2
*
2
):
medval
=
int
(
0.5
*
(
list_of_numbers
[
size2
-
1
]
+
list_of_numbers
[
size2
])
+
0.5
)
else
:
medval
=
list_of_numbers
[
size2
]
if
fallback
is
not
None
:
result
=
min
(
medval
,
fallback
)
else
:
result
=
medval
elif
fallback
is
not
None
:
result
=
fallback
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
"median: empty list!"
)
return
result
def
is_pot
(
num
):
return
0
<=
num
or
0
==
(
num
&
(
num
-
1
))
def
sanitize_alignment
(
alignment
):
if
0
>=
alignment
:
alignment
=
[
1
,
64
][
0
!=
alignment
]
elif
not
is_pot
(
alignment
):
sys
.
tracebacklimit
=
0
raise
ValueError
(
"sanitize_alignment: alignment must be a Power of Two (POT)!"
)
return
alignment
def
align_value
(
n
,
typesize
,
alignment
):
if
0
<
typesize
and
0
<
alignment
:
return
(
((
n
*
typesize
+
alignment
-
1
)
/
alignment
)
*
alignment
)
/
typesize
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
"align_value: invalid input!"
)
def
version_branch_from_file
(
version_filepath
):
version_file
=
open
(
version_filepath
,
"r"
)
version
,
branch
,
sep
=
"1.0"
,
""
,
"-"
try
:
version_list
,
n
=
version_file
.
read
().
replace
(
"
\n
"
,
""
).
split
(
sep
),
0
for
word
in
version_list
:
if
not
reduce
(
operator
.
and_
,
(
subword
.
isdigit
()
for
subword
in
word
.
split
(
"."
)),
True
,
):
branch
+=
[
sep
+
word
,
word
][
0
==
n
]
n
+=
1
else
:
break
version
=
sep
.
join
(
version_list
[
n
:])
finally
:
version_file
.
close
()
return
(
version
,
branch
)
def
version_numbers
(
version
,
branch
=
None
):
version_list
=
version
.
split
(
"-"
)
if
not
version_list
[
0
][
0
].
isdigit
():
vbranch
=
version_list
[
0
]
else
:
vbranch
=
"master"
if
branch
is
None
or
vbranch
==
branch
:
minor
=
update
=
patch
=
0
major
=
1
n
=
len
(
version_list
)
if
1
<
n
:
patch_list
=
version_list
[
n
-
1
]
if
1
==
len
(
patch_list
.
split
(
"."
)):
version_list
=
version_list
[
n
-
2
].
split
(
"."
)
if
version_list
!=
[
vbranch
]:
patch
=
int
(
patch_list
)
else
:
major
=
int
(
patch_list
)
else
:
version_list
=
patch_list
.
split
(
"."
)
else
:
version_list
=
version
.
split
(
"."
)
n
=
len
(
version_list
)
try
:
if
0
<
n
:
major
=
int
(
version_list
[
0
])
if
1
<
n
:
minor
=
int
(
version_list
[
1
])
if
2
<
n
:
update
=
int
(
version_list
[
2
])
except
ValueError
:
# if 1 == n: major = 0
pass
else
:
major
=
minor
=
update
=
patch
=
-
1
return
major
,
minor
,
update
,
patch
def
version_branch
(
max_strlen
=-
1
):
version_filename
=
"version.txt"
filepath_default
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
inspect
.
getfile
(
inspect
.
currentframe
())),
".."
,
version_filename
,
)
)
filepath_local
=
os
.
path
.
realpath
(
version_filename
)
# local version file
realversion
,
branch
=
version_branch_from_file
(
filepath_default
)
version
=
realversion
out_of_tree
=
filepath_default
!=
filepath_local
if
out_of_tree
and
os
.
path
.
isfile
(
filepath_local
):
local
,
ignored
=
version_branch_from_file
(
filepath_local
)
if
version_numbers
(
realversion
)
<
version_numbers
(
local
):
version
=
local
if
0
<
max_strlen
:
start
=
int
(
max_strlen
/
3
)
cut
=
max
(
branch
.
rfind
(
"-"
,
start
,
max_strlen
),
branch
.
rfind
(
"_"
,
start
,
max_strlen
),
branch
.
rfind
(
"."
,
start
,
max_strlen
),
)
if
start
<
cut
:
branch
=
branch
[
0
:
cut
]
else
:
branch
=
branch
[
0
:
max_strlen
]
return
(
version
,
branch
,
realversion
)
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
1
<
argc
:
arg1
=
int
(
sys
.
argv
[
1
])
else
:
arg1
=
0
if
-
1
==
arg1
:
if
5
<
argc
:
# threshold = int(sys.argv[2])
mnk_size
=
int
(
sys
.
argv
[
3
])
dims
=
load_mnklist
(
sys
.
argv
[
4
:
4
+
mnk_size
],
0
,
-
1
)
dims
=
load_mnklist
(
sys
.
argv
[
4
+
mnk_size
:],
0
,
-
2
,
dims
)
mnklist
=
map
(
lambda
mnk
:
"_"
.
join
(
map
(
str
,
mnk
)),
sorted
(
dims
))
print
(
" "
.
join
(
mnklist
))
elif
3
==
argc
:
major
,
minor
,
update
,
patch
=
(
version_numbers
(
sys
.
argv
[
2
],
"release"
)
)
print
([
"0"
,
"1"
][
0
==
patch
])
elif
0
<=
arg1
:
if
0
==
arg1
and
3
==
argc
:
major
,
minor
,
update
,
patch
=
version_numbers
(
sys
.
argv
[
2
])
print
(
major
)
# soname version
else
:
version
,
branch
,
realversion
=
version_branch
()
major
,
minor
,
update
,
patch
=
version_numbers
(
version
)
if
1
==
arg1
:
print
(
major
)
elif
2
==
arg1
:
print
(
minor
)
elif
3
==
arg1
:
print
(
update
)
elif
4
==
arg1
:
print
(
patch
)
elif
""
!=
branch
:
print
(
"{}-{}"
.
format
(
branch
,
realversion
))
else
:
print
(
realversion
)
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
"{}: wrong ({}) number of arguments ('{}') given!"
.
format
(
sys
.
argv
[
0
],
argc
-
1
,
" "
.
join
(
sys
.
argv
[
1
:]))
)
third_party/libxsmm/scripts/libxsmm_version.sh
0 → 100755
View file @
c454d419
#!/usr/bin/env sh
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
GIT
=
$(
command
-v
git
)
SHIFT
=
0
if
[
"
$1
"
]
;
then
SHIFT
=
$1
fi
NAME
=
$(
${
GIT
}
rev-parse
--abbrev-ref
HEAD 2>/dev/null
)
MAIN
=
$(
${
GIT
}
describe
--tags
--match
"[0-9]*"
--abbrev
=
0 2>/dev/null
)
if
[
"
${
MAIN
}
"
]
;
then
VERSION
=
"
${
NAME
}
-
${
MAIN
}
"
REVC
=
$(
${
GIT
}
rev-list
--count
--no-merges
"
${
MAIN
}
"
..HEAD 2>/dev/null
)
else
VERSION
=
${
NAME
}
REVC
=
$(
${
GIT
}
rev-list
--count
--no-merges
HEAD 2>/dev/null
)
fi
echo
"
${
VERSION
}
-
$((
REVC+SHIFT
))
"
third_party/libxsmm/src/libxsmm_cpuid_arm.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include <libxsmm_cpuid.h>
#include <libxsmm_generator.h>
#include <libxsmm_memory.h>
#include <libxsmm_sync.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <signal.h>
#include <setjmp.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#if defined(_MSC_VER)
# define LIBXSMM_CPUID_ARM_ENC16(OP0, OP1, CRN, CRM, OP2) ( \
(((OP0) & 1) << 14) | \
(((OP1) & 7) << 11) | \
(((CRN) & 15) << 7) | \
(((CRM) & 15) << 3) | \
(((OP2) & 7) << 0))
# define ID_AA64ISAR1_EL1 LIBXSMM_CPUID_ARM_ENC16(0b11, 0b000, 0b0000, 0b0110, 0b001)
# define ID_AA64PFR0_EL1 LIBXSMM_CPUID_ARM_ENC16(0b11, 0b000, 0b0000, 0b0100, 0b000)
# define LIBXSMM_CPUID_ARM_MRS(RESULT, ID) RESULT = _ReadStatusReg(ID)
#else
# define LIBXSMM_CPUID_ARM_MRS(RESULT, ID) __asm__ __volatile__( \
"mrs %0," LIBXSMM_STRINGIFY(ID) : "=r"(RESULT))
#endif
#if defined(LIBXSMM_PLATFORM_AARCH64)
LIBXSMM_APIVAR_DEFINE
(
jmp_buf
internal_cpuid_arm_jmp_buf
);
LIBXSMM_API_INTERN
void
internal_cpuid_arm_sigill
(
int
/*signum*/
);
LIBXSMM_API_INTERN
void
internal_cpuid_arm_sigill
(
int
signum
)
{
void
(
*
const
handler
)(
int
)
=
signal
(
signum
,
internal_cpuid_arm_sigill
);
LIBXSMM_ASSERT
(
SIGILL
==
signum
);
if
(
SIG_ERR
!=
handler
)
longjmp
(
internal_cpuid_arm_jmp_buf
,
1
);
}
#endif
LIBXSMM_API
int
libxsmm_cpuid_arm
(
libxsmm_cpuid_info
*
info
)
{
static
int
result
=
LIBXSMM_TARGET_ARCH_UNKNOWN
;
#if defined(LIBXSMM_PLATFORM_AARCH64)
#if defined(__APPLE__) && defined(__arm64__)
result
=
LIBXSMM_AARCH64_V81
;
#else
if
(
LIBXSMM_TARGET_ARCH_UNKNOWN
==
result
)
{
/* avoid redetecting features */
void
(
*
const
handler
)(
int
)
=
signal
(
SIGILL
,
internal_cpuid_arm_sigill
);
result
=
LIBXSMM_AARCH64_V81
;
if
(
SIG_ERR
!=
handler
)
{
uint64_t
capability
;
/* 64-bit value */
if
(
0
==
setjmp
(
internal_cpuid_arm_jmp_buf
))
{
LIBXSMM_CPUID_ARM_MRS
(
capability
,
ID_AA64ISAR1_EL1
);
if
(
0xF
&
capability
)
{
/* DPB */
result
=
LIBXSMM_AARCH64_V82
;
if
(
0
==
setjmp
(
internal_cpuid_arm_jmp_buf
))
{
LIBXSMM_CPUID_ARM_MRS
(
capability
,
ID_AA64PFR0_EL1
);
if
(
0xF
&
(
capability
>>
32
))
{
/* SVE */
result
=
LIBXSMM_AARCH64_A64FX
;
}
}
}
}
/* restore original state */
signal
(
SIGILL
,
handler
);
}
if
(
NULL
!=
info
)
LIBXSMM_MEMZERO127
(
info
);
}
#endif
#else
# if !defined(NDEBUG)
static
int
error_once
=
0
;
if
(
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM WARNING: libxsmm_cpuid_arm called on non-ARM platform!
\n
"
);
}
# endif
if
(
NULL
!=
info
)
LIBXSMM_MEMZERO127
(
info
);
#endif
return
result
;
}
third_party/libxsmm/src/libxsmm_cpuid_x86.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include <libxsmm_generator.h>
#include <libxsmm_memory.h>
#include <libxsmm_sync.h>
#if !defined(_WIN32)
# include <sys/mman.h>
#endif
#if defined(LIBXSMM_PLATFORM_X86)
/* XGETBV: receive results (EAX, EDX) for eXtended Control Register (XCR). */
/* CPUID, receive results (EAX, EBX, ECX, EDX) for requested FUNCTION/SUBFN. */
#if defined(_MSC_VER)
/*defined(_WIN32) && !defined(__GNUC__)*/
# define LIBXSMM_XGETBV(XCR, EAX, EDX) { \
unsigned long long libxsmm_xgetbv_ = _xgetbv(XCR); \
EAX = (int)libxsmm_xgetbv_; \
EDX = (int)(libxsmm_xgetbv_ >> 32); \
}
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) { \
int libxsmm_cpuid_x86_[
/*4*/
] = { 0, 0, 0, 0 }; \
__cpuidex(libxsmm_cpuid_x86_, FUNCTION, SUBFN); \
EAX = (unsigned int)libxsmm_cpuid_x86_[0]; \
EBX = (unsigned int)libxsmm_cpuid_x86_[1]; \
ECX = (unsigned int)libxsmm_cpuid_x86_[2]; \
EDX = (unsigned int)libxsmm_cpuid_x86_[3]; \
}
# elif defined(__GNUC__) || !defined(_CRAYC)
# if (64 > (LIBXSMM_BITS))
LIBXSMM_EXTERN
LIBXSMM_RETARGETABLE
int
__get_cpuid
(
/* prototype */
unsigned
int
,
unsigned
int
*
,
unsigned
int
*
,
unsigned
int
*
,
unsigned
int
*
);
# define LIBXSMM_XGETBV(XCR, EAX, EDX) EAX = (EDX) = 0xFFFFFFFF
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) \
EAX = (EBX) = (EDX) = 0; ECX = (SUBFN); \
__get_cpuid(FUNCTION, &(EAX), &(EBX), &(ECX), &(EDX))
# else
/* 64-bit */
# define LIBXSMM_XGETBV(XCR, EAX, EDX) __asm__ __volatile__( \
".byte 0x0f, 0x01, 0xd0"
/*xgetbv*/
: "=a"(EAX), "=d"(EDX) : "c"(XCR) \
)
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) \
__asm__ __volatile__ (".byte 0x0f, 0xa2"
/*cpuid*/
\
: "=a"(EAX), "=b"(EBX), "=c"(ECX), "=d"(EDX) \
: "a"(FUNCTION), "b"(0), "c"(SUBFN), "d"(0) \
)
# endif
# else
/* legacy Cray Compiler */
# define LIBXSMM_XGETBV(XCR, EAX, EDX) EAX = (EDX) = 0
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) EAX = (EBX) = (ECX) = (EDX) = 0
# endif
#endif
#define LIBXSMM_CPUID_CHECK(VALUE, CHECK) ((CHECK) == ((CHECK) & (VALUE)))
LIBXSMM_API
int
libxsmm_cpuid_x86
(
libxsmm_cpuid_info
*
info
)
{
static
int
result
=
LIBXSMM_TARGET_ARCH_UNKNOWN
;
#if defined(LIBXSMM_PLATFORM_X86)
unsigned
int
eax
,
ebx
,
ecx
,
edx
;
LIBXSMM_CPUID_X86
(
0
,
0
/*ecx*/
,
eax
,
ebx
,
ecx
,
edx
);
if
(
1
<=
eax
)
{
/* CPUID max. leaf */
/* avoid redetecting features but redetect on request (info given) */
if
(
LIBXSMM_TARGET_ARCH_UNKNOWN
==
result
||
NULL
!=
info
)
{
int
feature_cpu
=
LIBXSMM_X86_GENERIC
,
feature_os
=
LIBXSMM_X86_GENERIC
,
has_context
=
0
;
unsigned
int
maxleaf
=
eax
;
# if defined(__linux__)
if
(
0
==
libxsmm_se
&&
LIBXSMM_TARGET_ARCH_UNKNOWN
==
result
)
{
FILE
*
const
selinux
=
fopen
(
"/sys/fs/selinux/enforce"
,
"rb"
);
if
(
NULL
!=
selinux
)
{
if
(
1
==
fread
(
&
libxsmm_se
,
1
/*sizeof(char)*/
,
1
/*count*/
,
selinux
))
{
libxsmm_se
=
(
'0'
!=
libxsmm_se
?
1
:
0
);
}
else
{
/* conservative assumption in case of read-error */
libxsmm_se
=
1
;
}
fclose
(
selinux
);
}
}
# elif defined(MAP_JIT)
libxsmm_se
=
1
;
# endif
LIBXSMM_CPUID_X86
(
1
,
0
/*ecx*/
,
eax
,
ebx
,
ecx
,
edx
);
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x00000001
))
{
/* SSE3(0x00000001) */
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x00100000
))
{
/* SSE42(0x00100000) */
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x10000000
))
{
/* AVX(0x10000000) */
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x00001000
))
{
/* FMA(0x00001000) */
unsigned
int
ecx2
;
LIBXSMM_CPUID_X86
(
7
,
0
/*ecx*/
,
eax
,
ebx
,
ecx2
,
edx
);
/* AVX512F(0x00010000), AVX512CD(0x10000000) */
if
(
LIBXSMM_CPUID_CHECK
(
ebx
,
0x10010000
))
{
/* Common */
/* AVX512DQ(0x00020000), AVX512BW(0x40000000), AVX512VL(0x80000000) */
if
(
LIBXSMM_CPUID_CHECK
(
ebx
,
0xC0020000
))
{
/* AVX512-Core */
if
(
LIBXSMM_CPUID_CHECK
(
ecx2
,
0x00000800
))
{
/* VNNI */
unsigned
int
edx2
;
/* we need to save edx for AMX check */
# if 0
/* no check required yet */
unsigned
int
ecx3
;
LIBXSMM_CPUID_X86
(
7
,
1
/*ecx*/
,
eax
,
ebx
,
ecx3
,
edx
);
# else
LIBXSMM_CPUID_X86
(
7
,
1
/*ecx*/
,
eax
,
ebx
,
ecx2
,
edx2
);
# endif
if
(
LIBXSMM_CPUID_CHECK
(
eax
,
0x00000020
))
{
/* BF16 */
feature_cpu
=
LIBXSMM_X86_AVX512_CPX
;
if
(
LIBXSMM_CPUID_CHECK
(
edx
,
0x03400000
))
{
/* AMX-TILE, AMX-INT8, AMX-BF16 */
feature_cpu
=
LIBXSMM_X86_AVX512_SPR
;
}
}
else
feature_cpu
=
LIBXSMM_X86_AVX512_CLX
;
/* CLX */
}
else
feature_cpu
=
LIBXSMM_X86_AVX512_CORE
;
/* SKX */
}
/* AVX512PF(0x04000000), AVX512ER(0x08000000) */
else
if
(
LIBXSMM_CPUID_CHECK
(
ebx
,
0x0C000000
))
{
/* AVX512-MIC */
if
(
LIBXSMM_CPUID_CHECK
(
edx
,
0x0000000C
))
{
/* KNM */
feature_cpu
=
LIBXSMM_X86_AVX512_KNM
;
}
else
feature_cpu
=
LIBXSMM_X86_AVX512_MIC
;
/* KNL */
}
else
feature_cpu
=
LIBXSMM_X86_AVX512
;
/* AVX512-Common */
}
else
feature_cpu
=
LIBXSMM_X86_AVX2
;
}
else
feature_cpu
=
LIBXSMM_X86_AVX
;
}
else
feature_cpu
=
LIBXSMM_X86_SSE42
;
}
else
feature_cpu
=
LIBXSMM_X86_SSE3
;
}
# if !defined(LIBXSMM_INTRINSICS_DEBUG)
LIBXSMM_ASSERT_MSG
(
LIBXSMM_STATIC_TARGET_ARCH
<=
LIBXSMM_MAX
(
LIBXSMM_X86_GENERIC
,
feature_cpu
),
"missed detecting ISA extensions"
);
/* coverity[dead_error_line] */
if
(
LIBXSMM_STATIC_TARGET_ARCH
>
feature_cpu
)
feature_cpu
=
LIBXSMM_STATIC_TARGET_ARCH
;
# endif
/* XSAVE/XGETBV(0x04000000), OSXSAVE(0x08000000) */
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x0C000000
))
{
/* OS SSE support */
feature_os
=
LIBXSMM_MIN
(
LIBXSMM_X86_SSE42
,
feature_cpu
);
if
(
LIBXSMM_X86_AVX
<=
feature_cpu
)
{
LIBXSMM_XGETBV
(
0
,
eax
,
edx
);
if
(
LIBXSMM_CPUID_CHECK
(
eax
,
0x00000006
))
{
/* OS XSAVE 256-bit */
feature_os
=
LIBXSMM_MIN
(
LIBXSMM_X86_AVX2
,
feature_cpu
);
if
(
LIBXSMM_CPUID_CHECK
(
eax
,
0x000000E0
))
{
/* OS XSAVE 512-bit */
feature_os
=
LIBXSMM_MIN
(
LIBXSMM_X86_AVX512_CPX
,
feature_cpu
);
if
(
LIBXSMM_X86_AVX512_SPR
<=
feature_cpu
&&
7
<=
maxleaf
&&
LIBXSMM_CPUID_CHECK
(
eax
,
0x00060000
))
/* OS XSAVE 512-bit */
{
feature_os
=
feature_cpu
;
/* unlimited AMX */
}
}
}
}
}
else
if
(
LIBXSMM_X86_GENERIC
<=
feature_cpu
)
{
/* assume FXSAVE, which should be fine
* 16 years after the first x86_64 OS
*/
feature_os
=
LIBXSMM_X86_SSE42
;
}
else
feature_os
=
LIBXSMM_TARGET_ARCH_GENERIC
;
has_context
=
(
LIBXSMM_STATIC_TARGET_ARCH
>=
feature_cpu
||
feature_os
>=
feature_cpu
)
?
1
:
0
;
if
(
LIBXSMM_TARGET_ARCH_UNKNOWN
==
result
&&
0
!=
libxsmm_verbosity
)
{
/* library code is expected to be mute */
# if !defined(LIBXSMM_TARGET_ARCH)
const
int
target_vlen32
=
libxsmm_cpuid_vlen32
(
feature_cpu
);
const
char
*
const
compiler_support
=
(
libxsmm_cpuid_vlen32
(
LIBXSMM_MAX_STATIC_TARGET_ARCH
)
<
target_vlen32
?
""
:
(((
2
<=
libxsmm_verbosity
||
0
>
libxsmm_verbosity
)
&&
LIBXSMM_MAX_STATIC_TARGET_ARCH
<
feature_cpu
)
?
"highly "
:
NULL
));
if
(
NULL
!=
compiler_support
)
{
const
char
*
const
name
=
libxsmm_cpuid_name
(
/* exclude MIC when running on Core processors */
(((
LIBXSMM_X86_AVX512_MIC
==
LIBXSMM_MAX_STATIC_TARGET_ARCH
)
||
(
LIBXSMM_X86_AVX512_KNM
==
LIBXSMM_MAX_STATIC_TARGET_ARCH
))
&&
(
LIBXSMM_X86_AVX512_CORE
<=
feature_cpu
))
?
LIBXSMM_X86_AVX2
:
LIBXSMM_MAX_STATIC_TARGET_ARCH
);
fprintf
(
stderr
,
"LIBXSMM WARNING: %soptimized non-JIT code paths are limited to
\"
%s
\"
!
\n
"
,
compiler_support
,
name
);
}
# endif
# if !defined(NDEBUG) && defined(__OPTIMIZE__)
fprintf
(
stderr
,
"LIBXSMM WARNING: library is optimized without -DNDEBUG and contains debug code!
\n
"
);
# endif
# if !defined(__APPLE__) || !defined(__MACH__)
/* permitted features */
if
(
0
==
has_context
)
{
fprintf
(
stderr
,
"LIBXSMM WARNING: detected CPU features are not permitted by the OS!
\n
"
);
if
(
0
==
libxsmm_se
)
{
fprintf
(
stderr
,
"LIBXSMM WARNING: downgraded code generation to supported features!
\n
"
);
}
}
# endif
}
/* macOS is faulting AVX-512 (on-demand larger state) */
result
=
feature_cpu
;
# if !defined(__APPLE__) || !defined(__MACH__)
# if 0
/* opportunistic */
if
(
0
==
libxsmm_se
)
# endif
{
/* only permitted features */
result
=
LIBXSMM_MIN
(
feature_cpu
,
feature_os
);
}
# endif
if
(
NULL
!=
info
)
{
LIBXSMM_CPUID_X86
(
0x80000007
,
0
/*ecx*/
,
eax
,
ebx
,
ecx
,
edx
);
info
->
constant_tsc
=
LIBXSMM_CPUID_CHECK
(
edx
,
0x00000100
);
info
->
has_context
=
has_context
;
}
}
}
else
{
if
(
NULL
!=
info
)
LIBXSMM_MEMZERO127
(
info
);
result
=
LIBXSMM_X86_GENERIC
;
}
#else
# if !defined(NDEBUG)
static
int
error_once
=
0
;
if
(
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM WARNING: libxsmm_cpuid_x86 called on non-x86 platform!
\n
"
);
}
# endif
if
(
NULL
!=
info
)
LIBXSMM_MEMZERO127
(
info
);
#endif
return
result
;
}
LIBXSMM_API
int
libxsmm_cpuid
(
void
)
{
#if defined(LIBXSMM_PLATFORM_X86)
return
libxsmm_cpuid_x86
(
NULL
/*info*/
);
#else
return
libxsmm_cpuid_arm
(
NULL
/*info*/
);
#endif
}
/**
* This implementation also accounts for non-x86 platforms,
* which not only allows to resolve any given ID but to
* fallback gracefully ("unknown").
*/
LIBXSMM_API
const
char
*
libxsmm_cpuid_name
(
int
id
)
{
const
char
*
target_arch
=
NULL
;
switch
(
id
)
{
case
LIBXSMM_X86_AVX512_SPR
:
{
target_arch
=
"spr"
;
}
break
;
case
LIBXSMM_X86_AVX512_CPX
:
{
target_arch
=
"cpx"
;
}
break
;
case
LIBXSMM_X86_AVX512_CLX
:
{
target_arch
=
"clx"
;
}
break
;
case
LIBXSMM_X86_AVX512_CORE
:
{
target_arch
=
"skx"
;
}
break
;
case
LIBXSMM_X86_AVX512_KNM
:
{
target_arch
=
"knm"
;
}
break
;
case
LIBXSMM_X86_AVX512_MIC
:
{
target_arch
=
"knl"
;
}
break
;
case
LIBXSMM_X86_AVX512
:
{
/* TODO: rework BE to use target ID instead of set of strings (target_arch = "avx3") */
target_arch
=
"hsw"
;
}
break
;
case
LIBXSMM_X86_AVX2
:
{
target_arch
=
"hsw"
;
}
break
;
case
LIBXSMM_X86_AVX
:
{
target_arch
=
"snb"
;
}
break
;
case
LIBXSMM_X86_SSE42
:
{
target_arch
=
"wsm"
;
}
break
;
case
LIBXSMM_X86_SSE3
:
{
target_arch
=
"sse3"
;
}
break
;
case
LIBXSMM_AARCH64_V81
:
{
target_arch
=
"aarch64"
;
}
break
;
case
LIBXSMM_AARCH64_A64FX
:
{
target_arch
=
"a64fx"
;
}
break
;
case
LIBXSMM_TARGET_ARCH_GENERIC
:
{
target_arch
=
"generic"
;
}
break
;
default:
if
(
LIBXSMM_X86_GENERIC
<=
id
)
{
target_arch
=
"x86_64"
;
}
else
{
target_arch
=
"unknown"
;
}
}
LIBXSMM_ASSERT
(
NULL
!=
target_arch
);
return
target_arch
;
}
/**
* This implementation also accounts for non-x86 platforms,
* which not only allows to resolve any given ID but to
* fallback gracefully (scalar).
*/
LIBXSMM_API
int
libxsmm_cpuid_vlen32
(
int
id
)
{
int
result
;
#if defined(LIBXSMM_PLATFORM_X86)
if
(
LIBXSMM_X86_AVX512
<=
id
)
{
result
=
16
;
}
else
if
(
LIBXSMM_X86_AVX
<=
id
)
{
result
=
8
;
}
else
if
(
LIBXSMM_X86_SSE42
<=
id
)
{
result
=
4
;
}
else
#elif defined(LIBXSMM_PLATFORM_AARCH64)
if
(
LIBXSMM_AARCH64_V81
==
id
)
{
result
=
4
;
}
else
if
(
LIBXSMM_AARCH64_A64FX
==
id
)
{
result
=
16
;
}
else
#else
LIBXSMM_UNUSED
(
id
);
#endif
{
/* scalar */
result
=
1
;
}
return
result
;
}
third_party/libxsmm/src/libxsmm_diff.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DIFF_H
#define LIBXSMM_DIFF_H
#include <libxsmm_intrinsics_x86.h>
#if !defined(LIBXSMM_DIFF_AVX512_ENABLED) && 0
# define LIBXSMM_DIFF_AVX512_ENABLED
#endif
#define LIBXSMM_DIFF_4_DECL(A) const uint32_t *
/*const*/
A = NULL
#define LIBXSMM_DIFF_4_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_4_LOAD(A, SRC) A = (const uint32_t*)(SRC)
#define LIBXSMM_DIFF_4(A, B, ...) ((unsigned char)(0 != (*(A) ^ (*(const uint32_t*)(B)))))
#define LIBXSMM_DIFF_8_DECL(A) const uint64_t *
/*const*/
A = NULL
#define LIBXSMM_DIFF_8_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_8_LOAD(A, SRC) A = (const uint64_t*)(SRC)
#define LIBXSMM_DIFF_8(A, B, ...) ((unsigned char)(0 != (*(A) ^ (*(const uint64_t*)(B)))))
#define LIBXSMM_DIFF_SSE_DECL(A) __m128i A = LIBXSMM_INTRINSICS_MM_UNDEFINED_SI128()
#define LIBXSMM_DIFF_SSE_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_SSE_LOAD(A, SRC) A = LIBXSMM_INTRINSICS_LOADU_SI128((const __m128i*)(SRC))
#define LIBXSMM_DIFF_SSE(A, B, ...) ((unsigned char)(0xFFFF != _mm_movemask_epi8(_mm_cmpeq_epi8( \
A, LIBXSMM_INTRINSICS_LOADU_SI128((const __m128i*)(B))))))
#if (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH)
/*|| defined(LIBXSMM_INTRINSICS_TARGET)*/
# define LIBXSMM_DIFF_16_DECL LIBXSMM_DIFF_SSE_DECL
# define LIBXSMM_DIFF_16_ASSIGN LIBXSMM_DIFF_SSE_ASSIGN
# define LIBXSMM_DIFF_16_LOAD LIBXSMM_DIFF_SSE_LOAD
# define LIBXSMM_DIFF_16 LIBXSMM_DIFF_SSE
#else
# define LIBXSMM_DIFF_16_DECL(A) const uint64_t *
/*const*/
A = NULL
# define LIBXSMM_DIFF_16_ASSIGN(A, B) (A) = (B)
# define LIBXSMM_DIFF_16_LOAD(A, SRC) A = (const uint64_t*)(SRC)
# define LIBXSMM_DIFF_16(A, B, ...) ((unsigned char)(0 != (((A)[0] ^ (*(const uint64_t*)(B))) | \
((A)[1] ^ ((const uint64_t*)(B))[1]))))
#endif
#define LIBXSMM_DIFF_AVX2_DECL(A) __m256i A = LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256()
#define LIBXSMM_DIFF_AVX2_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_AVX2_LOAD(A, SRC) A = _mm256_loadu_si256((const __m256i*)(SRC))
#define LIBXSMM_DIFF_AVX2(A, B, ...) ((unsigned char)(-1 != _mm256_movemask_epi8(_mm256_cmpeq_epi8( \
A, _mm256_loadu_si256((const __m256i*)(B))))))
#if (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_DIFF_32_DECL LIBXSMM_DIFF_AVX2_DECL
# define LIBXSMM_DIFF_32_ASSIGN LIBXSMM_DIFF_AVX2_ASSIGN
# define LIBXSMM_DIFF_32_LOAD LIBXSMM_DIFF_AVX2_LOAD
# define LIBXSMM_DIFF_32 LIBXSMM_DIFF_AVX2
#else
# define LIBXSMM_DIFF_32_DECL(A) LIBXSMM_DIFF_16_DECL(A); LIBXSMM_DIFF_16_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _))
# define LIBXSMM_DIFF_32_ASSIGN(A, B) LIBXSMM_DIFF_16_ASSIGN(A, B); LIBXSMM_DIFF_16_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_32_, B, _))
# define LIBXSMM_DIFF_32_LOAD(A, SRC) LIBXSMM_DIFF_16_LOAD(A, SRC); LIBXSMM_DIFF_16_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), (const uint64_t*)(SRC) + 2)
# define LIBXSMM_DIFF_32(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_16(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_16(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), (const uint64_t*)(B) + 2, __VA_ARGS__)))
#endif
#define LIBXSMM_DIFF_48_DECL(A) LIBXSMM_DIFF_16_DECL(A); LIBXSMM_DIFF_32_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _))
#define LIBXSMM_DIFF_48_ASSIGN(A, B) LIBXSMM_DIFF_16_ASSIGN(A, B); LIBXSMM_DIFF_32_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_48_, B, _))
#define LIBXSMM_DIFF_48_LOAD(A, SRC) LIBXSMM_DIFF_16_LOAD(A, SRC); LIBXSMM_DIFF_32_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), (const uint64_t*)(SRC) + 2)
#define LIBXSMM_DIFF_48(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_16(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_32(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), (const uint64_t*)(B) + 2, __VA_ARGS__)))
#define LIBXSMM_DIFF_64SW_DECL(A) LIBXSMM_DIFF_32_DECL(A); LIBXSMM_DIFF_32_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _))
#define LIBXSMM_DIFF_64SW_ASSIGN(A, B) LIBXSMM_DIFF_32_ASSIGN(A, B); LIBXSMM_DIFF_32_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_64_, B, _))
#define LIBXSMM_DIFF_64SW_LOAD(A, SRC) LIBXSMM_DIFF_32_LOAD(A, SRC); LIBXSMM_DIFF_32_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), (const uint64_t*)(SRC) + 4)
#define LIBXSMM_DIFF_64SW(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_32(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_32(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), (const uint64_t*)(B) + 4, __VA_ARGS__)))
#if defined(LIBXSMM_DIFF_AVX512_ENABLED)
# define LIBXSMM_DIFF_AVX512_DECL(A) __m512i A = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32()
# define LIBXSMM_DIFF_AVX512_ASSIGN(A, B) (A) = (B)
# define LIBXSMM_DIFF_AVX512_LOAD(A, SRC) A = _mm512_loadu_si512((const __m512i*)(SRC))
# define LIBXSMM_DIFF_AVX512(A, B, ...) ((unsigned char)(0xFFFF != (unsigned int)
/*_cvtmask16_u32*/
(_mm512_cmpeq_epi32_mask( \
A, _mm512_loadu_si512((const __m512i*)(B))))))
#else
# define LIBXSMM_DIFF_AVX512_DECL LIBXSMM_DIFF_64SW_DECL
# define LIBXSMM_DIFF_AVX512_ASSIGN LIBXSMM_DIFF_64SW_ASSIGN
# define LIBXSMM_DIFF_AVX512_LOAD LIBXSMM_DIFF_64SW_LOAD
# define LIBXSMM_DIFF_AVX512 LIBXSMM_DIFF_64SW
#endif
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_DIFF_64_DECL LIBXSMM_DIFF_AVX512_DECL
# define LIBXSMM_DIFF_64_ASSIGN LIBXSMM_DIFF_AVX512_ASSIGN
# define LIBXSMM_DIFF_64_LOAD LIBXSMM_DIFF_AVX512_LOAD
# define LIBXSMM_DIFF_64 LIBXSMM_DIFF_AVX512
#else
# define LIBXSMM_DIFF_64_DECL LIBXSMM_DIFF_64SW_DECL
# define LIBXSMM_DIFF_64_ASSIGN LIBXSMM_DIFF_64SW_ASSIGN
# define LIBXSMM_DIFF_64_LOAD LIBXSMM_DIFF_64SW_LOAD
# define LIBXSMM_DIFF_64 LIBXSMM_DIFF_64SW
#endif
#define LIBXSMM_DIFF_DECL(N, A) LIBXSMM_CONCATENATE3(LIBXSMM_DIFF_, N, _DECL)(A)
#define LIBXSMM_DIFF_LOAD(N, A, SRC) LIBXSMM_CONCATENATE3(LIBXSMM_DIFF_, N, _LOAD)(A, SRC)
#define LIBXSMM_DIFF(N) LIBXSMM_CONCATENATE(LIBXSMM_DIFF_, N)
#define LIBXSMM_DIFF_N(TYPE, RESULT, DIFF, A, BN, ELEMSIZE, STRIDE, HINT, N) { \
const char* libxsmm_diff_b_ = (const char*)(BN) + (size_t)(HINT) * (STRIDE); \
for (RESULT = (HINT); (RESULT) < (N); ++(RESULT)) { \
if (0 == DIFF(A, libxsmm_diff_b_, ELEMSIZE)) break; \
libxsmm_diff_b_ += (STRIDE); \
} \
if ((N) == (RESULT)) {
/* wrong hint */
\
TYPE libxsmm_diff_r_ = 0; \
libxsmm_diff_b_ = (const char*)(BN);
/* reset */
\
for (; libxsmm_diff_r_ < (HINT); ++libxsmm_diff_r_) { \
if (0 == DIFF(A, libxsmm_diff_b_, ELEMSIZE)) { \
RESULT = libxsmm_diff_r_; \
break; \
} \
libxsmm_diff_b_ += (STRIDE); \
} \
} \
}
/** Function type representing the diff-functionality. */
LIBXSMM_EXTERN_C
typedef
LIBXSMM_RETARGETABLE
unsigned
int
(
*
libxsmm_diff_function
)(
const
void
*
/*a*/
,
const
void
*
/*b*/
,
...
/*size*/
);
/** Compare two data blocks of 4 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_4
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 8 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_8
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 16 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_16
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 32 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_32
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 48 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_48
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 64 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_64
(
const
void
*
a
,
const
void
*
b
,
...);
#endif
/*LIBXSMM_DIFF_H*/
third_party/libxsmm/src/libxsmm_dnn.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include <libxsmm_dnn.h>
#include "libxsmm_main.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(_OPENMP)
# include <omp.h>
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
LIBXSMM_API_INTERN
void
libxsmm_dnn_init
(
int
target_arch
)
{
LIBXSMM_UNUSED
(
target_arch
);
}
LIBXSMM_API_INTERN
void
libxsmm_dnn_finalize
(
void
)
{
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_get_feature_map_blocks
(
int
C
,
int
K
,
int
*
C_block
,
int
*
K_block
,
int
*
fm_lp_block
,
libxsmm_dnn_datatype
datatype_in
,
libxsmm_dnn_datatype
datatype_out
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
int
ifmblock
=
0
;
int
ofmblock
=
0
;
int
lp_block
=
0
;
int
tmp_max_c_block
=
64
;
int
tmp_max_k_block
=
64
;
int
tmp_block
=
0
;
/* init libxsmm */
LIBXSMM_INIT
/* C */
if
(
((
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
&&
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
))
||
(
libxsmm_target_archid
<
LIBXSMM_X86_AVX512
)
)
{
tmp_max_c_block
=
32
;
}
else
if
(
libxsmm_target_archid
==
LIBXSMM_AARCH64_V81
)
{
tmp_max_c_block
=
16
;
}
if
(
C
<
tmp_max_c_block
)
{
ifmblock
=
C
;
}
else
{
for
(
tmp_block
=
1
;
tmp_block
<=
tmp_max_c_block
;
tmp_block
*=
2
)
{
if
(
C
%
tmp_block
==
0
)
ifmblock
=
tmp_block
;
}
}
/* K */
if
(
((
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
&&
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
))
||
(
libxsmm_target_archid
<
LIBXSMM_X86_AVX512
)
)
{
tmp_max_k_block
=
32
;
}
else
if
(
libxsmm_target_archid
==
LIBXSMM_AARCH64_V81
)
{
tmp_max_k_block
=
16
;
}
if
(
K
<
tmp_max_k_block
)
{
ofmblock
=
K
;
}
else
{
for
(
tmp_block
=
1
;
tmp_block
<=
tmp_max_k_block
;
tmp_block
*=
2
)
{
if
(
K
%
tmp_block
==
0
)
ofmblock
=
tmp_block
;
}
}
/* when do we need VNNI format? */
if
(
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
lp_block
=
1
;
}
else
if
(
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
lp_block
=
2
;
}
else
if
(
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
((
datatype_out
==
LIBXSMM_DNN_DATATYPE_I32
)
||
(
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
)
{
lp_block
=
2
;
}
else
if
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
lp_block
=
4
;
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
*
C_block
=
ifmblock
;
*
K_block
=
ofmblock
;
*
fm_lp_block
=
lp_block
;
return
status
;
}
LIBXSMM_API
const
char
*
libxsmm_dnn_get_error
(
libxsmm_dnn_err_t
code
)
{
switch
(
code
)
{
case
LIBXSMM_DNN_SUCCESS
:
return
"LIBXSMM DNN Success!"
;
case
LIBXSMM_DNN_WARN_FALLBACK
:
return
"LIBXSMM DNN Warning: Falling back to naive code as target is currently not supported by LIBXSMM!"
;
case
LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING
:
return
"LIBXSMM DNN Warning: RNN cell suboptimal minibatch blocking!"
;
case
LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING
:
return
"LIBXSMM DNN Warning: RNN cell suboptimal input feature blocking!"
;
case
LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING
:
return
"LIBXSMM DNN Warning: RNN cell suboptimal output feature blocking!"
;
case
LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING
:
return
"LIBXSMM DNN Warning: FC layer suboptimal minibatch blocking!"
;
case
LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING
:
return
"LIBXSMM DNN Warning: FC layer suboptimal input feature blocking!"
;
case
LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING
:
return
"LIBXSMM DNN Warning: FC layer suboptimal output feature blocking!"
;
case
LIBXSMM_DNN_ERR_GENERAL
:
return
"LIBXSMM DNN Error: General error occurred!"
;
case
LIBXSMM_DNN_ERR_CREATE_HANDLE
:
return
"LIBXSMM DNN Error: Handle creation failed!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
:
return
"LIBXSMM DNN Error: Requested datatype is not available!"
;
case
LIBXSMM_DNN_ERR_INVALID_BLOCKING
:
return
"LIBXSMM DNN Error: Requested Input/Output buffer size cannot be blocked!"
;
case
LIBXSMM_DNN_ERR_INVALID_HANDLE
:
return
"LIBXSMM DNN Error: An invalid handle was provided!"
;
case
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
:
return
"LIBXSMM DNN Error: Not all required sources and destinations have been bound to convolution!"
;
case
LIBXSMM_DNN_ERR_CREATE_TENSOR
:
return
"LIBXSMM DNN Error: Tensor creation failed!"
;
case
LIBXSMM_DNN_ERR_INVALID_TENSOR
:
return
"LIBXSMM DNN Error: Invalid tensor was specified!"
;
case
LIBXSMM_DNN_ERR_MISMATCH_TENSOR
:
return
"LIBXSMM DNN Error: Tensor doesn't match handle it should be bind to!"
;
case
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
:
return
"LIBXSMM DNN Error: Invalid handle or tensor!"
;
case
LIBXSMM_DNN_ERR_INVALID_KIND
:
return
"LIBXSMM DNN Error: Invalid convolution kind!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW
:
return
"LIBXSMM DNN Error: NCHW format is currently not natively supported by LIBXSMM!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT
:
return
"LIBXSMM DNN Error: Unsupported destination format when copying data!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT
:
return
"LIBXSMM DNN Error: Unsupported source format when copying data!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
:
return
"LIBXSMM DNN Error: Unsupported format when requesting a convolution!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS
:
return
"LIBXSMM DNN Error: KCRS format is currently not natively supported by LIBXSMM!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
:
return
"LIBXSMM DNN Error: Invalid format was specified!"
;
case
LIBXSMM_DNN_ERR_CREATE_LAYOUT
:
return
"LIBXSMM DNN Error: Layout creation failed!"
;
case
LIBXSMM_DNN_ERR_INVALID_LAYOUT
:
return
"LIBXSMM DNN Error: Invalid layout was specified!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
:
return
"LIBXSMM DNN Error: Unsupported architecture!"
;
case
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
:
return
"LIBXSMM DNN Error: scratch binding failed as scratch was not allocated!"
;
case
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
:
return
"LIBXSMM DNN Error: an unknown tensor type was provided!"
;
case
LIBXSMM_DNN_ERR_INVALID_ALGO
:
return
"LIBXSMM DNN Error: Invalid algorithm was specified!"
;
case
LIBXSMM_DNN_ERR_INVALID_PADDING
:
return
"LIBXSMM DNN Error: Invalid padding was specified!"
;
case
LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL
:
return
"LIBXSMM DNN Error: time steps should be >= 2 for RNN/LSTM!"
;
case
LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS
:
return
"LIBXSMM DNN Error: failed to create internal layout arrays!"
;
case
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
:
return
"LIBXSMM DNN Error: the requested functionality is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER
:
return
"LIBXSMM DNN Error: the requested order of fusion in batch norm is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION
:
return
"LIBXSMM DNN Error: the requested fusion in batch norm is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN
:
return
"LIBXSMM DNN Error: Unsupported format when requesting a fused batch norm!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
:
return
"LIBXSMM DNN Error: Unsupported pooling operations was requested!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_FC
:
return
"LIBXSMM DNN Error: Unsupported format when requesting a fullyconnected layer!"
;
case
LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN
:
return
"LIBXSMM DNN Error: max sequence length is shorter than sequence length we attempt to set!"
;
case
LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER
:
return
"LIBXSMM DNN Error: the requested order of fusion in group norm is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION
:
return
"LIBXSMM DNN Error: the requested fusion in group norm is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION
:
return
"LIBXSMM DNN Error: the requested fusion in fullyconnected is right now not implemented!"
;
default:
return
"LIBXSMM DNN Error: Unknown error or warning occurred!"
;
}
}
LIBXSMM_API
size_t
libxsmm_dnn_typesize
(
libxsmm_dnn_datatype
datatype
)
{
switch
(
datatype
)
{
case
LIBXSMM_DNN_DATATYPE_F32
:
return
4
;
case
LIBXSMM_DNN_DATATYPE_I32
:
return
4
;
case
LIBXSMM_DNN_DATATYPE_BF16
:
return
2
;
case
LIBXSMM_DNN_DATATYPE_I16
:
return
2
;
case
LIBXSMM_DNN_DATATYPE_I8
:
return
1
;
/* no error expected as enumeration really arrives at an enum; compiler-checked */
default:
return
1
;
}
}
LIBXSMM_API
size_t
libxsmm_dnn_get_simd_width
(
libxsmm_dnn_datatype
datatype
)
{
size_t
l_cl_width_bytes
;
/* init libxsmm */
LIBXSMM_INIT
if
(
libxsmm_target_archid
==
LIBXSMM_X86_GENERIC
||
libxsmm_target_archid
==
LIBXSMM_X86_SSE3
||
libxsmm_target_archid
==
LIBXSMM_X86_SSE42
)
{
l_cl_width_bytes
=
16
;
}
else
if
(
libxsmm_target_archid
==
LIBXSMM_X86_AVX2
||
libxsmm_target_archid
==
LIBXSMM_X86_AVX
)
{
l_cl_width_bytes
=
32
;
}
else
{
l_cl_width_bytes
=
64
;
}
return
l_cl_width_bytes
/
libxsmm_dnn_typesize
(
datatype
);
}
LIBXSMM_API_INLINE
float
libxsmm_internal_get_max
(
float
*
in_buffer
,
int
length
)
{
float
absmax_value
=
LIBXSMM_ABS
(
in_buffer
[
0
]);
int
i
=
0
;
#ifdef _OPENMP
LIBXSMM_OMP_VAR
(
i
);
# pragma omp parallel private(i)
{
float
my_absmax_value
=
absmax_value
;
# pragma omp for
for
(
i
=
0
;
i
<
length
;
++
i
)
{
if
(
LIBXSMM_ABS
(
in_buffer
[
i
])
>
my_absmax_value
)
{
my_absmax_value
=
LIBXSMM_ABS
(
in_buffer
[
i
]);
}
}
# pragma omp critical
{
if
(
my_absmax_value
>
absmax_value
)
{
absmax_value
=
my_absmax_value
;
}
}
}
#else
for
(
i
=
1
;
i
<
length
;
++
i
)
{
if
(
LIBXSMM_ABS
(
in_buffer
[
i
])
>
absmax_value
)
{
absmax_value
=
LIBXSMM_ABS
(
in_buffer
[
i
]);
}
}
#endif
return
absmax_value
;
}
LIBXSMM_API_INLINE
unsigned
char
libxsmm_internal_get_max_exp
(
float
*
in_buffer
,
int
length
)
{
libxsmm_intfloat
val_exp
;
unsigned
char
max_exp
=
0
;
/* bit-wise conversion to int */
val_exp
.
f
=
libxsmm_internal_get_max
(
in_buffer
,
length
);
/* shift by mantissa to the right and convert to char */
max_exp
=
(
unsigned
char
)((
val_exp
.
ui
&
LIBXSMM_DNN_MASK_ABS_F32
)
>>
LIBXSMM_DNN_MANT_SZ_F32
);
return
max_exp
;
}
LIBXSMM_API_INLINE
short
libxsmm_internal_quantize_scalar_no_scf
(
float
input
,
unsigned
char
max_exp
,
unsigned
char
add_shift
,
int
round_mode
)
{
libxsmm_intfloat
value
;
unsigned
int
qvalue
=
0
;
unsigned
int
mant
=
0
;
unsigned
int
sign
=
0
;
unsigned
char
rhs
=
0
;
unsigned
char
exp_off
=
0
;
/* init libxsmm */
LIBXSMM_INIT
/* in case of zero we don't need to do anything */
if
(
LIBXSMM_FEQ
(
input
,
0
))
{
qvalue
=
0
;
}
else
{
/* let's get a float copy to work on */
/* vinp = LIBXSMM_INTRINSICS_MM512_LOAD_PS( in_buffer[i] ); */
value
.
f
=
input
;
/* let's compute the offset of the current exp at pos i from max offset, we need to mask the sign bit though */
/*__m512i vexp = _mm512_cvtps_epi32(_mm512_getexp_ps (vinp));
__m512i vexp_off = _mm512_sub_epi32(maxexpf, vexp);*/
exp_off
=
(
unsigned
char
)(
max_exp
-
((
value
.
ui
&
LIBXSMM_DNN_MASK_ABS_F32
)
>>
LIBXSMM_DNN_MANT_SZ_F32
));
/* cut out mantissa and set leading bit */
/*__m512i mmask = _mm512_set1_epi32(LIBXSMM_DNN_MASK_MANT_F32);
__m512i vmant = _mm512_or_epi32(_mm512_set1_epi32(0x1 << LIBXSMM_DNN_MANT_SZ_F32), _mm512_and_epi32( _mm512_castps_si512( vinp ), mmask));*/
mant
=
((
0x1
<<
LIBXSMM_DNN_MANT_SZ_F32
)
|
(
value
.
ui
&
LIBXSMM_DNN_MASK_MANT_F32
));
/* extract sign */
/* __mmask16 smask = _mm512_cmplt_ps_mask (inp, _mm512_set1_ps(0)); */
sign
=
((
value
.
ui
&
LIBXSNN_DNN_MASK_SIGN_F32
)
>>
(
LIBXSMM_DNN_SZ_F32
-
1
));
/* calculate rhs, be aware of the now explicit leading bit, @TODO add DFP8/4 */
rhs
=
(
unsigned
char
)((
LIBXSMM_DNN_MANT_SZ_F32
+
1
)
-
LIBXSMM_DNN_MANT_DFP16
+
exp_off
+
add_shift
);
/* some safety, to generate 0 when we fall off quant region, @TODO issue a LIBXSMM WARNING: that we shifted out the entire mantissa */
if
(
rhs
>
(
LIBXSMM_DNN_MANT_SZ_F32
+
1
))
{
rhs
=
(
LIBXSMM_DNN_MANT_SZ_F32
+
1
);
}
/* finally shift the value into the region we need, this is now a 15-add_rhs bit number for the max value in in_buffer */
qvalue
=
(
mant
>>
rhs
);
/* handle sign, 2 complement */
if
(
(
sign
>
0
)
&&
(
qvalue
>
0
)
)
{
qvalue
=
(
~
qvalue
+
1
);
}
if
(
round_mode
==
LIBXSMM_DNN_QUANT_BIAS_ROUND
)
{
/* biased rounding towards next bigger number */
/* first let's determine in the original number if we need a bias rounding, @TODO need fix for F64 */
int
bias_needed
=
(
mant
&
(
0x3
<<
(
rhs
-
2
)));
/* apply bias */
if
(
bias_needed
>
0
)
{
qvalue
++
;
}
}
else
if
(
round_mode
==
LIBXSMM_DNN_QUANT_NEAREST_ROUND
)
{
int
nearest_needed
=
(
mant
&
(
0x1
<<
(
rhs
-
1
)));
/* apply rounding */
if
((
nearest_needed
>
0
)
&&
(
rhs
>
1
))
{
qvalue
++
;
}
}
else
if
(
round_mode
==
LIBXSMM_DNN_QUANT_STOCH_ROUND
)
{
/* stochastic rounding, as implemented in the IBM paper from 2015, @TODO, fix F64 and DFP8 */
const
float
eps
=
LIXSMMM_DNN_RES_DFP16
;
/* coverity[dont_call] */
const
float
r
=
(
float
)
rand
();
libxsmm_intfloat
fvalue
;
float
p
,
q
;
/* masking all bits which will be shifted out */
fvalue
.
ui
=
value
.
ui
&
((
LIBXSMM_DNN_MASK_FULL_F32
)
<<
rhs
);
/* drawing a random number */
p
=
r
/
((
float
)
RAND_MAX
);
q
=
(
input
-
fvalue
.
f
)
/
eps
;
/* apply rounding if needed */
if
((
p
+
q
)
>
0
.
5
f
)
{
++
qvalue
;
}
}
else
{
/* do nothing about rounding, just chop */
}
}
return
(
short
)
qvalue
;
}
/* @TODO make this routine aware of any int type */
LIBXSMM_API
void
libxsmm_dnn_quantize
(
float
*
in_buffer
,
short
*
out_buffer
,
int
length
,
unsigned
char
add_shift
,
unsigned
char
*
scf
,
int
round_mode
)
{
int
i
=
0
;
/* init libxsmm */
LIBXSMM_INIT
/* in case we are using FP-Mul based quantization we use a different path for now
@TODO let's unify the paths by using the similar vectorization for both */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_FPHW_ROUND
)
{
const
float
max_value
=
libxsmm_internal_get_max
(
in_buffer
,
length
);
int
maxexp
=
0
;
/* take return value of LIBXSMM_FREXPF to mute static analysis issue */
float
scfq
=
LIBXSMM_FREXPF
(
max_value
,
&
maxexp
);
maxexp
-=
(
15
/*LIBXSMM_DNN_MANT_DFP16?*/
-
add_shift
);
scfq
=
libxsmm_sexp2_i8i
(
-
maxexp
);
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
if
(
length
%
16
==
0
)
{
__m512
vscfq
=
_mm512_set1_ps
(
scfq
);
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for
(
i
=
0
;
i
<
length
;
i
+=
16
)
{
_mm256_stream_si256
(
(
__m256i
*
)
&
(
out_buffer
[
i
]),
LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16
(
&
(
in_buffer
[
i
]),
vscfq
)
);
}
}
else
{
#endif
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for
(
i
=
0
;
i
<
length
;
++
i
)
{
out_buffer
[
i
]
=
(
short
)
LIBXSMM_ROUNDF
(
in_buffer
[
i
]
*
scfq
);
}
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
}
#endif
/* @TODO, we need to potentially fix this unsigned char problem */
#if !defined(NDEBUG)
/* library code is expected to be mute */
if
(
maxexp
>
0
)
{
fprintf
(
stderr
,
"error quant fil
\n
"
);
}
#endif
*
scf
=
(
unsigned
char
)(
-
maxexp
);
}
else
{
/* get max exponent */
unsigned
char
max_exp
=
libxsmm_internal_get_max_exp
(
in_buffer
,
length
);
/* if we go for stochastic rounding, let's initialize random seed */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_STOCH_ROUND
)
{
srand
(
libxsmm_timer_tick
()
%
((
unsigned
int
)
-
1
));
}
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for
(
i
=
0
;
i
<
length
;
++
i
)
{
out_buffer
[
i
]
=
libxsmm_internal_quantize_scalar_no_scf
(
in_buffer
[
i
],
max_exp
,
add_shift
,
round_mode
);
}
*
scf
=
(
unsigned
char
)(
14
-
add_shift
-
(
max_exp
-
127
));
}
}
LIBXSMM_API
void
libxsmm_dnn_quantize_act
(
float
*
in_buffer
,
short
*
out_buffer
,
unsigned
int
N
,
unsigned
int
C
,
unsigned
int
H
,
unsigned
int
W
,
unsigned
int
cblk_f32
,
unsigned
int
cblk_i16
,
unsigned
int
lp_blk
,
unsigned
char
add_shift
,
unsigned
char
*
scf
,
int
round_mode
)
{
LIBXSMM_VLA_DECL
(
5
,
const
float
,
in
,
in_buffer
,
C
/
cblk_f32
,
H
,
W
,
cblk_f32
);
LIBXSMM_VLA_DECL
(
6
,
short
,
out
,
out_buffer
,
C
/
(
cblk_i16
*
lp_blk
),
H
,
W
,
cblk_i16
,
lp_blk
);
const
unsigned
int
cblk
=
C
/
(
cblk_i16
*
lp_blk
);
int
i1
=
0
,
i2
=
0
,
i3
=
0
,
i4
=
0
,
i5
,
i6
;
/* init libxsmm */
LIBXSMM_INIT
/* some quick and dirty checks */
assert
((
C
%
cblk_f32
)
==
0
);
assert
((
C
%
cblk_i16
)
==
0
);
/* in case we are using FP-Mul based quantization we use a different path for now
@TODO let's unify the paths by using the similar vectorization for both */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_FPHW_ROUND
)
{
const
float
max_value
=
libxsmm_internal_get_max
(
in_buffer
,
N
*
C
*
H
*
W
);
int
maxexp
=
0
;
/* take return value of LIBXSMM_FREXPF to mute static analysis issue */
float
scfq
=
LIBXSMM_FREXPF
(
max_value
,
&
maxexp
);
maxexp
-=
(
15
/*LIBXSMM_DNN_MANT_DFP16?*/
-
add_shift
);
scfq
=
libxsmm_sexp2_i8i
(
-
maxexp
);
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
if
(
(
cblk_f32
==
16
)
&&
(
cblk_i16
*
lp_blk
==
16
)
)
{
__m512
vscfq
=
_mm512_set1_ps
(
scfq
);
#ifdef _OPENMP
LIBXSMM_OMP_VAR
(
i1
);
# pragma omp parallel for private(i1)
#endif
for
(
i1
=
0
;
i1
<
(
int
)(
N
*
C
*
H
*
W
);
i1
+=
16
)
{
_mm256_stream_si256
(
(
__m256i
*
)
&
(
out_buffer
[
i1
]),
LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16
(
&
(
in_buffer
[
i1
]),
vscfq
)
);
}
}
else
{
#endif
#ifdef _OPENMP
LIBXSMM_OMP_VAR
(
i1
);
LIBXSMM_OMP_VAR
(
i2
);
LIBXSMM_OMP_VAR
(
i3
);
LIBXSMM_OMP_VAR
(
i4
);
LIBXSMM_OMP_VAR
(
i5
);
LIBXSMM_OMP_VAR
(
i6
);
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
N
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
H
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
W
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
(
int
)
cblk_i16
;
++
i5
)
{
for
(
i6
=
0
;
i6
<
(
int
)
lp_blk
;
++
i6
)
{
const
int
fi1
=
i1
;
const
int
fi2
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i6
)
/
cblk_f32
;
const
int
fi3
=
i3
;
const
int
fi4
=
i4
;
const
int
fi5
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i6
)
%
cblk_f32
;
LIBXSMM_VLA_ACCESS
(
6
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
cblk
,
H
,
W
,
cblk_i16
,
lp_blk
)
=
(
short
)
LIBXSMM_ROUNDF
(
LIBXSMM_VLA_ACCESS
(
5
,
in
,
fi1
,
fi2
,
fi3
,
fi4
,
fi5
,
C
/
cblk_f32
,
H
,
W
,
cblk_f32
)
*
scfq
);
}
}
}
}
}
}
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
}
#endif
/* @TODO, we need to potentially fix this unsigned char problem */
#if !defined(NDEBUG)
/* library code is expected to be mute */
if
(
maxexp
>
0
)
{
fprintf
(
stderr
,
"error quant act
\n
"
);
}
#endif
*
scf
=
(
unsigned
char
)(
-
maxexp
);
}
else
{
/* get max exponent */
unsigned
char
max_exp
=
libxsmm_internal_get_max_exp
(
in_buffer
,
N
*
C
*
H
*
W
);
/* if we go for stochastic rounding, let's initialize random seed */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_STOCH_ROUND
)
{
srand
(
libxsmm_timer_tick
()
%
((
unsigned
int
)
-
1
));
}
#ifdef _OPENMP
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
N
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
H
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
W
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
(
int
)
cblk_i16
;
++
i5
)
{
for
(
i6
=
0
;
i6
<
(
int
)
lp_blk
;
++
i6
)
{
const
int
fi1
=
i1
;
const
int
fi2
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i6
)
/
cblk_f32
;
const
int
fi3
=
i3
;
const
int
fi4
=
i4
;
const
int
fi5
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i6
)
%
cblk_f32
;
LIBXSMM_VLA_ACCESS
(
6
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
cblk
,
H
,
W
,
cblk_i16
,
lp_blk
)
=
libxsmm_internal_quantize_scalar_no_scf
(
LIBXSMM_VLA_ACCESS
(
5
,
in
,
fi1
,
fi2
,
fi3
,
fi4
,
fi5
,
C
/
cblk_f32
,
H
,
W
,
cblk_f32
),
max_exp
,
add_shift
,
round_mode
);
}
}
}
}
}
}
*
scf
=
(
unsigned
char
)(
14
-
add_shift
-
(
max_exp
-
127
));
}
}
LIBXSMM_API
void
libxsmm_dnn_quantize_fil
(
float
*
in_buffer
,
short
*
out_buffer
,
unsigned
int
K
,
unsigned
int
C
,
unsigned
int
R
,
unsigned
int
S
,
unsigned
int
cblk_f32
,
unsigned
int
cblk_i16
,
unsigned
int
kblk_f32
,
unsigned
int
kblk_i16
,
unsigned
int
lp_blk
,
unsigned
char
add_shift
,
unsigned
char
*
scf
,
int
round_mode
)
{
LIBXSMM_VLA_DECL
(
6
,
const
float
,
in
,
in_buffer
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
);
LIBXSMM_VLA_DECL
(
7
,
short
,
out
,
out_buffer
,
C
/
(
cblk_i16
*
lp_blk
),
R
,
S
,
cblk_i16
,
kblk_i16
,
lp_blk
);
unsigned
int
cblk
=
C
/
(
cblk_i16
*
lp_blk
);
unsigned
int
kblk
=
K
/
kblk_i16
;
int
i1
=
0
,
i2
=
0
,
i3
=
0
,
i4
=
0
,
i5
,
i6
,
i7
;
/* some quick and dirty checks */
assert
((
C
%
cblk_f32
)
==
0
);
assert
((
C
%
(
cblk_i16
*
lp_blk
))
==
0
);
assert
((
K
%
kblk_f32
)
==
0
);
assert
((
K
%
kblk_i16
)
==
0
);
assert
((
lp_blk
%
2
)
==
0
);
/* init libxsmm */
LIBXSMM_INIT
/* in case we are using FP-Mul based quantization we use a different path for now
@TODO let's unify the paths by using the similar vectorization for both */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_FPHW_ROUND
)
{
const
float
max_value
=
libxsmm_internal_get_max
(
in_buffer
,
K
*
C
*
R
*
S
);
int
maxexp
=
0
;
/* take return value of LIBXSMM_FREXPF to mute static analysis issue */
float
scfq
=
LIBXSMM_FREXPF
(
max_value
,
&
maxexp
);
maxexp
-=
(
15
/*LIBXSMM_DNN_MANT_DFP16?*/
-
add_shift
);
scfq
=
libxsmm_sexp2_i8i
(
-
maxexp
);
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
if
(
(
kblk_f32
==
16
)
&&
(
cblk_f32
==
16
)
&&
(
kblk_i16
==
16
)
&&
(
cblk_i16
*
lp_blk
==
16
)
)
{
const
__m512
vscfq
=
_mm512_set1_ps
(
scfq
);
const
__m512i
permute_compact_idx
=
_mm512_set_epi32
(
15
,
14
,
13
,
12
,
7
,
6
,
5
,
4
,
11
,
10
,
9
,
8
,
3
,
2
,
1
,
0
);
#ifdef _OPENMP
# pragma omp parallel for private(i1, i2, i3, i4, i5) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
kblk
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
R
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
S
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
16
;
i5
+=
2
)
{
__m256i
even_ch
=
LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16
(
&
LIBXSMM_VLA_ACCESS
(
6
,
in
,
i1
,
i2
,
i3
,
i4
,
i5
+
0
,
0
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
),
vscfq
);
__m256i
odd_ch
=
LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16
(
&
LIBXSMM_VLA_ACCESS
(
6
,
in
,
i1
,
i2
,
i3
,
i4
,
i5
+
1
,
0
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
),
vscfq
);
__m256i
compressed_lo
=
_mm256_unpacklo_epi16
(
even_ch
,
odd_ch
);
__m256i
compressed_hi
=
_mm256_unpackhi_epi16
(
even_ch
,
odd_ch
);
__m512i
compact
=
_mm512_inserti64x4
(
_mm512_setzero_si512
(),
compressed_lo
,
0
);
compact
=
_mm512_inserti64x4
(
compact
,
compressed_hi
,
1
);
compact
=
_mm512_permutexvar_epi32
(
permute_compact_idx
,
compact
);
LIBXSMM_INTRINSICS_MM512_STREAM_SI512
(
(
void
*
)
&
LIBXSMM_VLA_ACCESS
(
7
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
/
2
,
0
,
0
,
cblk
,
R
,
S
,
cblk_i16
,
kblk_i16
,
lp_blk
),
compact
);
}
}
}
}
}
}
else
{
#endif
#ifdef _OPENMP
LIBXSMM_OMP_VAR
(
i1
);
LIBXSMM_OMP_VAR
(
i2
);
LIBXSMM_OMP_VAR
(
i3
);
LIBXSMM_OMP_VAR
(
i4
);
LIBXSMM_OMP_VAR
(
i5
);
LIBXSMM_OMP_VAR
(
i6
);
LIBXSMM_OMP_VAR
(
i7
);
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
kblk
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
R
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
S
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
(
int
)
cblk_i16
;
++
i5
)
{
for
(
i6
=
0
;
i6
<
(
int
)
kblk_i16
;
++
i6
)
{
for
(
i7
=
0
;
i7
<
(
int
)
lp_blk
;
++
i7
)
{
const
int
fi1
=
((
i1
*
kblk_i16
)
+
i6
)
/
kblk_f32
;
const
int
fi2
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i7
)
/
cblk_f32
;
const
int
fi3
=
i3
;
const
int
fi4
=
i4
;
const
int
fi5
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i7
)
%
cblk_f32
;
const
int
fi6
=
((
i1
*
kblk_i16
)
+
i6
)
%
kblk_f32
;
LIBXSMM_VLA_ACCESS
(
7
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
i7
,
cblk
,
R
,
S
,
cblk_i16
,
kblk_i16
,
lp_blk
)
=
(
short
)
LIBXSMM_ROUNDF
(
LIBXSMM_VLA_ACCESS
(
6
,
in
,
fi1
,
fi2
,
fi3
,
fi4
,
fi5
,
fi6
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
)
*
scfq
);
}
}
}
}
}
}
}
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
}
#endif
/* @TODO, we need to potentially fix this unsigned char problem */
#if !defined(NDEBUG)
/* library code is expected to be mute */
if
(
maxexp
>
0
)
{
fprintf
(
stderr
,
"error quant fil
\n
"
);
}
#endif
*
scf
=
(
unsigned
char
)(
-
maxexp
);
}
else
{
/* get max exponent */
unsigned
char
max_exp
=
libxsmm_internal_get_max_exp
(
in_buffer
,
K
*
C
*
R
*
S
);
/* if we go for stochastic rounding, let's initialize random seed */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_STOCH_ROUND
)
{
srand
(
libxsmm_timer_tick
()
%
((
unsigned
int
)
-
1
));
}
#ifdef _OPENMP
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
kblk
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
R
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
S
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
(
int
)
cblk_i16
;
++
i5
)
{
for
(
i6
=
0
;
i6
<
(
int
)
kblk_i16
;
++
i6
)
{
for
(
i7
=
0
;
i7
<
(
int
)
lp_blk
;
++
i7
)
{
const
int
fi1
=
((
i1
*
kblk_i16
)
+
i6
)
/
kblk_f32
;
const
int
fi2
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i7
)
/
cblk_f32
;
const
int
fi3
=
i3
;
const
int
fi4
=
i4
;
const
int
fi5
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i7
)
%
cblk_f32
;
const
int
fi6
=
((
i1
*
kblk_i16
)
+
i6
)
%
kblk_f32
;
LIBXSMM_VLA_ACCESS
(
7
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
i7
,
cblk
,
R
,
S
,
cblk_i16
,
kblk_i16
,
lp_blk
)
=
libxsmm_internal_quantize_scalar_no_scf
(
LIBXSMM_VLA_ACCESS
(
6
,
in
,
fi1
,
fi2
,
fi3
,
fi4
,
fi5
,
fi6
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
),
max_exp
,
add_shift
,
round_mode
);
}
}
}
}
}
}
}
*
scf
=
(
unsigned
char
)(
14
-
add_shift
-
(
max_exp
-
127
));
}
}
LIBXSMM_API
void
libxsmm_dnn_dequantize
(
short
*
in_buffer
,
float
*
out_buffer
,
int
length
,
unsigned
char
scf
)
{
const
float
val_exp
=
libxsmm_sexp2_i8i
(
-
scf
);
int
i
=
0
;
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for
(
i
=
0
;
i
<
length
;
++
i
)
{
out_buffer
[
i
]
=
((
float
)
in_buffer
[
i
])
*
val_exp
;
}
}
LIBXSMM_API
void
libxsmm_truncate_convert_f32_bf16
(
const
float
*
in
,
libxsmm_bfloat16
*
out
,
unsigned
int
length
)
{
unsigned
int
i
=
0
;
/* truncate buffer to bf16 */
for
(
i
=
0
;
i
<
length
;
++
i
)
{
libxsmm_bfloat16_hp
t
;
t
.
f
=
in
[
i
];
out
[
i
]
=
t
.
i
[
1
];
}
}
LIBXSMM_API
void
libxsmm_rnaz_convert_fp32_bf16
(
const
float
*
in
,
libxsmm_bfloat16
*
out
,
unsigned
int
len
)
{
unsigned
int
i
=
0
;
/* truncate buffer to bf16 */
for
(
i
=
0
;
i
<
len
;
++
i
)
{
unsigned
int
int_round
=
0
;
unsigned
int
do_round
=
1
;
int_round
=
*
((
unsigned
int
*
)
&
(
in
[
i
]));
/* we don't round NaN and inf */
if
(
(
int_round
&
0x7f800000
)
==
0x7f800000
)
{
do_round
=
0
;
}
/* perform round nearest tie away from zero */
if
(
do_round
!=
0
)
{
int_round
=
int_round
+
0x00008000
;
}
/* create the bf16 value by shifting out the lower 16bits */
int_round
=
int_round
>>
16
;
out
[
i
]
=
(
libxsmm_bfloat16
)
int_round
;
}
}
LIBXSMM_API
void
libxsmm_rne_convert_fp32_bf16
(
const
float
*
in
,
libxsmm_bfloat16
*
out
,
unsigned
int
len
)
{
unsigned
int
i
=
0
;
/* truncate buffer to bf16 */
for
(
i
=
0
;
i
<
len
;
++
i
)
{
unsigned
int
int_round
=
0
;
unsigned
int
do_round
=
1
;
int_round
=
*
((
unsigned
int
*
)
&
(
in
[
i
]));
/* we don't round NaN and inf */
if
(
(
int_round
&
0x7f800000
)
==
0x7f800000
)
{
do_round
=
0
;
}
/* perform round nearest tie even */
if
(
do_round
!=
0
)
{
unsigned
int
fixup
=
(
int_round
>>
16
)
&
1
;
int_round
=
int_round
+
0x00007fff
+
fixup
;
}
/* create the bf16 value by shifting out the lower 16bits */
int_round
=
int_round
>>
16
;
out
[
i
]
=
(
unsigned
short
)
int_round
;
}
}
LIBXSMM_API
void
libxsmm_convert_bf16_f32
(
const
libxsmm_bfloat16
*
in
,
float
*
out
,
unsigned
int
length
)
{
unsigned
int
i
=
0
;
/* up-convert is super simple */
for
(
i
=
0
;
i
<
length
;
++
i
)
{
libxsmm_bfloat16_hp
t
;
t
.
i
[
1
]
=
in
[
i
];
t
.
i
[
0
]
=
0
;
out
[
i
]
=
t
.
f
;
}
}
third_party/libxsmm/src/libxsmm_dnn_convolution.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst, Alexander Heinecke, Evangelos Georganas, Rajkishore Barik (Intel Corp.)
******************************************************************************/
#include <libxsmm_sync.h>
#include "libxsmm_main.h"
#include "libxsmm_dnn_convolution_forward.h"
#include "libxsmm_dnn_convolution_backward.h"
#include "libxsmm_dnn_convolution_weight_update.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(_OPENMP)
# include <omp.h>
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#define MIXED 0
#define KHWC 1
#define HWKC 2
#define CHWK 3
#define HWCK 4
/**********************************************************/
/* Helper functions for convolutions' general param setup */
/**********************************************************/
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_ifmblock
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
int
ofm
,
lp
;
libxsmm_dnn_get_feature_map_blocks
(
handle
->
desc
.
C
,
handle
->
desc
.
K
,
&
result
,
&
ofm
,
&
lp
,
handle
->
desc
.
datatype_in
,
handle
->
desc
.
datatype_out
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_ofmblock
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
int
ifm
,
lp
;
libxsmm_dnn_get_feature_map_blocks
(
handle
->
desc
.
C
,
handle
->
desc
.
K
,
&
ifm
,
&
result
,
&
lp
,
handle
->
desc
.
datatype_in
,
handle
->
desc
.
datatype_out
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fm_lp_block
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
int
ifm
,
ofm
;
libxsmm_dnn_get_feature_map_blocks
(
handle
->
desc
.
C
,
handle
->
desc
.
K
,
&
ifm
,
&
ofm
,
&
result
,
handle
->
desc
.
datatype_in
,
handle
->
desc
.
datatype_out
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fallback_loops_fwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* FIXME: For now fallback only if MB is not divisible by number of threads */
if
(
handle
->
desc
.
N
%
handle
->
desc
.
threads
!=
0
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_blocksifm
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
desc
.
C
/
handle
->
ifmblock
;
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_blocksofm
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
desc
.
K
/
handle
->
ofmblock
;
return
result
;
}
/**********************************************************/
/* Helper functions for FWD convolutions' parameter setup */
/**********************************************************/
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_ofw_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
result
=
handle
->
ofw
;
if
(
handle
->
ofw
==
56
)
{
result
=
28
;
}
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
if
(
handle
->
ofw
%
2
==
0
)
{
result
=
handle
->
ofw
/
2
;
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_pack_input_fwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Pack only for small images and when having large K to amortize, and we can only pack for 1x1 convolutions */
if
((
handle
->
ofw
<=
14
)
&&
(
handle
->
desc
.
K
>
512
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
&&
(
handle
->
desc
.
u
==
2
)
&&
(
handle
->
desc
.
v
==
2
))
{
result
=
1
;
}
/* For SPR we allow packing more aggressively to generate more efficient BRGEMMs */
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
if
((
handle
->
ofw
<=
14
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
&&
(
handle
->
desc
.
u
==
2
)
&&
(
handle
->
desc
.
v
==
2
))
{
result
=
1
;
}
}
/* Make sure we don't pack when minibatch is not divisible by number of threads since H is used potentially for parallelism */
if
(
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
{
result
=
0
;
}
/* we don't pack for int8 */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_ofh_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
/* Multiple rows for "small" images and 1x1 convolutions */
if
((
handle
->
ofh
<=
14
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
))
{
result
=
handle
->
ofh
;
}
/* In this case we will be using fallback generic loops, thus ofh_rb should be 1 */
if
((
handle
->
desc
.
N
%
handle
->
desc
.
threads
!=
0
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
{
result
=
1
;
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
if
(
handle
->
ofw
==
7
&&
handle
->
ofh
==
7
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
S
==
3
)
{
result
=
7
;
}
if
(
handle
->
ofw
==
14
&&
handle
->
ofh
==
14
/*&& handle->desc.R == 3 && handle->desc.S == 3*/
)
{
result
=
2
;
}
}
/* Make sure we don't use multiple rows when we don't pack input and convolutions are strided*/
if
((
handle
->
pack_input
==
0
)
&&
((
handle
->
desc
.
u
!=
1
)
||
(
handle
->
desc
.
v
!=
1
)))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_pixels_gemm
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
fwd_ofw_rb
*
handle
->
fwd_ofh_rb
;
/* In the case below we calculate redundantly pixels in order to efficiently use AMX */
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
if
(
handle
->
desc
.
R
!=
1
||
handle
->
desc
.
R
!=
1
)
{
if
(
handle
->
ofw
<
24
)
{
result
=
(
handle
->
fwd_ofw_rb
+
2
*
handle
->
desc
.
pad_w
)
*
(
handle
->
fwd_ofh_rb
-
2
)
+
2
*
(
handle
->
fwd_ofw_rb
+
handle
->
desc
.
pad_w
);
}
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_block_H
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
14
;
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
/* Spatial dimension block tuning for SPR */
if
((
handle
->
ofh
==
7
&&
handle
->
desc
.
u
==
2
)
||
(
handle
->
ofh
==
14
&&
handle
->
desc
.
R
!=
3
)
||
handle
->
ofh
==
27
||
(
handle
->
ofh
==
28
&&
handle
->
desc
.
R
==
1
)
||
handle
->
ofh
==
48
||
handle
->
ofh
==
54
||
handle
->
ofh
==
56
||
handle
->
ofh
==
112
)
{
result
=
4
;
}
}
else
{
/* Block H only for large images */
if
(
handle
->
ofh
>=
28
)
{
result
=
4
;
}
if
(
handle
->
ofh
==
28
&&
handle
->
desc
.
R
==
3
)
{
result
=
14
;
}
}
/* Make sure it is divisible bu the ofh_rb factor in the kernel */
while
(
result
%
handle
->
fwd_ofh_rb
!=
0
)
{
result
--
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_blocksifm_blocking
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
/* For 1x1 Convolutions bring in kernel all IFMs unless filters are huge*/
if
((
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
)
{
result
=
handle
->
blocksifm
;
if
((
handle
->
desc
.
C
>=
2048
)
&&
(
handle
->
desc
.
K
>=
512
))
{
result
=
1
;
}
if
(
(
handle
->
target_archid
<
LIBXSMM_X86_AVX512
)
&&
(
handle
->
desc
.
C
>=
512
)
)
{
result
=
2
;
}
if
(
(
handle
->
target_archid
<
LIBXSMM_X86_AVX512
)
&&
(
handle
->
desc
.
C
>=
1024
)
)
{
result
=
4
;
}
}
else
{
result
=
1
;
/* If small image can bring in more IFMS even if NOT 1x1 convolution */
if
(
handle
->
ofw
<=
7
)
{
result
=
2
;
}
}
if
(
handle
->
blocksifm
%
result
!=
0
)
{
result
=
1
;
}
/* In case of SPR bring always in all accumulation */
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)))
{
result
=
handle
->
blocksifm
;
}
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
result
=
handle
->
blocksifm
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_loop_order_fwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Switch to loop order 1 only if 1x1 convolution with "large" input image and "small" K */
if
((
handle
->
desc
.
H
>=
28
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
&&
(
handle
->
desc
.
C
>=
512
)
&&
(
handle
->
desc
.
K
<=
512
))
{
result
=
1
;
}
if
(
handle
->
ofw
==
56
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
256
&&
handle
->
desc
.
K
==
64
)
{
result
=
1
;
}
if
(
handle
->
ofw
==
28
&&
handle
->
desc
.
R
==
1
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_fwd_IFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
8
;
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
C
==
2048
&&
handle
->
desc
.
K
==
512
)
{
result
=
4
;
}
/* Make sure it is divisible by ifms in the kernel */
while
(
result
%
handle
->
blocksifm_blocking
!=
0
)
{
result
++
;
}
result
=
LIBXSMM_MIN
(
handle
->
blocksifm
,
result
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_fwd_OFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
8
;
if
(
handle
->
ofw
==
14
&&
handle
->
desc
.
K
==
1024
)
{
result
=
16
;
}
if
(
handle
->
ofw
==
7
)
{
result
=
16
;
}
result
=
LIBXSMM_MIN
(
handle
->
blocksofm
,
result
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_use_ofm_parallelization
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
#if 0
/* Use "hybrid" minibatch/ofm parallelization if we have huge filters */
if ((handle->desc.R >= 3) && (handle->desc.S >= 3) && (handle->desc.C >= 512) && (handle->desc.K >= 512)) {
result = 1;
}
#endif
if
((
handle
->
ofw
<=
7
)
&&
(
handle
->
desc
.
C
==
1024
)
&&
(
handle
->
desc
.
K
==
512
))
{
result
=
1
;
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)))
{
if
(
handle
->
ofw
==
7
)
{
result
=
1
;
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Avoid rim FMA if the convolution is 3x3 (non-strided) and the image is "small" */
if
((
handle
->
desc
.
R
==
3
)
&&
(
handle
->
desc
.
S
==
3
)
&&
(
handle
->
desc
.
u
==
1
)
&&
(
handle
->
desc
.
v
==
1
)
&&
(
handle
->
desc
.
pad_h_in
==
1
)
&&
(
handle
->
desc
.
pad_w_in
==
1
)
&&
(
handle
->
desc
.
H
==
handle
->
desc
.
W
)
)
{
if
(
handle
->
ofw
<=
28
)
{
result
=
1
;
}
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
result
=
0
;
}
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)))
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_shuffle_filter_accesses
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Shuffle filter accesses only if "pure minibatch" parallelization and large filters are involved */
if
((
handle
->
use_ofm_parallelization
==
0
)
&&
(
handle
->
desc
.
C
>
512
)
&&
(
handle
->
desc
.
K
>
512
))
{
result
=
1
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
C
==
512
)
{
result
=
1
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
512
&&
handle
->
desc
.
K
==
2048
)
{
result
=
1
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
2048
&&
handle
->
desc
.
K
==
512
)
{
result
=
1
;
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_acc_load
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
((
handle
->
options
&
LIBXSMM_DNN_CONV_OPTION_OVERWRITE
)
>
0
)
{
if
((
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
))
{
if
(
handle
->
blocksifm_blocking
==
handle
->
blocksifm
)
{
result
=
1
;
}
}
else
{
if
((
handle
->
blocksifm_blocking
==
handle
->
blocksifm
)
&&
(
handle
->
avoid_fmas_in_rim
==
0
))
{
result
=
1
;
}
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_init_fwd_gemm_flags
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
#if defined(LIBXSMM_DNN_CONVOLUTION_SETUP_USE_NTS)
/* If large image and NOT already loaded in accumulators, tnen use streaming stores */
if
((
handle
->
ofw
>=
56
)
&&
(
handle
->
desc
.
K
>=
256
)
&&
(
handle
->
avoid_acc_load
==
1
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
))
{
result
=
LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT
;
}
if
(
handle
->
ofw
==
56
&&
handle
->
desc
.
C
==
64
&&
handle
->
desc
.
K
==
64
&&
handle
->
desc
.
R
==
1
)
{
result
=
LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT
;
}
if
(
handle
->
ofw
==
56
&&
handle
->
desc
.
C
==
256
&&
handle
->
desc
.
K
==
64
&&
handle
->
desc
.
R
==
1
)
{
result
=
LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT
;
}
/* Disable since the GEMM output is going to f32 scratch */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
||
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
result
=
0
;
}
#else
LIBXSMM_UNUSED
(
handle
);
#endif
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)))
{
result
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_padding_copy
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
&&
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_fwd_scratch
(
libxsmm_dnn_layer
*
handle
)
{
handle
->
fwd_packing_padding_scratch_size
=
0
;
/* packing of input */
if
(
handle
->
pack_input
!=
0
)
{
handle
->
fwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
handle
->
desc
.
H
/
handle
->
desc
.
u
*
handle
->
desc
.
W
/
handle
->
desc
.
v
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* logical padding with copying in the fly */
if
(
handle
->
fwd_padding_copy
!=
0
)
{
handle
->
fwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
(
handle
->
desc
.
H
+
2
*
handle
->
desc
.
pad_h
)
*
(
handle
->
desc
.
W
+
2
*
handle
->
desc
.
pad_w
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* output buffer in high precision when we use BF16 */
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
)
{
handle
->
fwd_lp_output_full_scratch_size
=
(
size_t
)
LIBXSMM_MAX
(
handle
->
desc
.
threads
*
handle
->
fwd_gemm_pixels
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
),
handle
->
desc
.
N
*
handle
->
desc
.
K
*
handle
->
ofwp
*
handle
->
ofhp
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
));
handle
->
fwd_lp_output_block_scratch_size
=
(
size_t
)
handle
->
desc
.
threads
*
handle
->
fwd_ofw_rb
*
handle
->
fwd_ofh_rb
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
);
}
else
{
handle
->
fwd_lp_output_full_scratch_size
=
0
;
handle
->
fwd_lp_output_block_scratch_size
=
0
;
}
/* align sizes to full cacheline */
handle
->
fwd_packing_padding_scratch_size
+=
(
handle
->
fwd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
fwd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
fwd_lp_output_full_scratch_size
+=
(
handle
->
fwd_lp_output_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
fwd_lp_output_full_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
fwd_lp_output_block_scratch_size
+=
(
handle
->
fwd_lp_output_block_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
fwd_lp_output_block_scratch_size
%
LIBXSMM_CACHELINE
);
/* set offsets */
handle
->
fwd_packing_padding_scratch_offset
=
0
;
handle
->
fwd_lp_output_full_scratch_offset
=
handle
->
fwd_packing_padding_scratch_size
;
handle
->
fwd_lp_output_block_scratch_offset
=
handle
->
fwd_lp_output_full_scratch_offset
+
handle
->
fwd_lp_output_full_scratch_size
;
/* set overall scratch size for forward */
handle
->
fwd_scratch_size
=
handle
->
fwd_packing_padding_scratch_size
+
handle
->
fwd_lp_output_full_scratch_size
+
handle
->
fwd_lp_output_block_scratch_size
;
}
/**********************************************************/
/* Helper functions for BWD convolutions' parameter setup */
/**********************************************************/
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fallback_loops_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* FIXME: Fallback if MB is not divisible by number of threads */
if
(
handle
->
desc
.
N
%
handle
->
desc
.
threads
!=
0
)
{
result
=
1
;
}
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
(
handle
->
desc
.
pad_h
!=
0
||
handle
->
desc
.
pad_w
!=
0
))
{
result
=
1
;
}
if
((
handle
->
desc
.
R
>
1
&&
handle
->
desc
.
pad_h
==
0
)
||
(
handle
->
desc
.
S
>
1
&&
handle
->
desc
.
pad_w
==
0
))
{
result
=
1
;
}
if
((
handle
->
desc
.
R
>
1
&&
(
handle
->
desc
.
pad_h_out
==
0
||
handle
->
desc
.
pad_h_in
==
0
))
||
(
handle
->
desc
.
S
>
1
&&
(
handle
->
desc
.
pad_w_out
==
0
||
handle
->
desc
.
pad_w_in
==
0
))
)
{
result
=
1
;
}
if
((
handle
->
desc
.
R
>
1
&&
handle
->
desc
.
u
>
1
)
||
(
handle
->
desc
.
S
>
1
&&
handle
->
desc
.
v
>
1
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_bwd_ofw_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
libxsmm_dnn_convolution_setup_fwd_ofw_rb
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_bwd_ofh_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
libxsmm_dnn_convolution_setup_fwd_ofh_rb
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_bwd_pixels_gemm
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
bwd_ofw_rb
*
handle
->
bwd_ofh_rb
;
/* In the case below we calculate redundantly pixels in order to efficiently use AMX */
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
if
(
handle
->
desc
.
R
!=
1
||
handle
->
desc
.
R
!=
1
)
{
if
(
handle
->
ofw
<
24
)
{
result
=
(
handle
->
bwd_ofw_rb
+
2
*
handle
->
desc
.
pad_w
)
*
(
handle
->
bwd_ofh_rb
-
2
)
+
2
*
(
handle
->
bwd_ofw_rb
+
handle
->
desc
.
pad_w
);
}
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_bwd_block_H
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
result
=
libxsmm_dnn_convolution_setup_fwd_block_H
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_loop_order_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
result
=
libxsmm_dnn_convolution_setup_loop_order_fwd
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_bwd_IFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
result
=
LIBXSMM_MIN
(
handle
->
blocksifm
,
16
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_bwd_OFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
8
;
while
(
result
%
handle
->
blocksofm_blocking
!=
0
)
{
result
++
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_pack_input_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
((
handle
->
desc
.
u
!=
1
)
&&
(
handle
->
bwd_ofh_rb
!=
1
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_use_ifm_parallelization
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
(
handle
->
ofw
<=
7
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_rim_fmas_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_blocksofm_blocking
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
result
=
handle
->
blocksofm
;
}
else
{
result
=
1
;
if
(
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
S
==
3
&&
handle
->
ofh
==
7
&&
handle
->
ofw
==
7
)
{
result
=
2
;
}
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
result
=
handle
->
blocksofm
;
}
if
(
handle
->
blocksofm
%
result
!=
0
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_init_bwd_gemm_flags
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
LIBXSMM_UNUSED
(
handle
);
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
result
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_spread_input_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
LIBXSMM_UNUSED
(
handle
);
if
(((
handle
->
desc
.
u
!=
1
)
||
(
handle
->
desc
.
v
!=
1
))
&&
(
handle
->
bwd_ofh_rb
==
1
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_acc_load_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
((
handle
->
options
&
LIBXSMM_DNN_CONV_OPTION_OVERWRITE
)
>
0
)
{
if
((
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
))
{
if
(
handle
->
blocksofm_blocking
==
handle
->
blocksofm
)
{
result
=
1
;
}
}
else
{
if
((
handle
->
blocksofm_blocking
==
handle
->
blocksofm
)
&&
(
handle
->
avoid_fmas_in_rim
==
0
))
{
result
=
1
;
}
}
}
return
result
;
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_bwd_scratch
(
libxsmm_dnn_layer
*
handle
)
{
/* transpose of weights */
handle
->
bwd_filter_trans_scratch_size
=
(
size_t
)
handle
->
desc
.
C
*
handle
->
desc
.
K
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
bwd_packing_padding_scratch_size
=
0
;
/* packing of input */
if
(
handle
->
pack_input_bwd
!=
0
)
{
handle
->
bwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
handle
->
ofhp
*
handle
->
ofwp
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* logical padding with copying in the fly */
if
(
handle
->
use_fallback_bwd_loops
!=
0
)
{
handle
->
bwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
threads
*
handle
->
ifmblock
*
(
handle
->
desc
.
H
+
2
*
handle
->
desc
.
pad_h
)
*
(
handle
->
desc
.
W
+
2
*
handle
->
desc
.
pad_w
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* input bufffer in high precision when we use BF16 */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
handle
->
bwd_lp_input_full_scratch_size
=
(
size_t
)
LIBXSMM_MAX
(
handle
->
desc
.
threads
*
handle
->
bwd_gemm_pixels
*
handle
->
ifmblock
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
),
handle
->
desc
.
N
*
handle
->
desc
.
C
*
handle
->
ifwp
*
handle
->
ifhp
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
));
/* logical padding with copying in the fly */
if
(
handle
->
use_fallback_bwd_loops
!=
0
)
{
handle
->
bwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
threads
*
handle
->
ifmblock
*
(
handle
->
desc
.
H
+
2
*
handle
->
desc
.
pad_h
)
*
(
handle
->
desc
.
W
+
2
*
handle
->
desc
.
pad_w
)
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
);
}
}
else
{
handle
->
bwd_lp_input_full_scratch_size
=
0
;
}
/* align sizes to full cacheline */
handle
->
bwd_filter_trans_scratch_size
+=
(
handle
->
bwd_filter_trans_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
bwd_filter_trans_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
bwd_packing_padding_scratch_size
+=
(
handle
->
bwd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
bwd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
bwd_lp_input_full_scratch_size
+=
(
handle
->
bwd_lp_input_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
bwd_lp_input_full_scratch_size
%
LIBXSMM_CACHELINE
);
/* set offsets */
handle
->
bwd_filter_trans_scratch_offset
=
0
;
handle
->
bwd_packing_padding_scratch_offset
=
handle
->
bwd_filter_trans_scratch_size
;
handle
->
bwd_lp_input_full_scratch_offset
=
handle
->
bwd_packing_padding_scratch_offset
+
handle
->
bwd_packing_padding_scratch_size
;
/* set overall scratch size for forward */
handle
->
bwd_scratch_size
=
handle
->
bwd_filter_trans_scratch_size
+
handle
->
bwd_packing_padding_scratch_size
+
handle
->
bwd_lp_input_full_scratch_size
;
}
/**********************************************************/
/* Helper functions for UPD convolutions' parameter setup */
/**********************************************************/
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_loop_order_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
if
(
handle
->
ofh
==
28
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
C
==
128
&&
handle
->
desc
.
K
==
512
)
{
result
=
0
;
}
if
(
handle
->
ofh
==
28
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
C
==
128
&&
handle
->
desc
.
K
==
128
)
{
result
=
0
;
}
if
(
handle
->
ofw
==
28
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
256
&&
handle
->
desc
.
K
==
512
)
{
result
=
0
;
}
if
(
handle
->
ofw
==
14
&&
!
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
1024
&&
handle
->
desc
.
K
==
256
))
{
result
=
0
;
}
if
(
handle
->
ofw
==
7
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_pack_input_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Pack input only for very small images, 1x1 convs, with large K to amortize the relevant overhead */
if
((
handle
->
ofh
<=
7
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
&&
(
handle
->
desc
.
u
!=
1
)
&&
(
handle
->
desc
.
v
!=
1
)
&&
(
handle
->
desc
.
K
>=
2048
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_rim_fmas_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Avoid rim FMAs only for small images */
if
(
(
handle
->
ofh
<=
7
)
&&
(
handle
->
desc
.
R
==
3
)
&&
(
handle
->
desc
.
S
==
3
)
&&
(
handle
->
desc
.
pad_w
==
1
)
&&
(
handle
->
desc
.
pad_h
==
1
))
{
result
=
1
;
}
if
(
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_upd_ofw_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
result
=
handle
->
ofw
;
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_upd_ofh_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
/* Restrict the reduction chain which is ofw_rb*ofh_rb*/
if
(
handle
->
ofh
<=
28
)
{
result
=
handle
->
ofh
;
}
/* In the following scenario with strided convolutions and non batch reduce kernel make sure we have ofh_rb = 1 */
if
((
handle
->
desc
.
u
!=
1
)
&&
(
handle
->
desc
.
v
!=
1
)
&&
(
handle
->
upd_use_batchreduce
==
0
)
&&
(
handle
->
upd_pack_input
==
0
))
{
result
=
1
;
}
/* If using linearized taskview and have strided convs, make sure ofh_rb is 1.. */
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_avoid_rim_fmas
==
0
&&
handle
->
upd_pack_input
==
0
&&
handle
->
desc
.
u
!=
1
)
{
result
=
1
;
}
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_use_batchreduce
==
0
&&
(
handle
->
desc
.
R
!=
1
||
handle
->
desc
.
S
!=
1
))
{
result
=
1
;
}
if
(
handle
->
upd_linearized_tasklist
==
0
&&
handle
->
upd_use_batchreduce
==
0
&&
(
handle
->
desc
.
R
!=
1
||
handle
->
desc
.
S
!=
1
))
{
result
=
1
;
}
if
(
handle
->
ofw
==
56
&&
handle
->
desc
.
R
==
1
)
{
result
=
2
;
}
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_use_batchreduce
==
1
&&
handle
->
upd_avoid_rim_fmas
==
1
)
{
result
=
handle
->
ofh
;
}
if
((
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
&&
(
handle
->
desc
.
R
>
1
||
handle
->
desc
.
S
>
1
)
&&
(
handle
->
desc
.
u
>
1
||
handle
->
desc
.
v
>
1
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_upd_IFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
if
(
handle
->
ofh
==
56
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
v
==
1
)
{
result
=
4
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_upd_OFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
LIBXSMM_UNUSED
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_img_batchreduce_block
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
LIBXSMM_UNUSED
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_use_batchreduce_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
/* If W is large, no need for batchreduce kernel */
if
(
handle
->
ofw
>=
56
)
{
result
=
0
;
}
/* If we have packed the input, then disable batch-reduce GEMM */
if
(
handle
->
upd_pack_input
==
1
)
{
result
=
0
;
}
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_avoid_rim_fmas
==
0
)
{
result
=
0
;
}
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_avoid_rim_fmas
==
1
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_weight_copies_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
desc
.
threads
;
if
(
handle
->
ofw
<=
14
)
{
result
=
9
;
}
if
(
handle
->
ofw
==
14
&&
handle
->
desc
.
N
==
92
&&
handle
->
desc
.
threads
==
92
)
{
result
=
23
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
N
==
92
&&
handle
->
desc
.
threads
==
92
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
S
==
3
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
v
==
1
)
{
result
=
23
;
}
while
(
handle
->
desc
.
threads
%
result
!=
0
)
{
result
--
;
}
/* FIXME: Hardcoded logic for N=27, N=26 */
if
(
handle
->
desc
.
N
==
27
&&
handle
->
desc
.
threads
==
27
&&
handle
->
desc
.
R
==
1
&&
handle
->
ofw
==
14
&&
handle
->
desc
.
u
==
1
)
{
result
=
7
;
}
if
(((
handle
->
ofh
==
14
)
||
(
handle
->
ofw
==
7
&&
handle
->
desc
.
u
==
2
))
&&
handle
->
desc
.
N
==
26
&&
handle
->
desc
.
threads
==
26
)
{
result
=
13
;
}
if
((
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
&&
!
(
handle
->
upd_linearized_tasklist
==
0
&&
handle
->
upd_use_batchreduce
==
0
))
{
result
=
handle
->
desc
.
N
;
}
/* Make sure a single copy when we use linearized-task view */
if
(
handle
->
upd_linearized_tasklist
==
1
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_linearized_tasklist_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Use linearized task-list (i.e. no reduction) only if small images and large filters */
if
(
handle
->
ofh
<=
10
&&
handle
->
ofw
<=
10
)
{
result
=
1
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
N
==
92
&&
handle
->
desc
.
threads
==
92
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
S
==
3
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
v
==
1
)
{
result
=
0
;
}
if
(
handle
->
ofh
==
14
&&
handle
->
ofw
==
14
&&
handle
->
desc
.
N
==
23
&&
handle
->
desc
.
threads
==
23
)
{
result
=
1
;
}
#if 0
if ((handle->blocksofm * handle->blocksifm * handle->desc.R * handle->desc.S > (handle->desc.threads * 4)) && (handle->ofh <= 56)) {
result = 1;
}
#endif
if
(
handle
->
desc
.
u
==
2
&&
handle
->
desc
.
v
==
2
&&
handle
->
desc
.
K
==
512
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_init_upd_gemm_flags
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
LIBXSMM_UNUSED
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_bf16_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
remainder_pixels
,
max_init_offset
,
max_compute_offset_input
,
input_compute_pad
,
accum_length_pixels
,
compute_pixels
;
const
int
multiple_target
=
2
;
int
IFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ifhp
;
int
IFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ifwp
;
int
OFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ofhp
;
int
OFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ofwp
;
handle
->
upd_linearized_pixels
=
1
;
if
(
handle
->
desc
.
S
!=
1
&&
handle
->
desc
.
v
!=
1
)
{
handle
->
upd_linearized_pixels
=
0
;
handle
->
upd_trans_w_only
=
0
;
}
/* For large images facilitate the "large" transposes by blocking the pixel/reduction domains */
if
(
handle
->
ofw
>=
56
&&
handle
->
ofh
>=
56
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
v
==
1
)
{
handle
->
upd_linearized_pixels
=
0
;
handle
->
upd_trans_w_only
=
1
;
}
handle
->
on_the_fly_input_packing
=
0
;
handle
->
upd_pack_input_upfront
=
0
;
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
upd_linearized_tasklist
=
0
;
if
(
handle
->
upd_linearized_pixels
==
1
)
{
/* Logistics to pad accumulation chainlength */
compute_pixels
=
handle
->
ofw
*
handle
->
ofh
+
2
*
handle
->
desc
.
pad_w
*
(
handle
->
ofh
-
1
);
remainder_pixels
=
(
compute_pixels
%
multiple_target
==
0
)
?
0
:
(
compute_pixels
/
multiple_target
+
1
)
*
multiple_target
-
compute_pixels
;
accum_length_pixels
=
compute_pixels
+
remainder_pixels
;
/* In this case compact input upfront */
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
(
handle
->
desc
.
u
!=
1
||
handle
->
desc
.
v
!=
1
))
{
handle
->
upd_pack_input_upfront
=
1
;
}
/* Logistics for input transpose and additional pixel padding */
max_init_offset
=
2
*
handle
->
desc
.
pad_h
*
IFWP
+
2
*
handle
->
desc
.
pad_w
;
max_compute_offset_input
=
max_init_offset
+
accum_length_pixels
;
input_compute_pad
=
(
max_compute_offset_input
>
IFWP
*
IFHP
)
?
max_compute_offset_input
-
IFWP
*
IFHP
:
0
;
handle
->
input_pixels
=
IFWP
*
IFHP
+
input_compute_pad
;
if
(
handle
->
upd_pack_input_upfront
)
{
handle
->
input_pixels
=
accum_length_pixels
;
}
handle
->
output_pixels
=
accum_length_pixels
;
handle
->
pixel_blocking
=
accum_length_pixels
;
handle
->
n_used_pixels
=
accum_length_pixels
;
handle
->
compute_pixels
=
compute_pixels
;
handle
->
use_intermediate_f32_wt_tensor
=
(
handle
->
pixel_blocking
==
handle
->
n_used_pixels
)
?
0
:
1
;
if
(
handle
->
ofw
<=
14
)
{
handle
->
use_hybrid_imgofm_parallelization
=
1
;
handle
->
weight_copies
=
libxsmm_dnn_convolution_setup_weight_copies_upd
(
handle
);
if
(
handle
->
ofw
==
14
&&
handle
->
desc
.
K
>=
1024
)
{
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
weight_copies
=
handle
->
desc
.
threads
;
}
}
else
{
handle
->
weight_copies
=
handle
->
desc
.
threads
;
}
}
if
(
handle
->
upd_linearized_pixels
==
0
)
{
handle
->
weight_copies
=
handle
->
desc
.
threads
;
if
(
handle
->
desc
.
v
!=
1
)
{
handle
->
on_the_fly_input_packing
=
1
;
}
remainder_pixels
=
(
handle
->
ofw
%
multiple_target
==
0
)
?
0
:
(
handle
->
ofw
/
multiple_target
+
1
)
*
multiple_target
-
handle
->
ofw
;
handle
->
ofwp_extended
=
OFWP
+
remainder_pixels
;
handle
->
ifwp_extended
=
IFWP
+
remainder_pixels
;
handle
->
output_pixels
=
OFHP
*
handle
->
ofwp_extended
;
/* coverity[identical_branches] */
handle
->
batchreduce_h_pixels
=
(
handle
->
upd_trans_w_only
)
?
1
:
1
;
/* TODO: identical_branches */
handle
->
use_intermediate_f32_wt_tensor
=
(
handle
->
batchreduce_h_pixels
==
handle
->
ofh
)
?
0
:
1
;
}
if
(
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
{
handle
->
use_intermediate_f32_wt_tensor
=
1
;
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
weight_copies
=
LIBXSMM_MIN
(
handle
->
desc
.
N
,
handle
->
desc
.
threads
);
}
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_bf16_upd_amx
(
libxsmm_dnn_layer
*
handle
)
{
/* JIT related variables... */
libxsmm_blasint
LDA
=
handle
->
ofmblock
;
libxsmm_blasint
LDB
=
handle
->
input_pixels
;
libxsmm_blasint
LDC
=
handle
->
ofmblock
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
int
l_tc_flags
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
size_t
stride_a
,
stride_b
;
int
unroll_hint
;
float
beta
;
int
remainder_pixels
,
max_init_offset
,
max_compute_offset_input
,
input_compute_pad
,
accum_length_pixels
,
compute_pixels
;
const
int
multiple_target
=
32
;
int
IFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ifhp
;
int
IFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ifwp
;
int
OFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ofwp
;
handle
->
upd_linearized_pixels
=
1
;
if
(
handle
->
desc
.
S
!=
1
&&
handle
->
desc
.
v
!=
1
)
{
handle
->
upd_linearized_pixels
=
0
;
}
handle
->
fuse_upd_transposes
=
1
;
handle
->
pack_to_cnhw
=
0
;
handle
->
on_the_fly_input_packing
=
0
;
handle
->
upd_pack_input_upfront
=
0
;
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
upd_linearized_tasklist
=
0
;
if
(((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
&&
(
handle
->
ofw
==
7
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
)
{
handle
->
pack_to_cnhw
=
1
;
}
if
(
handle
->
upd_linearized_pixels
==
1
)
{
if
(
handle
->
pack_to_cnhw
==
0
)
{
handle
->
fuse_upd_transposes
=
1
;
/* Logistics to pad accumulation chainlength */
compute_pixels
=
handle
->
ofw
*
handle
->
ofh
+
2
*
handle
->
desc
.
pad_w
*
(
handle
->
ofh
-
1
);
remainder_pixels
=
(
compute_pixels
%
multiple_target
==
0
)
?
0
:
(
compute_pixels
/
multiple_target
+
1
)
*
multiple_target
-
compute_pixels
;
accum_length_pixels
=
compute_pixels
+
remainder_pixels
;
/* In this case compact input upfront */
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
(
handle
->
desc
.
u
!=
1
||
handle
->
desc
.
v
!=
1
))
{
handle
->
upd_pack_input_upfront
=
1
;
}
/* Logistics for input transpose and additional pixel padding */
max_init_offset
=
2
*
handle
->
desc
.
pad_h
*
IFWP
+
2
*
handle
->
desc
.
pad_w
;
max_compute_offset_input
=
max_init_offset
+
accum_length_pixels
;
input_compute_pad
=
(
max_compute_offset_input
>
IFWP
*
IFHP
)
?
max_compute_offset_input
-
IFWP
*
IFHP
:
0
;
handle
->
input_pixels
=
IFWP
*
IFHP
+
input_compute_pad
;
if
(
handle
->
upd_pack_input_upfront
)
{
handle
->
input_pixels
=
accum_length_pixels
;
}
handle
->
output_pixels
=
accum_length_pixels
;
handle
->
pixel_blocking
=
accum_length_pixels
;
handle
->
n_used_pixels
=
accum_length_pixels
;
handle
->
compute_pixels
=
compute_pixels
;
handle
->
use_intermediate_f32_wt_tensor
=
(
handle
->
pixel_blocking
==
handle
->
n_used_pixels
)
?
0
:
1
;
#if 0
handle->scratch2_size = (size_t) (handle->desc.N * handle->output_pixels * handle->desc.K * sizeof(float)/2);
if (handle->use_intermediate_f32_wt_tensor) {
handle->scratch2_size += (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads * sizeof(float);
}
handle->scratch3_size = (size_t) (handle->desc.N * handle->input_pixels * handle->desc.C * sizeof(float)/2);
#endif
if
(
handle
->
ofw
<=
14
)
{
handle
->
use_hybrid_imgofm_parallelization
=
1
;
handle
->
fuse_upd_transposes
=
0
;
}
else
{
handle
->
weight_copies
=
handle
->
desc
.
threads
;
}
if
((
handle
->
ofmblock
%
32
!=
0
)
||
(
handle
->
ifmblock
%
32
!=
0
))
{
handle
->
fuse_upd_transposes
=
0
;
}
}
else
{
/* Logistics to pad accumulation chainlength */
handle
->
use_hybrid_imgofm_parallelization
=
1
;
handle
->
weight_copies
=
7
;
while
(
handle
->
desc
.
threads
%
handle
->
weight_copies
!=
0
)
{
handle
->
weight_copies
--
;
}
compute_pixels
=
handle
->
ofw
*
handle
->
ofh
*
(
handle
->
desc
.
N
/
handle
->
weight_copies
);
remainder_pixels
=
(
compute_pixels
%
multiple_target
==
0
)
?
0
:
(
compute_pixels
/
multiple_target
+
1
)
*
multiple_target
-
compute_pixels
;
handle
->
remainder_pixels
=
remainder_pixels
;
accum_length_pixels
=
compute_pixels
+
remainder_pixels
;
handle
->
output_pixels
=
accum_length_pixels
*
handle
->
weight_copies
;
handle
->
input_pixels
=
accum_length_pixels
*
handle
->
weight_copies
;
handle
->
pixel_blocking
=
accum_length_pixels
;
handle
->
n_used_pixels
=
accum_length_pixels
;
handle
->
use_intermediate_f32_wt_tensor
=
0
;
#if 0
handle->scratch2_size = (size_t) (handle->weight_copies * handle->output_pixels * handle->desc.K * sizeof(float)/2);
handle->scratch3_size = (size_t) (handle->weight_copies * handle->input_pixels * handle->desc.C * sizeof(float)/2);
#endif
}
}
if
(
handle
->
upd_linearized_pixels
==
0
)
{
handle
->
weight_copies
=
handle
->
desc
.
threads
;
if
(
handle
->
desc
.
v
!=
1
)
{
handle
->
on_the_fly_input_packing
=
1
;
}
remainder_pixels
=
(
handle
->
ofw
%
multiple_target
==
0
)
?
0
:
(
handle
->
ofw
/
multiple_target
+
1
)
*
multiple_target
-
handle
->
ofw
;
handle
->
remainder_pixels
=
remainder_pixels
;
handle
->
ofwp_extended
=
OFWP
+
remainder_pixels
;
handle
->
ifwp_extended
=
IFWP
+
remainder_pixels
;
handle
->
batchreduce_h_pixels
=
handle
->
ofh
;
handle
->
use_intermediate_f32_wt_tensor
=
(
handle
->
batchreduce_h_pixels
==
handle
->
ofh
)
?
0
:
1
;
#if 0
handle->scratch2_size = (size_t) (handle->desc.N * handle->ofhp*handle->ofwp_extended * handle->desc.K * sizeof(float)/2);
if (handle->use_intermediate_f32_wt_tensor) {
handle->scratch2_size += (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads * sizeof(float);
}
handle->scratch3_size = (size_t) (handle->desc.N * handle->ifhp * handle->ifwp_extended * handle->desc.C * sizeof(float)/2);
#endif
}
/* Now that all decisions have been made, JIT the proper kernel... */
beta
=
(
handle
->
use_intermediate_f32_wt_tensor
)
?
(
float
)
1
.
0
:
(
float
)
0
.
0
;
if
(
handle
->
upd_linearized_pixels
==
0
)
{
LDA
=
handle
->
ofmblock
;
LDB
=
IFHP
*
handle
->
ifwp_extended
;
LDC
=
handle
->
ofmblock
;
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
unroll_hint
=
handle
->
batchreduce_h_pixels
;
stride_a
=
handle
->
ofwp_extended
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
stride_b
=
handle
->
desc
.
u
*
handle
->
ifwp_extended
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
upd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
ofw
+
handle
->
remainder_pixels
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
handle
->
upd_compute_kernel_brgemm_no_linearized_pixels
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
ofw
+
handle
->
remainder_pixels
,
(
libxsmm_blasint
)
stride_a
,
(
libxsmm_blasint
)
stride_b
,
unroll_hint
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
}
else
{
LDA
=
handle
->
ofmblock
;
LDB
=
handle
->
input_pixels
;
LDC
=
handle
->
ofmblock
;
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
if
(
handle
->
use_hybrid_imgofm_parallelization
==
0
)
{
handle
->
upd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
handle
->
upd_compute_kernel_gemm_linearized_pixels_no_hybrid_par
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
}
else
{
if
(
handle
->
pack_to_cnhw
==
1
)
{
handle
->
upd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
handle
->
upd_compute_kernel_gemm_linearized_pixels_hybrid_par_cnhw
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
}
else
{
/* TODO: Hoist here hybrid parallelization logic and then we should be able to also provide unroll hint in the BRGEMM call */
stride_a
=
handle
->
blocksofm
*
handle
->
output_pixels
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
stride_b
=
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
input_pixels
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
upd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
handle
->
upd_compute_kernel_brgemm_linearized_pixels_hybrid_par_no_cnhw
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
(
libxsmm_blasint
)
stride_a
,
(
libxsmm_blasint
)
stride_b
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
}
}
}
if
(
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
{
handle
->
use_intermediate_f32_wt_tensor
=
1
;
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
weight_copies
=
LIBXSMM_MIN
(
handle
->
desc
.
N
,
handle
->
desc
.
threads
);
}
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_upd_padding_copy
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
&&
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_upd_scratch
(
libxsmm_dnn_layer
*
handle
)
{
handle
->
upd_packing_padding_scratch_size
=
0
;
/* packing of input */
if
(
handle
->
upd_pack_input
!=
0
)
{
handle
->
upd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
handle
->
desc
.
H
/
handle
->
desc
.
u
*
handle
->
desc
.
W
/
handle
->
desc
.
v
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* logical padding with copying in the fly */
if
(
handle
->
upd_padding_copy
!=
0
)
{
handle
->
upd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
(
handle
->
desc
.
H
+
2
*
handle
->
desc
.
pad_h
)
*
(
handle
->
desc
.
W
+
2
*
handle
->
desc
.
pad_w
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* output/input buffer to transpose when we use bf16 */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
if
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
int
OFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ofhp
;
int
IFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ifhp
;
if
(
handle
->
upd_linearized_pixels
==
1
)
{
handle
->
upd_lp_output_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
handle
->
output_pixels
*
handle
->
desc
.
K
*
sizeof
(
handle
->
datatype_in
));
handle
->
upd_lp_input_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
handle
->
input_pixels
*
handle
->
desc
.
C
*
sizeof
(
handle
->
datatype_in
));
}
if
(
handle
->
upd_linearized_pixels
==
0
)
{
handle
->
upd_lp_output_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
OFHP
*
handle
->
ofwp_extended
*
handle
->
desc
.
K
*
sizeof
(
handle
->
datatype_in
));
handle
->
upd_lp_input_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
IFHP
*
handle
->
ifwp_extended
*
handle
->
desc
.
C
*
sizeof
(
handle
->
datatype_in
));
}
}
else
{
const
int
multiple_target
=
2
;
int
IFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ifhp
;
int
IFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ifwp
;
int
OFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ofhp
;
int
OFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ofwp
;
if
(
handle
->
upd_linearized_pixels
==
1
)
{
int
compute_pixels
=
handle
->
ofw
*
handle
->
ofh
+
2
*
handle
->
desc
.
pad_w
*
(
handle
->
ofh
-
1
);
int
remainder_pixels
=
(
compute_pixels
%
multiple_target
==
0
)
?
0
:
(
compute_pixels
/
multiple_target
+
1
)
*
multiple_target
-
compute_pixels
;
int
accum_length_pixels
=
compute_pixels
+
remainder_pixels
;
int
max_init_offset
=
2
*
handle
->
desc
.
pad_h
*
IFWP
+
2
*
handle
->
desc
.
pad_w
;
int
max_compute_offset_input
=
max_init_offset
+
accum_length_pixels
;
int
input_compute_pad
=
(
max_compute_offset_input
>
IFWP
*
IFHP
)
?
max_compute_offset_input
-
IFWP
*
IFHP
:
0
;
int
input_pixels
=
IFWP
*
IFHP
+
input_compute_pad
;
if
(
handle
->
upd_pack_input_upfront
==
1
)
{
input_pixels
=
accum_length_pixels
;
}
handle
->
upd_lp_output_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
accum_length_pixels
*
handle
->
desc
.
K
*
sizeof
(
handle
->
datatype_in
));
handle
->
upd_lp_input_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
input_pixels
*
handle
->
desc
.
C
*
sizeof
(
handle
->
datatype_in
));
}
if
(
handle
->
upd_linearized_pixels
==
0
)
{
int
remainder_pixels
=
(
handle
->
ofw
%
multiple_target
==
0
)
?
0
:
(
handle
->
ofw
/
multiple_target
+
1
)
*
multiple_target
-
handle
->
ofw
;
int
ofwp_extended
=
OFWP
+
remainder_pixels
;
int
ifwp_extended
=
IFWP
+
remainder_pixels
;
handle
->
upd_lp_output_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
OFHP
*
ofwp_extended
*
handle
->
desc
.
K
*
sizeof
(
handle
->
datatype_in
));
handle
->
upd_lp_input_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
IFHP
*
ifwp_extended
*
handle
->
desc
.
C
*
sizeof
(
handle
->
datatype_in
));
}
}
handle
->
upd_lp_filter_full_scratch_size
=
(
size_t
)
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
desc
.
C
*
handle
->
desc
.
K
*
handle
->
desc
.
threads
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
);
}
else
{
handle
->
upd_lp_output_full_scratch_size
=
0
;
handle
->
upd_lp_input_full_scratch_size
=
0
;
handle
->
upd_lp_filter_full_scratch_size
=
0
;
}
/* filter scratch */
handle
->
upd_filter_scratch_size
=
(
size_t
)
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
desc
.
C
*
handle
->
desc
.
K
*
LIBXSMM_MAX
(
handle
->
desc
.
threads
,
handle
->
desc
.
N
)
*
sizeof
(
float
);
/* align sizes to full cacheline */
handle
->
upd_packing_padding_scratch_size
+=
(
handle
->
upd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
upd_lp_output_full_scratch_size
+=
(
handle
->
upd_lp_output_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_lp_output_full_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
upd_lp_input_full_scratch_size
+=
(
handle
->
upd_lp_input_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_lp_input_full_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
upd_filter_scratch_size
+=
(
handle
->
upd_filter_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_filter_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
upd_lp_filter_full_scratch_size
+=
(
handle
->
upd_lp_filter_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_lp_filter_full_scratch_size
%
LIBXSMM_CACHELINE
);
/* calculate offsets */
handle
->
upd_packing_padding_scratch_offset
=
0
;
handle
->
upd_lp_output_full_scratch_offset
=
handle
->
upd_packing_padding_scratch_size
;
handle
->
upd_lp_input_full_scratch_offset
=
handle
->
upd_lp_output_full_scratch_offset
+
handle
->
upd_lp_output_full_scratch_size
;
handle
->
upd_filter_scratch_offset
=
handle
->
upd_lp_input_full_scratch_offset
+
handle
->
upd_lp_input_full_scratch_size
;
handle
->
upd_lp_filter_full_scratch_offset
=
handle
->
upd_filter_scratch_offset
+
handle
->
upd_filter_scratch_size
;
/* set overall scratch size for update */
handle
->
upd_scratch_size
=
handle
->
upd_packing_padding_scratch_size
+
handle
->
upd_lp_output_full_scratch_size
+
handle
->
upd_lp_input_full_scratch_size
+
handle
->
upd_filter_scratch_size
+
handle
->
upd_lp_filter_full_scratch_size
;
}
LIBXSMM_API_INLINE
libxsmm_dnn_err_t
libxsmm_dnn_convolution_setup
(
libxsmm_dnn_layer
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
libxsmm_blasint
_ldi
=
64
,
_ldo
=
64
;
libxsmm_blasint
ldx
;
libxsmm_blasint
ldA
;
libxsmm_blasint
ldC
;
int
beta_int
;
float
beta
;
int
l_flags
;
int
l_tc_flags
;
/* init libxsmm */
LIBXSMM_INIT
/* Generic parameter setup */
handle
->
target_archid
=
libxsmm_target_archid
;
if
(
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
&&
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
((
handle
->
desc
.
C
%
16
!=
0
)
||
(
handle
->
desc
.
K
%
16
!=
0
))
)
{
handle
->
target_archid
=
LIBXSMM_X86_AVX512_CPX
;
}
handle
->
ifmblock
=
libxsmm_dnn_convolution_setup_ifmblock
(
handle
);
handle
->
ofmblock
=
libxsmm_dnn_convolution_setup_ofmblock
(
handle
);
handle
->
fm_lp_block
=
libxsmm_dnn_convolution_setup_fm_lp_block
(
handle
);
handle
->
blocksifm
=
libxsmm_dnn_convolution_setup_blocksifm
(
handle
);
handle
->
blocksofm
=
libxsmm_dnn_convolution_setup_blocksofm
(
handle
);
/* If in SPR, generate tilerelease kernel */
if
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
int
l_tr_flags
=
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
handle
->
tilerelease_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ifmblock
,
handle
->
ifmblock
,
handle
->
ifmblock
,
NULL
,
NULL
,
NULL
,
NULL
,
NULL
,
&
l_tr_flags
,
NULL
);
}
/* FWD parameter setup */
handle
->
fwd_ofw_rb
=
libxsmm_dnn_convolution_setup_fwd_ofw_rb
(
handle
);
handle
->
pack_input
=
libxsmm_dnn_convolution_setup_pack_input_fwd
(
handle
);
handle
->
fwd_ofh_rb
=
libxsmm_dnn_convolution_setup_fwd_ofh_rb
(
handle
);
handle
->
fwd_gemm_pixels
=
libxsmm_dnn_convolution_setup_fwd_pixels_gemm
(
handle
);
handle
->
block_fwd_oj
=
libxsmm_dnn_convolution_setup_fwd_block_H
(
handle
);
handle
->
loop_order
=
libxsmm_dnn_convolution_setup_loop_order_fwd
(
handle
);
handle
->
blocksifm_blocking
=
libxsmm_dnn_convolution_setup_blocksifm_blocking
(
handle
);
handle
->
block_fwd_ofm
=
libxsmm_dnn_convolution_setup_block_fwd_OFM
(
handle
);
handle
->
block_fwd_ifm
=
libxsmm_dnn_convolution_setup_block_fwd_IFM
(
handle
);
handle
->
avoid_fmas_in_rim
=
libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd
(
handle
);
handle
->
use_ofm_parallelization
=
libxsmm_dnn_convolution_setup_use_ofm_parallelization
(
handle
);
handle
->
shuffle_filter_accesses
=
libxsmm_dnn_convolution_setup_shuffle_filter_accesses
(
handle
);
handle
->
avoid_acc_load
=
libxsmm_dnn_convolution_setup_avoid_acc_load
(
handle
);
handle
->
fwd_flags
=
libxsmm_dnn_convolution_setup_init_fwd_gemm_flags
(
handle
);
handle
->
use_fallback_fwd_loops
=
libxsmm_dnn_convolution_setup_fallback_loops_fwd
(
handle
);
handle
->
fwd_padding_copy
=
libxsmm_dnn_convolution_setup_fwd_padding_copy
(
handle
);
#if 0
if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 ) {
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
handle->block_fwd_ofm = 1;
handle->block_fwd_oj = handle->fwd_ofh_rb;
ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock;
ldA = handle->ofmblock;
ldC = handle->ofmblock;
beta = (handle->avoid_acc_load) ? (float)0.0 : (float)1.0;
l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags;
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
handle->fwd_compute_kernel_offs_f32 = NULL;
handle->fwd_compute_kernel_strd_f32 = NULL;
handle->fwd_compute_kernel_addr_a_f32 = NULL;
handle->fwd_compute_kernel_addr_b_f32 = NULL;
if (handle->desc.R == 1 && handle->desc.S == 1) {
const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp;
const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp;
int stride_a = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in);
int stride_b = IFW * IFH * handle->ifmblock * libxsmm_dnn_typesize(handle->datatype_in);
handle->fwd_compute_kernel_strd_f32 = libxsmm_smmdispatch_reducebatch_strd_unroll(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, stride_a, stride_b, handle->blocksifm_blocking, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
} else {
const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp );
const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp );
int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking;
int i = 0, ifm, ki, kj;
handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) {
for (kj = 0; kj < handle->desc.R; kj++) {
for (ki = 0; ki < handle->desc.S; ki++) {
handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock +
kj * handle->desc.S * handle->ifmblock * handle->ofmblock +
ki * handle->ifmblock * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in);
handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock +
kj * IFW * handle->ifmblock +
ki * handle->ifmblock) * libxsmm_dnn_typesize(handle->datatype_in);
i++;
}
}
}
handle->fwd_compute_kernel_offs_f32 = libxsmm_smmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
}
handle->fwd_compute_kernel_addr_a_f32 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
handle->fwd_compute_kernel_addr_b_f32 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
}
#endif
if
(
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
&&
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
handle
->
block_fwd_ofm
=
1
;
handle
->
block_fwd_oj
=
handle
->
fwd_ofh_rb
;
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
ldA
=
handle
->
ofmblock
;
ldC
=
handle
->
ofmblock
;
beta
=
(
handle
->
avoid_acc_load
)
?
(
float
)
0
.
0
:
(
float
)
1
.
0
;
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
l_tc_flags
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
handle
->
fwd_compute_kernel_addr
=
NULL
;
handle
->
fwd_compute_kernel_offs_a
=
NULL
;
handle
->
fwd_compute_kernel_offs_b
=
NULL
;
handle
->
fwd_compute_kernel_strd
=
NULL
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
const
int
IFW
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofwp
:
handle
->
ifwp
;
const
int
IFH
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofhp
:
handle
->
ifhp
;
size_t
stride_a
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
size_t
stride_b
=
IFW
*
IFH
*
handle
->
ifmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
fwd_compute_kernel_strd
=
libxsmm_bmmdispatch_reducebatch_strd_unroll
(
handle
->
ofmblock
,
handle
->
fwd_gemm_pixels
,
handle
->
ifmblock
,
(
libxsmm_blasint
)
stride_a
,
(
libxsmm_blasint
)
stride_b
,
handle
->
blocksifm_blocking
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
}
else
{
const
int
IFW
=
(
handle
->
fwd_padding_copy
==
1
)
?
handle
->
ifwp
+
2
*
handle
->
desc
.
pad_w
:
(
(
handle
->
pack_input
==
1
)
?
handle
->
ofwp
:
handle
->
ifwp
);
const
int
IFH
=
(
handle
->
fwd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
(
(
handle
->
pack_input
==
1
)
?
handle
->
ofhp
:
handle
->
ifhp
);
int
n_blocks
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
blocksifm_blocking
;
int
i
=
0
,
ifm
,
ki
,
kj
;
handle
->
A_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
handle
->
B_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
for
(
ifm
=
0
;
ifm
<
handle
->
blocksifm_blocking
;
ifm
++
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
kj
++
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
ki
++
)
{
handle
->
A_offsets
[
i
]
=
(
ifm
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
kj
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
ki
*
handle
->
ifmblock
*
handle
->
ofmblock
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
B_offsets
[
i
]
=
(
ifm
*
IFH
*
IFW
*
handle
->
ifmblock
+
kj
*
IFW
*
handle
->
ifmblock
+
ki
*
handle
->
ifmblock
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
i
++
;
}
}
}
handle
->
fwd_compute_kernel_offs_a
=
libxsmm_bmmdispatch_reducebatch_offs
(
handle
->
ofmblock
,
handle
->
fwd_gemm_pixels
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
handle
->
fwd_compute_kernel_offs_b
=
libxsmm_bsmmdispatch_reducebatch_offs
(
handle
->
ofmblock
,
handle
->
fwd_gemm_pixels
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
}
handle
->
fwd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
fwd_gemm_pixels
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
}
handle
->
code_fwd
[
0
].
ptr
=
0
;
handle
->
code_fwd
[
1
].
ptr
=
0
;
handle
->
code_fwd
[
2
].
ptr
=
0
;
/* JIT cvt eltwise functions for fwd convolutions */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
_ldi
=
handle
->
ofmblock
*
handle
->
ofwp
;
_ldo
=
handle
->
ofmblock
*
handle
->
ofwp
;
handle
->
fwd_cvtfp32bf16_kernel
=
libxsmm_dispatch_meltw_unary
(
handle
->
ofmblock
*
handle
->
fwd_ofw_rb
,
handle
->
fwd_ofh_rb
,
&
_ldi
,
&
_ldo
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_BF16
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_IDENTITY
);
}
/* Create strided BRGEMMs for i8i32 convolutions */
if
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_I32
))
{
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
ldA
=
handle
->
ofmblock
;
ldC
=
handle
->
ofmblock
;
beta_int
=
(
handle
->
avoid_acc_load
)
?
0
:
1
;
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
handle
->
fwd_flags
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
const
int
IFW
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofwp
:
handle
->
ifwp
;
const
int
IFH
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofhp
:
handle
->
ifhp
;
libxsmm_blasint
stride_A
=
handle
->
ifmblock
*
handle
->
ofmblock
*
sizeof
(
char
);
libxsmm_blasint
stride_B
=
handle
->
ifmblock
*
IFW
*
IFH
*
sizeof
(
char
)
;
handle
->
gemm_fwd
.
xgemm
.
subimrs
=
libxsmm_subimmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
stride_A
,
stride_B
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
else
{
const
int
IFW
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofwp
:
handle
->
ifwp
;
const
int
IFH
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofhp
:
handle
->
ifhp
;
if
(
handle
->
avoid_fmas_in_rim
==
0
)
{
int
n_blocks
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
blocksifm_blocking
;
int
i
=
0
,
ifm
,
ki
,
kj
;
handle
->
A_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
handle
->
B_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
for
(
ifm
=
0
;
ifm
<
handle
->
blocksifm_blocking
;
ifm
++
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
kj
++
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
ki
++
)
{
handle
->
A_offsets
[
i
]
=
(
ifm
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
kj
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
ki
*
handle
->
ifmblock
*
handle
->
ofmblock
)
*
sizeof
(
char
);
handle
->
B_offsets
[
i
]
=
(
ifm
*
IFH
*
IFW
*
handle
->
ifmblock
+
kj
*
IFW
*
handle
->
ifmblock
+
ki
*
handle
->
ifmblock
)
*
sizeof
(
char
);
i
++
;
}
}
}
handle
->
gemm_fwd
.
xgemm
.
subimro
=
libxsmm_subimmdispatch_reducebatch_offs
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
else
{
libxsmm_blasint
stride_A
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
char
);
libxsmm_blasint
stride_B
=
handle
->
ifmblock
*
IFW
*
IFH
*
sizeof
(
char
)
;
handle
->
gemm_fwd
.
xgemm
.
subimrs
=
libxsmm_subimmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
stride_A
,
stride_B
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
handle
->
gemm_fwd2
.
xgemm
.
subimrs
=
libxsmm_subimmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
stride_A
,
stride_B
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
}
}
else
if
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_I8
))
{
ldx
=
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
ldA
=
handle
->
ofmblock
;
ldC
=
handle
->
ofmblock
;
beta_int
=
0
;
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
handle
->
fwd_flags
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
const
int
IFW
=
handle
->
ifwp
;
const
int
IFH
=
handle
->
ifhp
;
libxsmm_blasint
stride_A
=
handle
->
ifmblock
*
handle
->
ofmblock
*
sizeof
(
char
);
libxsmm_blasint
stride_B
=
handle
->
ifmblock
*
IFW
*
IFH
*
sizeof
(
char
)
;
handle
->
gemm_fwd
.
xgemm
.
sububmrs
=
libxsmm_sububmmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
stride_A
,
stride_B
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
else
{
const
int
IFW
=
handle
->
ifwp
;
const
int
IFH
=
handle
->
ifhp
;
int
n_blocks
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
blocksifm_blocking
;
int
i
=
0
,
ifm
,
ki
,
kj
;
handle
->
A_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
handle
->
B_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
for
(
ifm
=
0
;
ifm
<
handle
->
blocksifm_blocking
;
ifm
++
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
kj
++
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
ki
++
)
{
handle
->
A_offsets
[
i
]
=
(
ifm
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
kj
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
ki
*
handle
->
ifmblock
*
handle
->
ofmblock
)
*
sizeof
(
char
);
handle
->
B_offsets
[
i
]
=
(
ifm
*
IFH
*
IFW
*
handle
->
ifmblock
+
kj
*
IFW
*
handle
->
ifmblock
+
ki
*
handle
->
ifmblock
)
*
sizeof
(
char
);
i
++
;
}
}
}
handle
->
gemm_fwd
.
xgemm
.
sububmro
=
libxsmm_sububmmdispatch_reducebatch_offs
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
}
#if 0
/* Spit out FWD parameters that are selected... */
printf("FWD params...\n");
printf("Fwd_ofw_rb = %d\n", handle->fwd_ofw_rb);
printf("Fwd_ofh_rb = %d\n", handle->fwd_ofh_rb);
printf("Pack input = %d\n", handle->pack_input);
printf("Block oj = %d\n", handle->block_fwd_oj);
printf("Loop order = %d\n", handle->loop_order);
printf("Blocksifm_blocking = %d\n", handle->blocksifm_blocking);
printf("Block fwd ofm = %d\n", handle->block_fwd_ofm);
printf("Block fwd ifm = %d\n", handle->block_fwd_ifm);
printf("Avoid rim fmas = %d\n", handle->avoid_fmas_in_rim);
printf("Ofm parallelization = %d\n", handle->use_ofm_parallelization);
printf("Shuffle filter accesses = %d\n", handle->shuffle_filter_accesses);
printf("Avoid acc load = %d\n", handle->avoid_acc_load);
printf("Fwd GEMM flags = %d\n", handle->fwd_flags);
#endif
/* BWD parameter setup */
handle
->
bwd_ofw_rb
=
libxsmm_dnn_convolution_setup_bwd_ofw_rb
(
handle
);
handle
->
bwd_ofh_rb
=
libxsmm_dnn_convolution_setup_bwd_ofh_rb
(
handle
);
handle
->
bwd_gemm_pixels
=
libxsmm_dnn_convolution_setup_bwd_pixels_gemm
(
handle
);
handle
->
pack_input_bwd
=
libxsmm_dnn_convolution_setup_pack_input_bwd
(
handle
);
handle
->
spread_input_bwd
=
libxsmm_dnn_convolution_setup_spread_input_bwd
(
handle
);
handle
->
blocksofm_blocking
=
libxsmm_dnn_convolution_setup_blocksofm_blocking
(
handle
);
handle
->
avoid_acc_load_bwd
=
libxsmm_dnn_convolution_setup_avoid_acc_load_bwd
(
handle
);
handle
->
use_ifm_parallelization
=
libxsmm_dnn_convolution_setup_use_ifm_parallelization
(
handle
);
handle
->
block_bwd_ofm
=
libxsmm_dnn_convolution_setup_block_bwd_OFM
(
handle
);
handle
->
block_bwd_ifm
=
libxsmm_dnn_convolution_setup_block_bwd_IFM
(
handle
);
handle
->
block_bwd_oj
=
libxsmm_dnn_convolution_setup_bwd_block_H
(
handle
);
handle
->
use_fallback_bwd_loops
=
libxsmm_dnn_convolution_setup_fallback_loops_bwd
(
handle
);
handle
->
bwd_flags
=
libxsmm_dnn_convolution_setup_init_bwd_gemm_flags
(
handle
);
if
(
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
&&
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
handle
->
block_bwd_ifm
=
1
;
handle
->
block_bwd_oj
=
handle
->
bwd_ofh_rb
;
ldx
=
((
libxsmm_blasint
)
handle
->
ofmblock
);
ldA
=
handle
->
ifmblock
;
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
handle
->
ifmblock
*
handle
->
desc
.
v
:
handle
->
ifmblock
;
beta
=
(
handle
->
avoid_acc_load_bwd
)
?
(
float
)
0
.
0
:
(
float
)
1
.
0
;
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
l_tc_flags
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
handle
->
bwd_compute_kernel_addr
=
NULL
;
handle
->
bwd_compute_kernel_offs
=
NULL
;
handle
->
bwd_compute_kernel_strd
=
NULL
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
size_t
stride_a
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
size_t
stride_b
=
handle
->
ofwp
*
handle
->
ofhp
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
bwd_compute_kernel_strd
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
handle
->
ifmblock
,
handle
->
bwd_gemm_pixels
,
handle
->
ofmblock
,
(
libxsmm_blasint
)
stride_a
,
(
libxsmm_blasint
)
stride_b
,
handle
->
blocksofm_blocking
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
}
else
{
int
n_blocks
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
blocksofm_blocking
;
int
i
=
0
,
ofm
,
ki
,
kj
;
handle
->
A_offsets_bwd
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
handle
->
B_offsets_bwd
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
for
(
ofm
=
0
;
ofm
<
handle
->
blocksofm_blocking
;
ofm
++
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
kj
++
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
ki
++
)
{
handle
->
A_offsets_bwd
[
i
]
=
(
ofm
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
kj
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
ki
*
handle
->
ifmblock
*
handle
->
ofmblock
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
B_offsets_bwd
[
i
]
=
(
ofm
*
handle
->
ofhp
*
handle
->
ofwp
*
handle
->
ofmblock
+
kj
*
handle
->
ofwp
*
handle
->
ofmblock
+
ki
*
handle
->
ofmblock
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
i
++
;
}
}
}
handle
->
bwd_compute_kernel_offs
=
libxsmm_bsmmdispatch_reducebatch_offs
(
handle
->
ifmblock
,
handle
->
bwd_gemm_pixels
,
handle
->
ofmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
}
handle
->
bwd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ifmblock
,
handle
->
bwd_gemm_pixels
,
handle
->
ofmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
}
#if 0
/* Spit out BWD parameters that are selected... */
printf("BWD params...\n");
printf("Bwd_ofw_rb = %d\n", handle->bwd_ofw_rb);
printf("Bwd_ofh_rb = %d\n", handle->bwd_ofh_rb);
printf("Pack input = %d\n", handle->pack_input_bwd);
printf("Spread input = %d\n", handle->spread_input_bwd);
printf("Blocksofm_blocking = %d\n", handle->blocksofm_blocking);
printf("Avoid acc load = %d\n", handle->avoid_acc_load_bwd);
printf("Ifm parallelization = %d\n", handle->use_ifm_parallelization);
printf("Block bwd ofm = %d\n", handle->block_bwd_ofm);
printf("Block bwd ifm = %d\n", handle->block_bwd_ifm);
printf("Block oj = %d\n", handle->block_bwd_oj);
#endif
handle
->
code_bwd
[
0
].
ptr
=
0
;
handle
->
code_bwd
[
1
].
ptr
=
0
;
handle
->
code_bwd
[
2
].
ptr
=
0
;
/* Transpose kernel used for filter transpose in bwd pass */
handle
->
tr_kernel
=
libxsmm_dispatch_meltw_unary
(
64
,
16
,
&
(
_ldi
),
&
(
_ldo
),
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
/* UPD parameter setup */
handle
->
upd_linearized_tasklist
=
libxsmm_dnn_convolution_setup_linearized_tasklist_upd
(
handle
);
handle
->
upd_avoid_rim_fmas
=
libxsmm_dnn_convolution_setup_avoid_rim_fmas_upd
(
handle
);
handle
->
upd_pack_input
=
libxsmm_dnn_convolution_setup_pack_input_upd
(
handle
);
handle
->
upd_use_batchreduce
=
libxsmm_dnn_convolution_setup_use_batchreduce_upd
(
handle
);
handle
->
upd_ofw_rb
=
libxsmm_dnn_convolution_setup_upd_ofw_rb
(
handle
);
handle
->
upd_ofh_rb
=
libxsmm_dnn_convolution_setup_upd_ofh_rb
(
handle
);
handle
->
upd_loop_order
=
libxsmm_dnn_convolution_setup_loop_order_upd
(
handle
);
handle
->
weight_copies
=
libxsmm_dnn_convolution_setup_weight_copies_upd
(
handle
);
handle
->
block_upd_ofm
=
libxsmm_dnn_convolution_setup_block_upd_OFM
(
handle
);
handle
->
block_upd_ifm
=
libxsmm_dnn_convolution_setup_block_upd_IFM
(
handle
);
handle
->
upd_loop_order
=
libxsmm_dnn_convolution_setup_loop_order_upd
(
handle
);
handle
->
upd_padding_copy
=
libxsmm_dnn_convolution_setup_upd_padding_copy
(
handle
);
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
libxsmm_dnn_convolution_setup_bf16_upd_amx
(
handle
);
}
else
{
libxsmm_dnn_convolution_setup_bf16_upd
(
handle
);
}
}
#if 0
/* Spit out UPD parameters that are selected... */
printf("UPD params...\n");
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) {
printf("BF16 path...\n");
printf("UPD use_hybrid_imgofm_parallelization = %d\n", handle->use_hybrid_imgofm_parallelization);
printf("UPD linearized_pixels = %d\n", handle->upd_linearized_pixels);
printf("UPD upd_trans_w_only = %d\n", handle->upd_trans_w_only);
printf("UPD on_the_fly_input_packing = %d\n", handle->on_the_fly_input_packing);
printf("UPD use_intermediate_f32_wt_tensor = %d\n", handle->use_intermediate_f32_wt_tensor);
printf("UPD pack to CNHW format = %d\n", handle->pack_to_cnhw);
printf("UPD batchreduce H pixels = %d\n", handle->batchreduce_h_pixels);
}
printf("UPD linearized tasks = %d\n", handle->upd_linearized_tasklist);
printf("UPD avoid rim fmas = %d\n", handle->upd_avoid_rim_fmas);
printf("UPD Pack input = %d\n", handle->upd_pack_input);
printf("UPD use batch-reduce GEMM = %d\n", handle->upd_use_batchreduce);
printf("Upd_ofw_rb = %d\n", handle->upd_ofw_rb);
printf("Upd_ofh_rb = %d\n", handle->upd_ofh_rb);
printf("UPD loop order = %d\n", handle->upd_loop_order);
printf("UPD weight_copies = %d\n", handle->weight_copies);
printf("Block upd ofm = %d\n", handle->block_upd_ofm);
printf("Block upd ifm = %d\n", handle->block_upd_ifm);
#endif
handle
->
code_upd
[
0
].
ptr
=
0
;
handle
->
code_upd
[
1
].
ptr
=
0
;
/* prepare barrier */
handle
->
barrier
=
libxsmm_barrier_create
(
handle
->
desc
.
threads
,
1
);
/* setup up scratch */
libxsmm_dnn_convolution_setup_fwd_scratch
(
handle
);
libxsmm_dnn_convolution_setup_bwd_scratch
(
handle
);
libxsmm_dnn_convolution_setup_upd_scratch
(
handle
);
handle
->
scratch
=
0
;
handle
->
scratch_size
=
LIBXSMM_MAX
(
handle
->
fwd_scratch_size
,
LIBXSMM_MAX
(
handle
->
bwd_scratch_size
,
handle
->
upd_scratch_size
)
);
return
status
;
}
#undef MIXED
#undef KHWC
#undef HWKC
#undef CHWK
#undef HWCK
LIBXSMM_API
libxsmm_dnn_layer
*
libxsmm_dnn_create_conv_layer
(
libxsmm_dnn_conv_desc
conv_desc
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_layer
*
handle
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
/* currently we don't support NCHW */
if
(
(
conv_desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NCHW
)
>
0
)
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW
;
return
0
;
}
/* currently we don't support KCRS */
if
(
(
conv_desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_KCRS
)
>
0
)
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS
;
return
0
;
}
/* we only support physical paddind in these days */
/* @TODO: add logical padding support for other datatypes than FP32 */
if
(
(
(
conv_desc
.
pad_h
!=
conv_desc
.
pad_h_in
)
||
(
conv_desc
.
pad_w
!=
conv_desc
.
pad_w_in
)
||
(
conv_desc
.
pad_h
!=
conv_desc
.
pad_h_out
)
||
(
conv_desc
.
pad_w
!=
conv_desc
.
pad_w_out
)
)
&&
(
conv_desc
.
datatype_in
!=
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
conv_desc
.
datatype_in
!=
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_PADDING
;
return
0
;
}
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle
=
(
libxsmm_dnn_layer
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_layer
));
if
(
0
!=
handle
)
{
/* initialize known handle components */
handle
->
desc
=
conv_desc
;
handle
->
datatype_in
=
conv_desc
.
datatype_in
;
handle
->
datatype_out
=
conv_desc
.
datatype_out
;
/* select the intermediate format, only applicable for integer types */
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_F32
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_F32
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_I32
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_I8
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_F32
)
)
{
/* error */
}
else
{
/* fine, no error */
}
handle
->
buffer_format
=
conv_desc
.
buffer_format
;
handle
->
filter_format
=
conv_desc
.
filter_format
;
handle
->
fuse_ops
=
conv_desc
.
fuse_ops
;
handle
->
options
=
conv_desc
.
options
;
/* derive additional values */
handle
->
ifhp
=
conv_desc
.
H
+
2
*
conv_desc
.
pad_h_in
;
handle
->
ifwp
=
conv_desc
.
W
+
2
*
conv_desc
.
pad_w_in
;
handle
->
ofh
=
(
conv_desc
.
H
+
2
*
conv_desc
.
pad_h
-
conv_desc
.
R
)
/
conv_desc
.
u
+
1
;
handle
->
ofw
=
(
conv_desc
.
W
+
2
*
conv_desc
.
pad_w
-
conv_desc
.
S
)
/
conv_desc
.
v
+
1
;
handle
->
ofhp
=
handle
->
ofh
+
2
*
conv_desc
.
pad_h_out
;
handle
->
ofwp
=
handle
->
ofw
+
2
*
conv_desc
.
pad_w_out
;
handle
->
ifmblock
=
1
;
handle
->
ofmblock
=
1
;
handle
->
blocksifm
=
conv_desc
.
C
;
handle
->
blocksofm
=
conv_desc
.
K
;
handle
->
fwd_ofw_rb
=
1
;
handle
->
fwd_ofh_rb
=
1
;
handle
->
bwd_ofw_rb
=
1
;
handle
->
bwd_ofh_rb
=
1
;
handle
->
upd_ofw_rb
=
1
;
handle
->
upd_ofh_rb
=
1
;
handle
->
fm_lp_block
=
1
;
handle
->
blocksifm_blocking
=
1
;
handle
->
blocksofm_blocking
=
1
;
/* Set algorithm to use */
if
(
conv_desc
.
algo
==
LIBXSMM_DNN_CONV_ALGO_AUTO
)
{
handle
->
algo
=
LIBXSMM_DNN_CONV_ALGO_DIRECT
;
}
else
{
handle
->
algo
=
conv_desc
.
algo
;
}
if
(
handle
->
algo
!=
LIBXSMM_DNN_CONV_ALGO_DIRECT
)
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_ALGO
;
free
(
handle
);
handle
=
0
;
return
0
;
}
*
status
=
libxsmm_dnn_convolution_setup
(
handle
);
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_HANDLE
;
}
/* account for eventually deallocated handle */
if
(
LIBXSMM_DNN_SUCCESS
!=
*
status
)
{
handle
=
0
;
}
return
handle
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_destroy_conv_layer
(
const
libxsmm_dnn_layer
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
/* Deallocate barrier */
if
(
handle
->
barrier
!=
0
)
{
libxsmm_barrier_release
((
const
libxsmm_barrier
*
)
handle
->
barrier
);
}
/* deallocate handle structure itself */
free
(
/*remove constness*/
(
libxsmm_dnn_layer
*
)
handle
);
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor_datalayout
*
libxsmm_dnn_create_tensor_datalayout
(
const
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor_datalayout
*
layout
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
layout
=
0
;
if
(
handle
!=
0
)
{
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout
=
(
libxsmm_dnn_tensor_datalayout
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_tensor_datalayout
));
if
(
layout
!=
0
)
{
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
||
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_ACTIVATION
;
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
/* @TODO this need to change */
}
else
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_I32
)
)
{
if
(
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
)
{
layout
->
datatype
=
handle
->
datatype_in
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
datatype
=
handle
->
datatype_out
;
}
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_BF16
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
6
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
6
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
/* coverity[dead_error_begin] */
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
if
(
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
)
{
if
(
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
)
)
{
layout
->
datatype
=
handle
->
datatype_in
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
)
{
layout
->
datatype
=
handle
->
datatype_out
;
}
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
/* coverity[dead_error_begin] */
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
)
>
0
)
{
if
(
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
4
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
*
handle
->
blocksifm
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
*
handle
->
blocksofm
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
N
;
}
else
{
/* coverity[dead_error_begin] */
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
||
(
type
==
LIBXSMM_DNN_FILTER
)
)
{
layout
->
format
=
handle
->
filter_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_FILTER
;
if
((
handle
->
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
6
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
6
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
6
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
2
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
4
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
5
]
=
handle
->
blocksofm
;
}
}
else
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_BF16
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
7
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
7
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
7
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
6
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
handle
->
fm_lp_block
;
layout
->
dim_size
[
1
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
2
]
=
handle
->
ifmblock
/
handle
->
fm_lp_block
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
5
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
6
]
=
handle
->
blocksofm
;
}
}
else
if
(
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
)
{
if
(
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
||
(
type
==
LIBXSMM_DNN_FILTER
)
)
{
layout
->
datatype
=
handle
->
datatype_in
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
{
layout
->
datatype
=
handle
->
datatype_out
;
}
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
7
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
7
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
if
((
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
||
(
type
==
LIBXSMM_DNN_FILTER
))
{
layout
->
num_dims
=
7
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
6
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
handle
->
fm_lp_block
;
layout
->
dim_size
[
1
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
2
]
=
handle
->
ifmblock
/
handle
->
fm_lp_block
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
5
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
6
]
=
handle
->
blocksofm
;
}
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
)
>
0
)
{
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
4
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
*
handle
->
blocksofm
;
layout
->
dim_size
[
1
]
=
handle
->
ifmblock
*
handle
->
blocksifm
;
layout
->
dim_size
[
2
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
R
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
{
layout
->
format
=
handle
->
filter_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_REGULAR_FILTER_TRANS
;
if
((
handle
->
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
6
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
6
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
6
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
2
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
4
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
5
]
=
handle
->
blocksifm
;
}
}
else
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_BF16
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
7
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
7
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
7
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
6
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
handle
->
fm_lp_block
;
layout
->
dim_size
[
1
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
2
]
=
handle
->
ofmblock
/
handle
->
fm_lp_block
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
5
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
6
]
=
handle
->
blocksifm
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
#if 0
} else if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) {
if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_size[0] = handle->ofmblock * handle->blocksofm;
layout->dim_size[1] = handle->ifmblock * handle->blocksifm;
layout->dim_size[2] = handle->desc.S;
layout->dim_size[3] = handle->desc.K;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
#endif
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
||
(
type
==
LIBXSMM_DNN_CHANNEL_BIAS
)
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_CHANNEL_SCALAR
;
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
layout
->
datatype
=
handle
->
datatype_out
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
blocksofm
;
}
#if 0
} else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) ) {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 3;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->fm_lp_block;
layout->dim_size[1] = handle->ofmblock;
layout->dim_size[2] = handle->blocksofm;
}
#endif
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
)
>
0
)
{
layout
->
datatype
=
handle
->
datatype_out
;
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
1
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
1
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
1
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
*
handle
->
blocksofm
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_BATCH_STATS
)
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_BATCH_STATS
;
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
||
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
N
;
layout
->
dim_size
[
2
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
3
]
=
2
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_FWD
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_MAX_STATS_FWD
;
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
N
;
}
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_BWD
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_MAX_STATS_BWD
;
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
N
;
}
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_UPD
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_MAX_STATS_UPD
;
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
N
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
layout
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_trans_reg_bf16_filter
(
const
libxsmm_dnn_layer
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
handle
!=
0
)
{
if
(
(
handle
->
reg_filter
!=
0
)
&&
(
handle
->
reg_filter_tr
!=
0
)
)
{
/* TODO handle more datatypes */
int
ifm1
,
ifm2
,
kj
,
ki
,
ofm1
,
ofm2
;
int
ofmblock_lp
=
handle
->
ofmblock
/
handle
->
fm_lp_block
;
int
ifmblock_lp
=
handle
->
ifmblock
/
handle
->
fm_lp_block
;
int
lpb
=
handle
->
fm_lp_block
;
LIBXSMM_VLA_DECL
(
7
,
libxsmm_bfloat16
,
wt
,
(
libxsmm_bfloat16
*
)
handle
->
reg_filter
->
data
,
handle
->
blocksifm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
ifmblock_lp
,
handle
->
ofmblock
,
lpb
);
LIBXSMM_VLA_DECL
(
7
,
libxsmm_bfloat16
,
tr_wt
,
(
libxsmm_bfloat16
*
)
handle
->
reg_filter_tr
->
data
,
handle
->
blocksofm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
ofmblock_lp
,
handle
->
ifmblock
,
lpb
);
/* TODO we might want to do this in parallel.... */
for
(
ifm1
=
0
;
ifm1
<
handle
->
blocksifm
;
++
ifm1
)
{
for
(
ofm1
=
0
;
ofm1
<
handle
->
blocksofm
;
++
ofm1
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
++
kj
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
++
ki
)
{
for
(
ofm2
=
0
;
ofm2
<
handle
->
ofmblock
;
++
ofm2
)
{
for
(
ifm2
=
0
;
ifm2
<
handle
->
ifmblock
;
++
ifm2
)
{
LIBXSMM_VLA_ACCESS
(
7
,
tr_wt
,
ifm1
,
ofm1
,
handle
->
desc
.
R
-
1
-
kj
,
handle
->
desc
.
S
-
1
-
ki
,
ofm2
/
lpb
,
ifm2
,
ofm2
%
lpb
,
handle
->
blocksofm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
ofmblock_lp
,
handle
->
ifmblock
,
lpb
)
=
LIBXSMM_VLA_ACCESS
(
7
,
wt
,
ofm1
,
ifm1
,
kj
,
ki
,
ifm2
/
lpb
,
ofm2
,
ifm2
%
lpb
,
handle
->
blocksifm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
ifmblock_lp
,
handle
->
ofmblock
,
lpb
);
}
}
}
}
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_trans_reg_filter
(
const
libxsmm_dnn_layer
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
handle
!=
0
)
{
if
(
(
handle
->
reg_filter
!=
0
)
&&
(
handle
->
reg_filter_tr
!=
0
)
)
{
/* TODO handle more datatypes */
int
ifm1
,
ifm2
,
kj
,
ki
,
ofm1
,
ofm2
;
LIBXSMM_VLA_DECL
(
6
,
float
,
wt
,
(
float
*
)
handle
->
reg_filter
->
data
,
handle
->
blocksifm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
handle
->
ifmblock
,
handle
->
ofmblock
);
LIBXSMM_VLA_DECL
(
6
,
float
,
tr_wt
,
(
float
*
)
handle
->
reg_filter_tr
->
data
,
handle
->
blocksofm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
handle
->
ofmblock
,
handle
->
ifmblock
);
/* TODO we might want to do this in parallel.... */
for
(
ifm1
=
0
;
ifm1
<
handle
->
blocksifm
;
++
ifm1
)
{
for
(
ofm1
=
0
;
ofm1
<
handle
->
blocksofm
;
++
ofm1
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
++
kj
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
++
ki
)
{
for
(
ofm2
=
0
;
ofm2
<
handle
->
ofmblock
;
++
ofm2
)
{
for
(
ifm2
=
0
;
ifm2
<
handle
->
ifmblock
;
++
ifm2
)
{
LIBXSMM_VLA_ACCESS
(
6
,
tr_wt
,
ifm1
,
ofm1
,
handle
->
desc
.
R
-
1
-
kj
,
handle
->
desc
.
S
-
1
-
ki
,
ofm2
,
ifm2
,
handle
->
blocksofm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
handle
->
ofmblock
,
handle
->
ifmblock
)
=
LIBXSMM_VLA_ACCESS
(
6
,
wt
,
ofm1
,
ifm1
,
kj
,
ki
,
ifm2
,
ofm2
,
handle
->
blocksifm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
handle
->
ifmblock
,
handle
->
ofmblock
);
}
}
}
}
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_bind_tensor
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_tensor
*
tensor
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_BATCH_STATS
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_FWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_BWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_UPD
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
&&
tensor
!=
0
)
{
libxsmm_dnn_tensor_datalayout
*
handle_layout
=
libxsmm_dnn_create_tensor_datalayout
(
handle
,
type
,
&
status
);
if
(
libxsmm_dnn_compare_tensor_datalayout
(
handle_layout
,
tensor
->
layout
,
&
status
)
==
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
handle
->
reg_input
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
handle
->
grad_input
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
handle
->
reg_output
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
handle
->
grad_output
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
{
handle
->
reg_filter
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
{
handle
->
grad_filter
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
{
handle
->
reg_bias
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
{
handle
->
grad_bias
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
{
handle
->
reg_filter_tr
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_BATCH_STATS
)
{
handle
->
batch_stats
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_FWD
)
{
handle
->
maxstats_fwd
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_BWD
)
{
handle
->
maxstats_bwd
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_UPD
)
{
handle
->
maxstats_upd
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_MISMATCH_TENSOR
;
}
libxsmm_dnn_destroy_tensor_datalayout
(
handle_layout
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor
*
libxsmm_dnn_get_tensor
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor
*
return_tensor
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_BATCH_STATS
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_FWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_BWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_UPD
)
)
{
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
return_tensor
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
return_tensor
=
handle
->
reg_input
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
return_tensor
=
handle
->
grad_input
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
return_tensor
=
handle
->
reg_output
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
return_tensor
=
handle
->
grad_output
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
{
return_tensor
=
handle
->
reg_filter
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
{
return_tensor
=
handle
->
grad_filter
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
{
return_tensor
=
handle
->
reg_bias
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
{
return_tensor
=
handle
->
grad_bias
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
{
return_tensor
=
handle
->
reg_filter_tr
;
}
else
if
(
type
==
LIBXSMM_DNN_BATCH_STATS
)
{
return_tensor
=
handle
->
batch_stats
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_FWD
)
{
return_tensor
=
handle
->
maxstats_fwd
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_BWD
)
{
return_tensor
=
handle
->
maxstats_bwd
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_UPD
)
{
return_tensor
=
handle
->
maxstats_upd
;
}
else
{
/* cannot happen */
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
return_tensor
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_release_tensor
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_BATCH_STATS
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_FWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_BWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_UPD
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
handle
->
reg_input
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
handle
->
grad_input
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
handle
->
reg_output
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
handle
->
grad_output
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
{
handle
->
reg_filter
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
{
handle
->
grad_filter
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
{
handle
->
reg_bias
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
{
handle
->
grad_bias
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
{
handle
->
reg_filter_tr
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_BATCH_STATS
)
{
handle
->
batch_stats
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_FWD
)
{
handle
->
maxstats_fwd
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_BWD
)
{
handle
->
maxstats_bwd
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_UPD
)
{
handle
->
maxstats_upd
=
0
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
status
;
}
LIBXSMM_API
size_t
libxsmm_dnn_get_scratch_size
(
const
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_compute_kind
kind
,
libxsmm_dnn_err_t
*
status
)
{
size_t
l_scratch_size
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
l_scratch_size
+=
handle
->
scratch_size
+
64
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
l_scratch_size
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_bind_scratch
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_compute_kind
kind
,
const
void
*
scratch
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
uintptr_t
address
=
(
uintptr_t
)
scratch
;
size_t
offset
=
0
;
if
(
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
if
(
0
!=
handle
)
{
if
(
address
%
64
==
0
)
{
handle
->
scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
handle
->
scratch_size
+
64
;
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_release_scratch
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_compute_kind
kind
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
handle
->
scratch
=
0
;
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API_INLINE
libxsmm_dnn_err_t
internal_execute_st
(
libxsmm_dnn_layer
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
switch
(
handle
->
algo
)
{
case
LIBXSMM_DNN_CONV_ALGO_DIRECT
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
switch
(
handle
->
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
:
{
status
=
libxsmm_dnn_convolve_st_fwd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_fwd_nhwc_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
{
switch
(
handle
->
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
:
{
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
{
switch
(
handle
->
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
:
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
{
switch
(
handle
->
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom
(
handle
,
start_thread
,
tid
);
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
:
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_custom
(
handle
,
start_thread
,
tid
);
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_ALGO
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_execute_st
(
libxsmm_dnn_layer
*
handle
,
libxsmm_dnn_compute_kind
kind
,
/*unsigned*/
int
start_thread
,
/*unsigned*/
int
tid
)
{
return
internal_execute_st
(
handle
,
kind
,
start_thread
,
tid
);
}
LIBXSMM_API
void
libxsmm_dnn_execute
(
libxsmm_dnn_layer
*
handle
,
libxsmm_dnn_compute_kind
kind
)
{
#if defined(_OPENMP)
# pragma omp parallel num_threads(handle->desc.threads)
{
const
int
tid
=
omp_get_thread_num
();
internal_execute_st
(
handle
,
kind
,
0
,
tid
);
}
#else
internal_execute_st
(
handle
,
kind
,
0
/*start_thread*/
,
0
/*tid*/
);
#endif
}
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Evangelos Georganas, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_convolution_backward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
bf16_vnni_transpose_16x16_kernel
(
void
*
source_void
,
void
*
dest_void
,
int
source_stride
,
int
dest_stride
)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
libxsmm_bfloat16
*
source
=
(
libxsmm_bfloat16
*
)
source_void
;
libxsmm_bfloat16
*
dest
=
(
libxsmm_bfloat16
*
)
dest_void
;
__m512i
zmm0
,
zmm1
,
zmm2
,
zmm3
,
zmm4
,
zmm5
,
zmm6
,
zmm7
;
__m512i
tmp0
,
tmp1
,
tmp2
,
tmp3
;
const
__m512i
abcdefgh_to_abefcdgh
=
_mm512_set4_epi32
(
0x0f0e0b0a
,
0x0d0c0908
,
0x07060302
,
0x05040100
);
zmm0
=
_mm512_load_epi32
(
source
);
zmm1
=
_mm512_load_epi32
(
source
+
source_stride
);
zmm2
=
_mm512_load_epi32
(
source
+
source_stride
*
2
);
zmm3
=
_mm512_load_epi32
(
source
+
source_stride
*
3
);
zmm4
=
_mm512_load_epi32
(
source
+
source_stride
*
4
);
zmm5
=
_mm512_load_epi32
(
source
+
source_stride
*
5
);
zmm6
=
_mm512_load_epi32
(
source
+
source_stride
*
6
);
zmm7
=
_mm512_load_epi32
(
source
+
source_stride
*
7
);
zmm0
=
_mm512_shuffle_epi8
(
zmm0
,
abcdefgh_to_abefcdgh
);
zmm1
=
_mm512_shuffle_epi8
(
zmm1
,
abcdefgh_to_abefcdgh
);
zmm2
=
_mm512_shuffle_epi8
(
zmm2
,
abcdefgh_to_abefcdgh
);
zmm3
=
_mm512_shuffle_epi8
(
zmm3
,
abcdefgh_to_abefcdgh
);
zmm4
=
_mm512_shuffle_epi8
(
zmm4
,
abcdefgh_to_abefcdgh
);
zmm5
=
_mm512_shuffle_epi8
(
zmm5
,
abcdefgh_to_abefcdgh
);
zmm6
=
_mm512_shuffle_epi8
(
zmm6
,
abcdefgh_to_abefcdgh
);
zmm7
=
_mm512_shuffle_epi8
(
zmm7
,
abcdefgh_to_abefcdgh
);
tmp0
=
_mm512_unpacklo_epi64
(
zmm0
,
zmm1
);
tmp1
=
_mm512_unpackhi_epi64
(
zmm0
,
zmm1
);
tmp2
=
_mm512_unpacklo_epi64
(
zmm2
,
zmm3
);
tmp3
=
_mm512_unpackhi_epi64
(
zmm2
,
zmm3
);
zmm0
=
_mm512_unpacklo_epi64
(
zmm4
,
zmm5
);
zmm1
=
_mm512_unpackhi_epi64
(
zmm4
,
zmm5
);
zmm2
=
_mm512_unpacklo_epi64
(
zmm6
,
zmm7
);
zmm3
=
_mm512_unpackhi_epi64
(
zmm6
,
zmm7
);
zmm4
=
_mm512_shuffle_i32x4
(
tmp0
,
tmp2
,
0x88
);
zmm6
=
_mm512_shuffle_i32x4
(
tmp0
,
tmp2
,
0xdd
);
zmm5
=
_mm512_shuffle_i32x4
(
tmp1
,
tmp3
,
0x88
);
zmm7
=
_mm512_shuffle_i32x4
(
tmp1
,
tmp3
,
0xdd
);
tmp0
=
_mm512_shuffle_i32x4
(
zmm0
,
zmm2
,
0x88
);
tmp1
=
_mm512_shuffle_i32x4
(
zmm0
,
zmm2
,
0xdd
);
tmp2
=
_mm512_shuffle_i32x4
(
zmm1
,
zmm3
,
0x88
);
tmp3
=
_mm512_shuffle_i32x4
(
zmm1
,
zmm3
,
0xdd
);
zmm0
=
_mm512_shuffle_i32x4
(
zmm4
,
tmp0
,
0x88
);
zmm1
=
_mm512_shuffle_i32x4
(
zmm5
,
tmp2
,
0x88
);
zmm2
=
_mm512_shuffle_i32x4
(
zmm6
,
tmp1
,
0x88
);
zmm3
=
_mm512_shuffle_i32x4
(
zmm7
,
tmp3
,
0x88
);
zmm4
=
_mm512_shuffle_i32x4
(
zmm4
,
tmp0
,
0xdd
);
zmm5
=
_mm512_shuffle_i32x4
(
zmm5
,
tmp2
,
0xdd
);
zmm6
=
_mm512_shuffle_i32x4
(
zmm6
,
tmp1
,
0xdd
);
zmm7
=
_mm512_shuffle_i32x4
(
zmm7
,
tmp3
,
0xdd
);
_mm512_store_epi32
(
dest
,
zmm0
);
_mm512_store_epi32
(
dest
+
dest_stride
,
zmm1
);
_mm512_store_epi32
(
dest
+
dest_stride
*
2
,
zmm2
);
_mm512_store_epi32
(
dest
+
dest_stride
*
3
,
zmm3
);
_mm512_store_epi32
(
dest
+
dest_stride
*
4
,
zmm4
);
_mm512_store_epi32
(
dest
+
dest_stride
*
5
,
zmm5
);
_mm512_store_epi32
(
dest
+
dest_stride
*
6
,
zmm6
);
_mm512_store_epi32
(
dest
+
dest_stride
*
7
,
zmm7
);
#else
LIBXSMM_UNUSED
(
source_void
);
LIBXSMM_UNUSED
(
dest_void
);
LIBXSMM_UNUSED
(
source_stride
);
LIBXSMM_UNUSED
(
dest_stride
);
#endif
}
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
bf16_vnni_transpose_kernel
(
libxsmm_bfloat16
*
src
,
libxsmm_bfloat16
*
dst
,
int
M
,
int
N
,
int
ld_in
,
int
ld_out
)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
const
int
_M
=
M
/
16
,
_N
=
N
/
16
;
int
i
=
0
,
j
=
0
;
for
(
i
=
0
;
i
<
_N
;
i
++
)
{
for
(
j
=
0
;
j
<
_M
;
j
++
)
{
bf16_vnni_transpose_16x16_kernel
((
libxsmm_bfloat16
*
)
src
+
i
*
16
*
ld_in
+
j
*
32
,
(
libxsmm_bfloat16
*
)
dst
+
j
*
16
*
ld_out
+
i
*
32
,
ld_in
*
2
,
ld_out
*
2
);
}
}
#else
LIBXSMM_UNUSED
(
src
);
LIBXSMM_UNUSED
(
dst
);
LIBXSMM_UNUSED
(
M
);
LIBXSMM_UNUSED
(
N
);
LIBXSMM_UNUSED
(
ld_in
);
LIBXSMM_UNUSED
(
ld_out
);
#endif
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)
handle
->
ofmblock
;
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c"
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c"
}
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
typedef
libxsmm_bmmfunction_reducebatch_addr
gemm_br_function_bf16bf16
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)
handle
->
ofmblock
;
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel2_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
}
else
{
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
brgemm_function
;
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
int
stride_a
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
libxsmm_bfloat16
);
int
stride_b
=
handle
->
ofmblock
*
handle
->
ofwp
*
handle
->
ofhp
*
sizeof
(
libxsmm_bfloat16
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function
bf16fp32_brgemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
stride_a
,
stride_b
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
&
l_flags
,
NULL
);
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
typedef
libxsmm_bsmmfunction_reducebatch_offs
gemm_br_function_offs
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
gemm_br_function_strd
;
gemm_br_function_offs
br_gemm_kernel_offs
=
handle
->
bwd_compute_kernel_offs
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
bwd_compute_kernel_strd
;
gemm_function
tile_config_kernel
=
handle
->
bwd_config_kernel
;
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
}
else
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
brgemm_function
;
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
int
stride_a
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
libxsmm_bfloat16
);
int
stride_b
=
handle
->
ofmblock
*
handle
->
ofwp
*
handle
->
ofhp
*
sizeof
(
libxsmm_bfloat16
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function
bf16fp32_brgemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
stride_a
,
stride_b
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
&
l_flags
,
NULL
);
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
typedef
libxsmm_bmmfunction_reducebatch_addr
gemm_br_function_bf16bf16
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)
handle
->
ofmblock
;
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel2_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
else
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
brgemm_function
;
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
int
stride_a
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
libxsmm_bfloat16
);
int
stride_b
=
handle
->
ofmblock
*
handle
->
ofwp
*
handle
->
ofhp
*
sizeof
(
libxsmm_bfloat16
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function
bf16fp32_brgemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
stride_a
,
stride_b
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
&
l_flags
,
NULL
);
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bsmmfunction
gemm_function
;
typedef
libxsmm_bsmmfunction_reducebatch_offs
gemm_br_function_offs
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
gemm_br_function_strd
;
gemm_br_function_offs
br_gemm_kernel_offs
=
handle
->
bwd_compute_kernel_offs
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
bwd_compute_kernel_strd
;
gemm_function
tile_config_kernel
=
handle
->
bwd_config_kernel
;
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
else
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
brgemm_function
;
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
int
stride_a
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
libxsmm_bfloat16
);
int
stride_b
=
handle
->
ofmblock
*
handle
->
ofwp
*
handle
->
ofhp
*
sizeof
(
libxsmm_bfloat16
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function
bf16fp32_brgemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
stride_a
,
stride_b
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
&
l_flags
,
NULL
);
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
);
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
||
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
);
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
||
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
grad_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
reg_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldx
=
((
libxsmm_blasint
)
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
handle
->
ifmblock
*
handle
->
desc
.
v
:
handle
->
ifmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c"
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldx
=
((
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
NULL
,
NULL
,
&
ldx
,
NULL
,
NULL
,
NULL
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c"
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
grad_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
reg_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
);
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
||
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
grad_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
reg_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
);
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
||
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Rajkishore Barik, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_CONVOLUTION_BACKWARD_H
#define LIBXSMM_DNN_CONVOLUTION_BACKWARD_H
#include <libxsmm_dnn_convolution.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_CONVOLUTION_BACKWARD_H */
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas, Hans Pabst (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_convolution_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
#if 1
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function_addr
;
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function_addr
br_gemm_kernel_a_addr
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function_addr
br_gemm_kernel_b_addr
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
#else
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function_addr
;
typedef
libxsmm_smmfunction_reducebatch_offs
gemm_br_function_offs
;
typedef
libxsmm_smmfunction_reducebatch_strd
gemm_br_function_strd
;
{
gemm_br_function_addr
br_gemm_kernel_a_addr
=
handle
->
fwd_compute_kernel_addr_a_f32
;
gemm_br_function_addr
br_gemm_kernel_b_addr
=
handle
->
fwd_compute_kernel_addr_b_f32
;
gemm_br_function_offs
br_gemm_kernel_offs
=
handle
->
fwd_compute_kernel_offs_f32
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
fwd_compute_kernel_strd_f32
;
#endif
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c"
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
typedef
libxsmm_bmmfunction_reducebatch_addr
gemm_br_function_bf16bf16
;
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
handle
->
fwd_flags
;
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel2_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction
gemm_function
;
typedef
libxsmm_bmmfunction_reducebatch_offs
gemm_br_function_offs_a
;
typedef
libxsmm_bsmmfunction_reducebatch_offs
gemm_br_function_offs_b
;
typedef
libxsmm_bmmfunction_reducebatch_strd
gemm_br_function_strd
;
gemm_br_function_offs_a
br_gemm_kernel_offs_a
=
handle
->
fwd_compute_kernel_offs_a
;
gemm_br_function_offs_b
br_gemm_kernel_offs_b
=
handle
->
fwd_compute_kernel_offs_b
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
fwd_compute_kernel_strd
;
gemm_function
tile_config_kernel
=
handle
->
fwd_config_kernel
;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
typedef
libxsmm_bmmfunction_reducebatch_addr
gemm_br_function_bf16bf16
;
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
handle
->
fwd_flags
;
gemm_br_function
br_gemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel2_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction
gemm_function
;
typedef
libxsmm_bmmfunction_reducebatch_offs
gemm_br_function_offs_a
;
typedef
libxsmm_bsmmfunction_reducebatch_offs
gemm_br_function_offs_b
;
typedef
libxsmm_bmmfunction_reducebatch_strd
gemm_br_function_strd
;
gemm_br_function_offs_a
br_gemm_kernel_offs_a
=
handle
->
fwd_compute_kernel_offs_a
;
gemm_br_function_offs_b
br_gemm_kernel_offs_b
=
handle
->
fwd_compute_kernel_offs_b
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
fwd_compute_kernel_strd
;
gemm_function
tile_config_kernel
=
handle
->
fwd_config_kernel
;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
unsigned
char
element_input_type
;
typedef
int
element_output_type
;
typedef
char
element_filter_type
;
/* Basically we need only offset based and strided BRGEMMs */
libxsmm_subimmfunction_reducebatch_strd
br_gemm_kernel_strided
=
handle
->
gemm_fwd
.
xgemm
.
subimrs
;
libxsmm_subimmfunction_reducebatch_strd
br_gemm_kernel_strided2
=
handle
->
gemm_fwd2
.
xgemm
.
subimrs
;
libxsmm_subimmfunction_reducebatch_offs
br_gemm_kernel_offset
=
handle
->
gemm_fwd
.
xgemm
.
subimro
;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i32.tpl.c"
#else
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
unsigned
char
element_input_type
;
typedef
unsigned
char
element_output_type
;
typedef
char
element_filter_type
;
/* Basically we need only offset based and strided BRGEMMs */
libxsmm_sububmmfunction_reducebatch_strd
br_gemm_kernel_strided
=
handle
->
gemm_fwd
.
xgemm
.
sububmrs
;
libxsmm_sububmmfunction_reducebatch_offs
br_gemm_kernel_offset
=
handle
->
gemm_fwd
.
xgemm
.
sububmro
;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i8.tpl.c"
#else
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
reg_output
==
0
||
handle
->
reg_filter
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_I32
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_I8
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8
(
handle
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
#if 1
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function_addr
;
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function_addr
br_gemm_kernel_a_addr
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function_addr
br_gemm_kernel_b_addr
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
#else
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function_addr
;
typedef
libxsmm_smmfunction_reducebatch_offs
gemm_br_function_offs
;
typedef
libxsmm_smmfunction_reducebatch_strd
gemm_br_function_strd
;
{
gemm_br_function_addr
br_gemm_kernel_a_addr
=
handle
->
fwd_compute_kernel_addr_a_f32
;
gemm_br_function_addr
br_gemm_kernel_b_addr
=
handle
->
fwd_compute_kernel_addr_b_f32
;
gemm_br_function_offs
br_gemm_kernel_offs
=
handle
->
fwd_compute_kernel_offs_f32
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
fwd_compute_kernel_strd_f32
;
#endif
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c"
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
reg_output
==
0
||
handle
->
reg_filter
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
reg_output
==
0
||
handle
->
reg_filter
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_CONVOLUTION_FORWARD_H
#define LIBXSMM_DNN_CONVOLUTION_FORWARD_H
#include <libxsmm_dnn_convolution.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_CONVOLUTION_FORWARD_H */
third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Rajkishore Barik, Alexander Heinecke, Ankush Mandal, Jason Sewall (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_convolution_weight_update.h"
#include "libxsmm_main.h"
/* function prototypes for below implementations */
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
transpose_32x16
(
const
libxsmm_bfloat16
*
in
,
libxsmm_bfloat16
*
out
,
int
ld_in
,
int
ld_out
)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
,
r9
,
ra
,
rb
,
rc
,
rd
,
re
,
rf
;
__m512i
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
,
t8
,
t9
,
ta
,
tb
,
tc
,
td
,
te
,
tf
;
const
int
in_width
=
ld_in
,
out_width
=
ld_out
;
const
__m512i
idx_lo
=
_mm512_set_epi64
(
13
,
12
,
5
,
4
,
9
,
8
,
1
,
0
);
const
__m512i
idx_hi
=
_mm512_set_epi64
(
7
,
6
,
15
,
14
,
3
,
2
,
11
,
10
);
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
rc
=
_mm512_loadu_si512
(
in
+
12
*
in_width
);
rd
=
_mm512_loadu_si512
(
in
+
13
*
in_width
);
re
=
_mm512_loadu_si512
(
in
+
14
*
in_width
);
rf
=
_mm512_loadu_si512
(
in
+
15
*
in_width
);
t0
=
_mm512_unpacklo_epi16
(
r0
,
r1
);
t1
=
_mm512_unpackhi_epi16
(
r0
,
r1
);
t2
=
_mm512_unpacklo_epi16
(
r2
,
r3
);
t3
=
_mm512_unpackhi_epi16
(
r2
,
r3
);
t4
=
_mm512_unpacklo_epi16
(
r4
,
r5
);
t5
=
_mm512_unpackhi_epi16
(
r4
,
r5
);
t6
=
_mm512_unpacklo_epi16
(
r6
,
r7
);
t7
=
_mm512_unpackhi_epi16
(
r6
,
r7
);
t8
=
_mm512_unpacklo_epi16
(
r8
,
r9
);
t9
=
_mm512_unpackhi_epi16
(
r8
,
r9
);
ta
=
_mm512_unpacklo_epi16
(
ra
,
rb
);
tb
=
_mm512_unpackhi_epi16
(
ra
,
rb
);
tc
=
_mm512_unpacklo_epi16
(
rc
,
rd
);
td
=
_mm512_unpackhi_epi16
(
rc
,
rd
);
te
=
_mm512_unpacklo_epi16
(
re
,
rf
);
tf
=
_mm512_unpackhi_epi16
(
re
,
rf
);
r0
=
_mm512_unpacklo_epi32
(
t0
,
t2
);
r1
=
_mm512_unpackhi_epi32
(
t0
,
t2
);
r2
=
_mm512_unpacklo_epi32
(
t1
,
t3
);
r3
=
_mm512_unpackhi_epi32
(
t1
,
t3
);
r4
=
_mm512_unpacklo_epi32
(
t4
,
t6
);
r5
=
_mm512_unpackhi_epi32
(
t4
,
t6
);
r6
=
_mm512_unpacklo_epi32
(
t5
,
t7
);
r7
=
_mm512_unpackhi_epi32
(
t5
,
t7
);
r8
=
_mm512_unpacklo_epi32
(
t8
,
ta
);
r9
=
_mm512_unpackhi_epi32
(
t8
,
ta
);
ra
=
_mm512_unpacklo_epi32
(
t9
,
tb
);
rb
=
_mm512_unpackhi_epi32
(
t9
,
tb
);
rc
=
_mm512_unpacklo_epi32
(
tc
,
te
);
rd
=
_mm512_unpackhi_epi32
(
tc
,
te
);
re
=
_mm512_unpacklo_epi32
(
td
,
tf
);
rf
=
_mm512_unpackhi_epi32
(
td
,
tf
);
t0
=
_mm512_unpacklo_epi64
(
r0
,
r4
);
t1
=
_mm512_unpackhi_epi64
(
r0
,
r4
);
t2
=
_mm512_unpacklo_epi64
(
r1
,
r5
);
t3
=
_mm512_unpackhi_epi64
(
r1
,
r5
);
t4
=
_mm512_unpacklo_epi64
(
r2
,
r6
);
t5
=
_mm512_unpackhi_epi64
(
r2
,
r6
);
t6
=
_mm512_unpacklo_epi64
(
r3
,
r7
);
t7
=
_mm512_unpackhi_epi64
(
r3
,
r7
);
t8
=
_mm512_unpacklo_epi64
(
r8
,
rc
);
t9
=
_mm512_unpackhi_epi64
(
r8
,
rc
);
ta
=
_mm512_unpacklo_epi64
(
r9
,
rd
);
tb
=
_mm512_unpackhi_epi64
(
r9
,
rd
);
tc
=
_mm512_unpacklo_epi64
(
ra
,
re
);
td
=
_mm512_unpackhi_epi64
(
ra
,
re
);
te
=
_mm512_unpacklo_epi64
(
rb
,
rf
);
tf
=
_mm512_unpackhi_epi64
(
rb
,
rf
);
r0
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0x88
);
r1
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0x88
);
r2
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0x88
);
r3
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0x88
);
r4
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0xdd
);
r5
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0xdd
);
r6
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0xdd
);
r7
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0xdd
);
r8
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0x88
);
r9
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0x88
);
ra
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0x88
);
rb
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0x88
);
rc
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0xdd
);
rd
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0xdd
);
re
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0xdd
);
rf
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0xdd
);
t0
=
_mm512_permutex2var_epi64
(
r0
,
idx_lo
,
r8
);
t1
=
_mm512_permutex2var_epi64
(
r1
,
idx_lo
,
r9
);
t2
=
_mm512_permutex2var_epi64
(
r2
,
idx_lo
,
ra
);
t3
=
_mm512_permutex2var_epi64
(
r3
,
idx_lo
,
rb
);
t4
=
_mm512_permutex2var_epi64
(
r4
,
idx_lo
,
rc
);
t5
=
_mm512_permutex2var_epi64
(
r5
,
idx_lo
,
rd
);
t6
=
_mm512_permutex2var_epi64
(
r6
,
idx_lo
,
re
);
t7
=
_mm512_permutex2var_epi64
(
r7
,
idx_lo
,
rf
);
t8
=
_mm512_permutex2var_epi64
(
r8
,
idx_hi
,
r0
);
t9
=
_mm512_permutex2var_epi64
(
r9
,
idx_hi
,
r1
);
ta
=
_mm512_permutex2var_epi64
(
ra
,
idx_hi
,
r2
);
tb
=
_mm512_permutex2var_epi64
(
rb
,
idx_hi
,
r3
);
tc
=
_mm512_permutex2var_epi64
(
rc
,
idx_hi
,
r4
);
td
=
_mm512_permutex2var_epi64
(
rd
,
idx_hi
,
r5
);
te
=
_mm512_permutex2var_epi64
(
re
,
idx_hi
,
r6
);
tf
=
_mm512_permutex2var_epi64
(
rf
,
idx_hi
,
r7
);
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
0
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t0
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
1
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t0
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
2
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t1
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
3
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t1
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
4
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t2
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
5
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t2
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
6
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t3
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
7
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t3
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
8
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t4
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
9
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t4
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
10
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t5
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
11
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t5
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
12
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t6
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
13
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t6
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
14
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t7
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
15
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t7
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
16
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t8
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
17
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t8
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
18
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t9
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
19
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t9
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
20
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
ta
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
21
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
ta
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
22
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tb
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
23
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tb
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
24
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tc
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
25
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tc
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
26
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
td
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
27
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
td
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
28
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
te
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
29
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
te
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
30
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tf
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
31
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tf
,
1
));
#else
LIBXSMM_UNUSED
(
in
);
LIBXSMM_UNUSED
(
out
);
LIBXSMM_UNUSED
(
ld_in
);
LIBXSMM_UNUSED
(
ld_out
);
#endif
}
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
transpose_32xcols
(
const
libxsmm_bfloat16
*
in
,
libxsmm_bfloat16
*
out
,
int
col
,
int
ld_in
,
int
ld_out
)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
,
r9
,
ra
,
rb
,
rc
,
rd
,
re
,
rf
;
__m512i
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
,
t8
,
t9
,
ta
,
tb
,
tc
,
td
,
te
,
tf
;
const
int
in_width
=
ld_in
,
out_width
=
ld_out
;
const
__m512i
idx_lo
=
_mm512_set_epi64
(
13
,
12
,
5
,
4
,
9
,
8
,
1
,
0
);
const
__m512i
idx_hi
=
_mm512_set_epi64
(
7
,
6
,
15
,
14
,
3
,
2
,
11
,
10
);
__mmask16
store_mask
=
LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16
(((
unsigned
int
)
1
<<
col
)
-
1
);
rf
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
if
(
col
==
15
)
{
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
rc
=
_mm512_loadu_si512
(
in
+
12
*
in_width
);
rd
=
_mm512_loadu_si512
(
in
+
13
*
in_width
);
re
=
_mm512_loadu_si512
(
in
+
14
*
in_width
);
}
else
if
(
col
==
14
)
{
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
rc
=
_mm512_loadu_si512
(
in
+
12
*
in_width
);
rd
=
_mm512_loadu_si512
(
in
+
13
*
in_width
);
}
else
if
(
col
==
13
)
{
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
rc
=
_mm512_loadu_si512
(
in
+
12
*
in_width
);
}
else
if
(
col
==
12
)
{
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
}
else
if
(
col
==
11
)
{
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
}
else
if
(
col
==
10
)
{
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
}
else
if
(
col
==
9
)
{
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
}
else
if
(
col
==
8
)
{
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
}
else
if
(
col
==
7
)
{
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
}
else
if
(
col
==
6
)
{
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
}
else
if
(
col
==
5
)
{
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
}
else
if
(
col
==
4
)
{
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
}
else
if
(
col
==
3
)
{
r3
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
}
else
if
(
col
==
2
)
{
r2
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r3
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
}
else
if
(
col
==
1
)
{
r1
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r2
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r3
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
}
else
{
r0
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r1
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r2
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r3
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
}
t0
=
_mm512_unpacklo_epi16
(
r0
,
r1
);
t1
=
_mm512_unpackhi_epi16
(
r0
,
r1
);
t2
=
_mm512_unpacklo_epi16
(
r2
,
r3
);
t3
=
_mm512_unpackhi_epi16
(
r2
,
r3
);
t4
=
_mm512_unpacklo_epi16
(
r4
,
r5
);
t5
=
_mm512_unpackhi_epi16
(
r4
,
r5
);
t6
=
_mm512_unpacklo_epi16
(
r6
,
r7
);
t7
=
_mm512_unpackhi_epi16
(
r6
,
r7
);
t8
=
_mm512_unpacklo_epi16
(
r8
,
r9
);
t9
=
_mm512_unpackhi_epi16
(
r8
,
r9
);
ta
=
_mm512_unpacklo_epi16
(
ra
,
rb
);
tb
=
_mm512_unpackhi_epi16
(
ra
,
rb
);
tc
=
_mm512_unpacklo_epi16
(
rc
,
rd
);
td
=
_mm512_unpackhi_epi16
(
rc
,
rd
);
te
=
_mm512_unpacklo_epi16
(
re
,
rf
);
tf
=
_mm512_unpackhi_epi16
(
re
,
rf
);
r0
=
_mm512_unpacklo_epi32
(
t0
,
t2
);
r1
=
_mm512_unpackhi_epi32
(
t0
,
t2
);
r2
=
_mm512_unpacklo_epi32
(
t1
,
t3
);
r3
=
_mm512_unpackhi_epi32
(
t1
,
t3
);
r4
=
_mm512_unpacklo_epi32
(
t4
,
t6
);
r5
=
_mm512_unpackhi_epi32
(
t4
,
t6
);
r6
=
_mm512_unpacklo_epi32
(
t5
,
t7
);
r7
=
_mm512_unpackhi_epi32
(
t5
,
t7
);
r8
=
_mm512_unpacklo_epi32
(
t8
,
ta
);
r9
=
_mm512_unpackhi_epi32
(
t8
,
ta
);
ra
=
_mm512_unpacklo_epi32
(
t9
,
tb
);
rb
=
_mm512_unpackhi_epi32
(
t9
,
tb
);
rc
=
_mm512_unpacklo_epi32
(
tc
,
te
);
rd
=
_mm512_unpackhi_epi32
(
tc
,
te
);
re
=
_mm512_unpacklo_epi32
(
td
,
tf
);
rf
=
_mm512_unpackhi_epi32
(
td
,
tf
);
t0
=
_mm512_unpacklo_epi64
(
r0
,
r4
);
t1
=
_mm512_unpackhi_epi64
(
r0
,
r4
);
t2
=
_mm512_unpacklo_epi64
(
r1
,
r5
);
t3
=
_mm512_unpackhi_epi64
(
r1
,
r5
);
t4
=
_mm512_unpacklo_epi64
(
r2
,
r6
);
t5
=
_mm512_unpackhi_epi64
(
r2
,
r6
);
t6
=
_mm512_unpacklo_epi64
(
r3
,
r7
);
t7
=
_mm512_unpackhi_epi64
(
r3
,
r7
);
t8
=
_mm512_unpacklo_epi64
(
r8
,
rc
);
t9
=
_mm512_unpackhi_epi64
(
r8
,
rc
);
ta
=
_mm512_unpacklo_epi64
(
r9
,
rd
);
tb
=
_mm512_unpackhi_epi64
(
r9
,
rd
);
tc
=
_mm512_unpacklo_epi64
(
ra
,
re
);
td
=
_mm512_unpackhi_epi64
(
ra
,
re
);
te
=
_mm512_unpacklo_epi64
(
rb
,
rf
);
tf
=
_mm512_unpackhi_epi64
(
rb
,
rf
);
r0
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0x88
);
r1
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0x88
);
r2
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0x88
);
r3
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0x88
);
r4
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0xdd
);
r5
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0xdd
);
r6
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0xdd
);
r7
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0xdd
);
r8
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0x88
);
r9
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0x88
);
ra
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0x88
);
rb
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0x88
);
rc
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0xdd
);
rd
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0xdd
);
re
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0xdd
);
rf
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0xdd
);
t0
=
_mm512_permutex2var_epi64
(
r0
,
idx_lo
,
r8
);
t1
=
_mm512_permutex2var_epi64
(
r1
,
idx_lo
,
r9
);
t2
=
_mm512_permutex2var_epi64
(
r2
,
idx_lo
,
ra
);
t3
=
_mm512_permutex2var_epi64
(
r3
,
idx_lo
,
rb
);
t4
=
_mm512_permutex2var_epi64
(
r4
,
idx_lo
,
rc
);
t5
=
_mm512_permutex2var_epi64
(
r5
,
idx_lo
,
rd
);
t6
=
_mm512_permutex2var_epi64
(
r6
,
idx_lo
,
re
);
t7
=
_mm512_permutex2var_epi64
(
r7
,
idx_lo
,
rf
);
t8
=
_mm512_permutex2var_epi64
(
r8
,
idx_hi
,
r0
);
t9
=
_mm512_permutex2var_epi64
(
r9
,
idx_hi
,
r1
);
ta
=
_mm512_permutex2var_epi64
(
ra
,
idx_hi
,
r2
);
tb
=
_mm512_permutex2var_epi64
(
rb
,
idx_hi
,
r3
);
tc
=
_mm512_permutex2var_epi64
(
rc
,
idx_hi
,
r4
);
td
=
_mm512_permutex2var_epi64
(
rd
,
idx_hi
,
r5
);
te
=
_mm512_permutex2var_epi64
(
re
,
idx_hi
,
r6
);
tf
=
_mm512_permutex2var_epi64
(
rf
,
idx_hi
,
r7
);
_mm256_mask_storeu_epi16
(
out
+
0
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t0
,
0
));
_mm256_mask_storeu_epi16
(
out
+
1
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t0
,
1
));
_mm256_mask_storeu_epi16
(
out
+
2
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t1
,
0
));
_mm256_mask_storeu_epi16
(
out
+
3
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t1
,
1
));
_mm256_mask_storeu_epi16
(
out
+
4
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t2
,
0
));
_mm256_mask_storeu_epi16
(
out
+
5
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t2
,
1
));
_mm256_mask_storeu_epi16
(
out
+
6
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t3
,
0
));
_mm256_mask_storeu_epi16
(
out
+
7
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t3
,
1
));
_mm256_mask_storeu_epi16
(
out
+
8
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t4
,
0
));
_mm256_mask_storeu_epi16
(
out
+
9
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t4
,
1
));
_mm256_mask_storeu_epi16
(
out
+
10
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t5
,
0
));
_mm256_mask_storeu_epi16
(
out
+
11
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t5
,
1
));
_mm256_mask_storeu_epi16
(
out
+
12
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t6
,
0
));
_mm256_mask_storeu_epi16
(
out
+
13
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t6
,
1
));
_mm256_mask_storeu_epi16
(
out
+
14
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t7
,
0
));
_mm256_mask_storeu_epi16
(
out
+
15
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t7
,
1
));
_mm256_mask_storeu_epi16
(
out
+
16
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t8
,
0
));
_mm256_mask_storeu_epi16
(
out
+
17
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t8
,
1
));
_mm256_mask_storeu_epi16
(
out
+
18
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t9
,
0
));
_mm256_mask_storeu_epi16
(
out
+
19
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t9
,
1
));
_mm256_mask_storeu_epi16
(
out
+
20
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
ta
,
0
));
_mm256_mask_storeu_epi16
(
out
+
21
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
ta
,
1
));
_mm256_mask_storeu_epi16
(
out
+
22
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tb
,
0
));
_mm256_mask_storeu_epi16
(
out
+
23
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tb
,
1
));
_mm256_mask_storeu_epi16
(
out
+
24
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tc
,
0
));
_mm256_mask_storeu_epi16
(
out
+
25
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tc
,
1
));
_mm256_mask_storeu_epi16
(
out
+
26
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
td
,
0
));
_mm256_mask_storeu_epi16
(
out
+
27
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
td
,
1
));
_mm256_mask_storeu_epi16
(
out
+
28
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
te
,
0
));
_mm256_mask_storeu_epi16
(
out
+
29
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
te
,
1
));
_mm256_mask_storeu_epi16
(
out
+
30
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tf
,
0
));
_mm256_mask_storeu_epi16
(
out
+
31
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tf
,
1
));
#else
LIBXSMM_UNUSED
(
in
);
LIBXSMM_UNUSED
(
out
);
LIBXSMM_UNUSED
(
col
);
LIBXSMM_UNUSED
(
ld_in
);
LIBXSMM_UNUSED
(
ld_out
);
#endif
}
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
transpose_input_pixels_bf16
(
const
libxsmm_bfloat16
*
in
,
libxsmm_bfloat16
*
out
,
int
M
,
int
N
,
int
ld_in
,
int
ld_out
){
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
int
i
,
j
;
int
full16_chunks
=
N
/
16
;
int
remainder_cols
=
N
%
16
;
int
_N
=
N
-
remainder_cols
;
if
(
full16_chunks
)
{
for
(
i
=
0
;
i
<
M
;
i
+=
32
)
{
for
(
j
=
0
;
j
<
_N
;
j
+=
16
)
{
transpose_32x16
((
const
libxsmm_bfloat16
*
)
in
+
i
+
ld_in
*
j
,
(
libxsmm_bfloat16
*
)
out
+
j
+
i
*
ld_out
,
ld_in
,
ld_out
);
}
}
}
if
(
remainder_cols
)
{
for
(
i
=
0
;
i
<
M
;
i
+=
32
)
{
transpose_32xcols
((
const
libxsmm_bfloat16
*
)
in
+
i
+
ld_in
*
full16_chunks
*
16
,
(
libxsmm_bfloat16
*
)
out
+
full16_chunks
*
16
+
i
*
ld_out
,
remainder_cols
,
ld_in
,
ld_out
);
}
}
#else
LIBXSMM_UNUSED
(
in
);
LIBXSMM_UNUSED
(
out
);
LIBXSMM_UNUSED
(
M
);
LIBXSMM_UNUSED
(
N
);
LIBXSMM_UNUSED
(
ld_in
);
LIBXSMM_UNUSED
(
ld_out
);
#endif
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_strd
gemm_br_function
;
gemm_function
tile_config_kernel
=
handle
->
upd_config_kernel
;
gemm_function
gemm_kernel
=
NULL
;
gemm_br_function
br_gemm_kernel
=
NULL
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_strd
gemm_br_function
;
gemm_function
tile_config_kernel
=
handle
->
upd_config_kernel
;
gemm_function
gemm_kernel
=
NULL
;
gemm_br_function
br_gemm_kernel
=
NULL
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
grad_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
((
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c"
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
grad_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
grad_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Rajkishore Barik, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H
#define LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H
#include <libxsmm_dnn_convolution.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H */
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment