Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c454d419
Commit
c454d419
authored
May 12, 2023
by
lisj
Browse files
删除子模块的gitignore
parent
3359c1f1
Changes
264
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7404 additions
and
0 deletions
+7404
-0
third_party/libxsmm/samples/smm/.make
third_party/libxsmm/samples/smm/.make
+0
-0
third_party/libxsmm/samples/utilities/mhd/mhd_in.mhd
third_party/libxsmm/samples/utilities/mhd/mhd_in.mhd
+0
-0
third_party/libxsmm/scripts/libxsmm_config.py
third_party/libxsmm/scripts/libxsmm_config.py
+145
-0
third_party/libxsmm/scripts/libxsmm_dispatch.py
third_party/libxsmm/scripts/libxsmm_dispatch.py
+116
-0
third_party/libxsmm/scripts/libxsmm_interface.py
third_party/libxsmm/scripts/libxsmm_interface.py
+195
-0
third_party/libxsmm/scripts/libxsmm_source.sh
third_party/libxsmm/scripts/libxsmm_source.sh
+68
-0
third_party/libxsmm/scripts/libxsmm_specialized.py
third_party/libxsmm/scripts/libxsmm_specialized.py
+205
-0
third_party/libxsmm/scripts/libxsmm_utilities.py
third_party/libxsmm/scripts/libxsmm_utilities.py
+320
-0
third_party/libxsmm/scripts/libxsmm_version.sh
third_party/libxsmm/scripts/libxsmm_version.sh
+30
-0
third_party/libxsmm/src/libxsmm_cpuid_arm.c
third_party/libxsmm/src/libxsmm_cpuid_arm.c
+96
-0
third_party/libxsmm/src/libxsmm_cpuid_x86.c
third_party/libxsmm/src/libxsmm_cpuid_x86.c
+336
-0
third_party/libxsmm/src/libxsmm_diff.h
third_party/libxsmm/src/libxsmm_diff.h
+144
-0
third_party/libxsmm/src/libxsmm_dnn.c
third_party/libxsmm/src/libxsmm_dnn.c
+759
-0
third_party/libxsmm/src/libxsmm_dnn_convolution.c
third_party/libxsmm/src/libxsmm_dnn_convolution.c
+2747
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.c
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.c
+719
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.h
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.h
+22
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.c
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.c
+544
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.h
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.h
+22
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.c
...party/libxsmm/src/libxsmm_dnn_convolution_weight_update.c
+914
-0
third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.h
...party/libxsmm/src/libxsmm_dnn_convolution_weight_update.h
+22
-0
No files found.
Too many changes to show.
To preserve performance only
264 of 264+
files are displayed.
Plain diff
Email patch
third_party/libxsmm/samples/smm/.make
0 → 100644
View file @
c454d419
third_party/libxsmm/samples/utilities/mhd/mhd_in.mhd
0 → 100644
View file @
c454d419
File added
third_party/libxsmm/scripts/libxsmm_config.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
from
string
import
Template
from
datetime
import
date
import
libxsmm_utilities
import
fnmatch
import
sys
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
1
<
argc
:
# required argument(s)
filename
=
sys
.
argv
[
1
]
# default configuration if no arguments are given
ilp64
=
offload
=
precision
=
flags
=
threshold
=
0
sync
=
jit
=
1
alpha
=
beta
=
1
cacheline
=
64
prefetch
=
-
1
wrap
=
1
malloc
=
0
mnklist
=
list
()
# optional argument(s)
if
2
<
argc
:
ilp64
=
int
(
sys
.
argv
[
2
])
if
3
<
argc
:
offload
=
int
(
sys
.
argv
[
3
])
if
4
<
argc
:
cacheline
=
libxsmm_utilities
.
sanitize_alignment
(
int
(
sys
.
argv
[
4
]))
if
5
<
argc
:
precision
=
int
(
sys
.
argv
[
5
])
if
6
<
argc
:
prefetch
=
int
(
sys
.
argv
[
6
])
if
7
<
argc
:
threshold
=
int
(
sys
.
argv
[
7
])
if
8
<
argc
:
sync
=
int
(
sys
.
argv
[
8
])
if
9
<
argc
:
jit
=
int
(
sys
.
argv
[
9
])
if
10
<
argc
:
flags
=
int
(
sys
.
argv
[
10
])
if
11
<
argc
:
alpha
=
int
(
sys
.
argv
[
11
])
if
12
<
argc
:
beta
=
int
(
sys
.
argv
[
12
])
if
13
<
argc
:
wrap
=
int
(
sys
.
argv
[
13
])
if
14
<
argc
:
malloc
=
int
(
sys
.
argv
[
14
])
if
15
<
argc
:
mnklist
=
sorted
(
libxsmm_utilities
.
load_mnklist
(
sys
.
argv
[
15
:],
0
))
version
,
branch
,
realversion
=
libxsmm_utilities
.
version_branch
()
major
,
minor
,
update
,
patch
=
libxsmm_utilities
.
version_numbers
(
version
)
if
0
==
threshold
:
threshold
=
64
*
64
*
64
maxmnk
=
libxsmm_utilities
.
max_mnk
(
mnklist
,
threshold
)
maxdim
=
int
(
maxmnk
**
(
1.0
/
3.0
)
+
0.5
)
avgdim
=
int
(
0.5
*
maxdim
+
0.5
)
avgm
=
libxsmm_utilities
.
median
(
list
(
map
(
lambda
mnk
:
mnk
[
0
],
mnklist
)),
avgdim
,
False
)
avgn
=
libxsmm_utilities
.
median
(
list
(
map
(
lambda
mnk
:
mnk
[
1
],
mnklist
)),
avgdim
,
False
)
avgk
=
libxsmm_utilities
.
median
(
list
(
map
(
lambda
mnk
:
mnk
[
2
],
mnklist
)),
avgdim
,
False
)
maxm
=
libxsmm_utilities
.
max_mnk
(
mnklist
,
avgdim
,
0
)
maxn
=
libxsmm_utilities
.
max_mnk
(
mnklist
,
avgdim
,
1
)
maxk
=
libxsmm_utilities
.
max_mnk
(
mnklist
,
avgdim
,
2
)
substitute
=
{
"VERSION"
:
realversion
,
"BRANCH"
:
branch
,
"MAJOR"
:
major
,
"MINOR"
:
minor
,
"UPDATE"
:
update
,
"PATCH"
:
patch
,
"DATE"
:
date
.
today
().
strftime
(
"%Y%m%d"
),
"CACHELINE"
:
cacheline
,
"PREFETCH"
:
[
-
1
,
prefetch
][
0
<=
prefetch
],
"MAX_MNK"
:
maxmnk
,
"MAX_DIM"
:
maxdim
,
"AVG_DIM"
:
int
((
maxdim
+
1
)
/
2
),
"MAX_M"
:
[
maxdim
,
maxm
][
avgm
<
maxm
],
"MAX_N"
:
[
maxdim
,
maxn
][
avgn
<
maxn
],
"MAX_K"
:
[
maxdim
,
maxk
][
avgk
<
maxk
],
"FLAGS"
:
flags
,
"ILP64"
:
[
0
,
1
][
0
!=
ilp64
],
"ALPHA"
:
alpha
,
"BETA"
:
beta
,
"WRAP"
:
wrap
,
"MALLOC"
:
malloc
,
"SYNC"
:
[
0
,
1
][
0
!=
sync
],
"JIT"
:
[
0
,
1
][
0
!=
jit
],
"LIBXSMM_OFFLOAD_BUILD"
:
[
""
,
"
\n
#define LIBXSMM_OFFLOAD_BUILD"
][
0
!=
offload
],
"MNK_PREPROCESSOR_LIST"
:
""
,
}
template
=
Template
(
open
(
filename
,
"r"
).
read
())
if
fnmatch
.
fnmatch
(
filename
,
"*.h*"
):
if
mnklist
:
first
=
mnklist
[
0
]
for
mnk
in
mnklist
:
mnkstr
=
"_"
.
join
(
map
(
str
,
mnk
))
if
mnk
!=
first
:
substitute
[
"MNK_PREPROCESSOR_LIST"
]
+=
"
\n
"
if
2
!=
precision
:
substitute
[
"MNK_PREPROCESSOR_LIST"
]
+=
(
"#define LIBXSMM_SMM_"
+
mnkstr
)
if
mnk
!=
first
or
0
==
precision
:
substitute
[
"MNK_PREPROCESSOR_LIST"
]
+=
"
\n
"
if
1
!=
precision
:
substitute
[
"MNK_PREPROCESSOR_LIST"
]
+=
(
"#define LIBXSMM_DMM_"
+
mnkstr
)
print
(
template
.
substitute
(
substitute
))
else
:
substitute
[
"BLASINT_KIND"
]
=
[
"C_INT"
,
"C_LONG_LONG"
][
0
!=
ilp64
]
print
(
template
.
safe_substitute
(
substitute
))
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
sys
.
argv
[
0
]
+
": wrong number of arguments!"
)
third_party/libxsmm/scripts/libxsmm_dispatch.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import
libxsmm_utilities
import
sys
import
os
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
1
<
argc
:
arg1_filename
=
[
sys
.
argv
[
1
],
""
][
"0"
==
sys
.
argv
[
1
]]
arg1_isfile
=
os
.
path
.
isfile
(
arg1_filename
)
base
=
1
if
arg1_isfile
:
print
(
"#if !defined(_WIN32)"
)
print
(
"{ static const char *const build_state ="
)
print
(
'# include "../'
+
os
.
path
.
basename
(
arg1_filename
)
+
'"'
)
print
(
" ;"
)
print
(
" internal_build_state = build_state;"
)
print
(
"}"
)
print
(
"#endif"
)
base
=
2
if
(
base
+
2
)
<
argc
:
precision
=
int
(
sys
.
argv
[
base
+
0
])
threshold
=
int
(
sys
.
argv
[
base
+
1
])
mnklist
=
libxsmm_utilities
.
load_mnklist
(
sys
.
argv
[
base
+
2
:],
0
)
print
(
"/* omit registering code if JIT is enabled"
" and if an ISA extension is found"
)
print
(
" * which is beyond the static code"
" path used to compile the library"
)
print
(
" */"
)
print
(
"#if (0 != LIBXSMM_JIT) && !defined(__MIC__)"
)
print
(
"if (LIBXSMM_X86_GENERIC > libxsmm_target_archid "
"/* JIT code gen. is not available */"
)
print
(
" /* conditions allows to avoid JIT "
"(if static code is good enough) */"
)
print
(
" || (LIBXSMM_STATIC_TARGET_ARCH == libxsmm_target_archid)"
)
print
(
" || (LIBXSMM_X86_AVX512_CORE <= libxsmm_target_archid &&"
)
print
(
" libxsmm_cpuid_vlen32(LIBXSMM_STATIC_TARGET_ARCH) =="
)
print
(
" libxsmm_cpuid_vlen32(libxsmm_target_archid)))"
)
print
(
"#endif"
)
print
(
"{"
)
print
(
" libxsmm_xmmfunction func;"
)
for
mnk
in
mnklist
:
mstr
,
nstr
,
kstr
,
mnkstr
=
(
str
(
mnk
[
0
]),
str
(
mnk
[
1
]),
str
(
mnk
[
2
]),
"_"
.
join
(
map
(
str
,
mnk
)),
)
mnksig
=
mstr
+
", "
+
nstr
+
", "
+
kstr
# prefer registering double-precision kernels
# when approaching an exhausted registry
if
1
!=
precision
:
# only double-precision
print
(
" func.dmm = (libxsmm_dmmfunction)libxsmm_dmm_"
+
mnkstr
+
";"
)
print
(
" internal_register_static_code("
+
"LIBXSMM_GEMM_PRECISION_F64, "
+
mnksig
+
", func, new_registry);"
)
for
mnk
in
mnklist
:
mstr
,
nstr
,
kstr
,
mnkstr
=
(
str
(
mnk
[
0
]),
str
(
mnk
[
1
]),
str
(
mnk
[
2
]),
"_"
.
join
(
map
(
str
,
mnk
)),
)
mnksig
=
mstr
+
", "
+
nstr
+
", "
+
kstr
# prefer registering double-precision kernels
# when approaching an exhausted registry
if
2
!=
precision
:
# only single-precision
print
(
" func.smm = (libxsmm_smmfunction)libxsmm_smm_"
+
mnkstr
+
";"
)
print
(
" internal_register_static_code("
+
"LIBXSMM_GEMM_PRECISION_F32, "
+
mnksig
+
", func, new_registry);"
)
print
(
"}"
)
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
sys
.
argv
[
0
]
+
": wrong number of arguments!"
)
third_party/libxsmm/scripts/libxsmm_interface.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
from
string
import
Template
import
libxsmm_utilities
import
fnmatch
import
sys
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
1
<
argc
:
# required argument(s)
filename
=
sys
.
argv
[
1
]
# default configuration if no arguments are given
precision
=
0
# all
ifversion
=
1
# interface
prefetch
=
-
1
# auto
mnklist
=
list
()
# optional argument(s)
if
2
<
argc
:
ivalue
=
int
(
sys
.
argv
[
2
])
ifversion
=
(
ivalue
>>
2
)
precision
=
(
ivalue
&
3
)
if
3
<
argc
:
prefetch
=
int
(
sys
.
argv
[
3
])
if
4
<
argc
:
mnklist
=
sorted
(
libxsmm_utilities
.
load_mnklist
(
sys
.
argv
[
4
:],
0
))
template
=
Template
(
open
(
filename
,
"r"
).
read
())
if
fnmatch
.
fnmatch
(
filename
,
"*.h*"
):
optional
=
[
", ..."
,
""
][
0
<=
prefetch
]
substitute
=
{
"MNK_INTERFACE_LIST"
:
""
}
for
mnk
in
mnklist
:
mnkstr
=
"_"
.
join
(
map
(
str
,
mnk
))
if
2
!=
precision
:
pfsig
=
[
optional
+
");"
,
",
\n
"
"const float* pa, "
"const float* pb, "
"const float* pc);"
][
0
<
prefetch
]
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
LIBXSMM_API void libxsmm_smm_"
+
mnkstr
+
"(const float* a, const float* b, float* c"
+
pfsig
)
if
1
!=
precision
:
pfsig
=
[
optional
+
");"
,
",
\n
"
"const double* pa, "
"const double* pb, "
"const double* pc);"
][
0
<
prefetch
]
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
LIBXSMM_API void libxsmm_dmm_"
+
mnkstr
+
"(const double* a, const double* b, double* c"
+
pfsig
)
if
0
==
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
"
if
mnklist
and
0
!=
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
"
print
(
template
.
substitute
(
substitute
))
else
:
# Fortran interface
if
1
>
ifversion
and
0
!=
ifversion
:
raise
ValueError
(
"Fortran interface level is inconsistent!"
)
# Fortran's OPTIONAL allows to always generate an interface
# with prefetch signature (more flexible usage)
if
0
==
prefetch
:
prefetch
=
-
1
version
,
branch
,
realversion
=
libxsmm_utilities
.
version_branch
(
16
)
major
,
minor
,
update
,
patch
=
libxsmm_utilities
.
version_numbers
(
version
)
substitute
=
{
"VERSION"
:
realversion
,
"BRANCH"
:
branch
,
"MAJOR"
:
major
,
"MINOR"
:
minor
,
"UPDATE"
:
update
,
"PATCH"
:
patch
,
"MNK_INTERFACE_LIST"
:
""
,
"CONTIGUOUS"
:
[
""
,
", CONTIGUOUS"
][
1
<
ifversion
]
}
if
mnklist
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
"
for
mnk
in
mnklist
:
mnkstr
=
"_"
.
join
(
map
(
str
,
mnk
))
if
0
==
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_"
+
mnkstr
+
", libxsmm_dmm_"
+
mnkstr
)
elif
2
!=
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_"
+
mnkstr
)
elif
1
!=
precision
:
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_dmm_"
+
mnkstr
)
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
INTERFACE"
optional
=
[
", OPTIONAL"
,
""
][
0
<
prefetch
]
bindc
=
[
""
,
"BIND(C)"
][
0
<
prefetch
]
for
mnk
in
mnklist
:
mnkstr
=
"_"
.
join
(
map
(
str
,
mnk
))
if
2
!=
precision
:
pfsiga
=
[
") BIND(C)
\n
"
,
","
+
"&"
.
rjust
(
26
-
len
(
mnkstr
))
+
"
\n
& pa, pb, pc) "
+
bindc
+
"
\n
"
][
0
!=
prefetch
]
pfsigb
=
[
""
,
" REAL(C_FLOAT), "
"INTENT(IN)"
+
optional
+
" :: "
"pa(*), "
"pb(*), "
"pc(*)
\n
"
][
0
!=
prefetch
]
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"PURE SUBROUTINE libxsmm_smm_"
+
mnkstr
+
"(a, b, c"
+
pfsiga
+
" IMPORT :: C_FLOAT
\n
"
" REAL(C_FLOAT), "
"INTENT(IN) :: a(*), b(*)
\n
"
" REAL(C_FLOAT), "
"INTENT(INOUT) :: c(*)
\n
"
+
pfsigb
+
" END SUBROUTINE"
)
if
1
!=
precision
:
pfsiga
=
[
") BIND(C)
\n
"
,
","
+
"&"
.
rjust
(
26
-
len
(
mnkstr
))
+
"
\n
& pa, pb, pc) "
+
bindc
+
"
\n
"
][
0
!=
prefetch
]
pfsigb
=
[
""
,
" REAL(C_DOUBLE), "
"INTENT(IN)"
+
optional
+
" :: "
"pa(*), "
"pb(*), "
"pc(*)
\n
"
][
0
!=
prefetch
]
substitute
[
"MNK_INTERFACE_LIST"
]
+=
(
"
\n
"
"PURE SUBROUTINE libxsmm_dmm_"
+
mnkstr
+
"(a, b, c"
+
pfsiga
+
" IMPORT :: C_DOUBLE
\n
"
" REAL(C_DOUBLE), "
"INTENT(IN) :: a(*), b(*)
\n
"
" REAL(C_DOUBLE), "
"INTENT(INOUT) :: c(*)
\n
"
+
pfsigb
+
" END SUBROUTINE"
)
substitute
[
"MNK_INTERFACE_LIST"
]
+=
"
\n
END INTERFACE"
print
(
template
.
safe_substitute
(
substitute
))
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
sys
.
argv
[
0
]
+
": wrong number of arguments!"
)
third_party/libxsmm/scripts/libxsmm_source.sh
0 → 100755
View file @
c454d419
#!/usr/bin/env sh
SRCDIR
=
../src
GREP
=
$(
command
-v
grep
)
if
[
""
=
"
${
GREP
}
"
]
;
then
>
&2
echo
"Error: missing prerequisites!"
exit
1
fi
cat
<<
EOM
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_SOURCE_H
#define LIBXSMM_SOURCE_H
#if defined(LIBXSMM_MACROS_H)
# error Please do not include any LIBXSMM header other than libxsmm_source.h!
#endif
#if defined(LIBXSMM_BUILD)
# error LIBXSMM_BUILD cannot be defined for the header-only LIBXSMM!
#endif
/**
* This header is intentionally called "libxsmm_source.h" since the followings block
* includes *internal* files, and thereby exposes LIBXSMM's implementation.
* The so-called "header-only" usage model gives up the clearly defined binary interface
* (including support for hot-fixes after deployment), and requires to rebuild client
* code for every (internal) change of LIBXSMM. Please make sure to only rely on the
* public interface as the internal implementation may change without notice.
*/
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
EOM
HERE
=
$(
cd
"
$(
dirname
"
$0
"
)
"
&&
pwd
-P
)
if
[
""
=
"
$1
"
]
;
then
DSTDIR
=
${
SRCDIR
}
else
DSTDIR
=
$1
fi
# determine order of filenames in directory list
export
LC_ALL
=
C
# good-enough pattern to match a main function, and to exclude this translation unit
for
FILE
in
$(
cd
"
${
HERE
}
/
${
SRCDIR
}
"
&&
${
GREP
}
-L
"main[[:space:]]*(.*)"
./
*
.c
)
;
do
BASENAME
=
$(
basename
"
${
FILE
}
"
)
echo
"#include
\"
${
DSTDIR
}
/
${
BASENAME
}
\"
"
done
cat
<<
EOM
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#endif /*LIBXSMM_SOURCE_H*/
EOM
third_party/libxsmm/scripts/libxsmm_specialized.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import
sys
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
6
==
argc
:
precision
=
int
(
sys
.
argv
[
1
])
m
,
n
,
k
=
int
(
sys
.
argv
[
2
]),
int
(
sys
.
argv
[
3
]),
int
(
sys
.
argv
[
4
])
prefetch
=
int
(
sys
.
argv
[
5
])
mnkstr
=
str
(
m
)
+
"_"
+
str
(
n
)
+
"_"
+
str
(
k
)
optional
=
[
""
,
", ..."
][
0
>
prefetch
]
signature
=
[
"a, b, c"
,
"a, b, c, pa, pb, pc"
][
0
<
prefetch
]
if
2
!=
precision
:
pfsig
=
[
optional
+
")"
,
"
\n
"
", const float* pa"
", const float* pb"
", const float* pc)"
,
][
0
<
prefetch
]
print
print
print
(
"LIBXSMM_API void libxsmm_smm_"
+
mnkstr
+
"(const float* a, const float* b, float* c"
+
pfsig
)
print
(
"{"
)
print
(
"#if defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_skx_sp) &&
\\
"
)
print
(
" !(defined(__AVX512PF__) && defined(__AVX512ER__))"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_skx("
+
signature
+
");"
)
print
(
"#elif defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_knl_sp)"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_knl("
+
signature
+
");"
)
print
(
"#elif defined(__AVX2__) && "
"defined(LIBXSMM_GENTARGET_hsw_sp)"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_hsw("
+
signature
+
");"
)
print
(
"#elif defined(__AVX__) && "
"defined(LIBXSMM_GENTARGET_snb_sp)"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_snb("
+
signature
+
");"
)
print
(
"#elif defined(__SSE3__) && "
"defined(LIBXSMM_GENTARGET_wsm_sp)"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"_wsm("
+
signature
+
");"
)
print
(
"#else"
)
print
(
" const char transa = (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print
(
" const char transb = (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print
(
" const float alpha = LIBXSMM_ALPHA, beta = LIBXSMM_BETA;"
)
print
(
" const libxsmm_blasint "
"m = "
+
str
(
m
)
+
", "
"n = "
+
str
(
n
)
+
", "
"k = "
+
str
(
k
)
+
";"
)
if
0
<
prefetch
:
print
(
" LIBXSMM_UNUSED(pa);"
" LIBXSMM_UNUSED(pb);"
" LIBXSMM_UNUSED(pc);"
)
print
(
" LIBXSMM_INLINE_XGEMM(float, float, &transa, &transb,"
" &m, &n, &k, &alpha, a, &m, b, &k, &beta, c, &m);"
)
print
(
"#endif"
)
print
(
"}"
)
print
print
print
(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_smm_"
+
mnkstr
+
")(const float* a, const float* b, float* c"
+
pfsig
+
";"
)
print
(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_smm_"
+
mnkstr
+
")(const float* a, const float* b, float* c"
+
pfsig
)
print
(
"{"
)
print
(
" libxsmm_smm_"
+
mnkstr
+
"("
+
signature
+
");"
)
print
(
"}"
)
if
1
!=
precision
:
pfsig
=
[
optional
+
")"
,
"
\n
"
", const double* pa"
", const double* pb"
", const double* pc)"
,
][
0
<
prefetch
]
print
print
print
(
"LIBXSMM_API void libxsmm_dmm_"
+
mnkstr
+
"(const double* a, const double* b, double* c"
+
pfsig
)
print
(
"{"
)
print
(
"#if defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_skx_dp) &&
\\
"
)
print
(
" !(defined(__AVX512PF__) && defined(__AVX512ER__))"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_skx("
+
signature
+
");"
)
print
(
"#elif defined(__AVX512F__) && "
"defined(LIBXSMM_GENTARGET_knl_dp)"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_knl("
+
signature
+
");"
)
print
(
"#elif defined(__AVX2__) && "
"defined(LIBXSMM_GENTARGET_hsw_dp)"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_hsw("
+
signature
+
");"
)
print
(
"#elif defined(__AVX__) && "
"defined(LIBXSMM_GENTARGET_snb_dp)"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_snb("
+
signature
+
");"
)
print
(
"#elif defined(__SSE3__) && "
"defined(LIBXSMM_GENTARGET_wsm_dp)"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"_wsm("
+
signature
+
");"
)
print
(
"#else"
)
print
(
" const char transa = (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print
(
" const char transb = (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & "
"LIBXSMM_FLAGS) ? 'N' : 'T');"
)
print
(
" const double alpha = LIBXSMM_ALPHA, beta = LIBXSMM_BETA;"
)
print
(
" const libxsmm_blasint "
"m = "
+
str
(
m
)
+
", "
"n = "
+
str
(
n
)
+
", "
"k = "
+
str
(
k
)
+
";"
)
if
0
<
prefetch
:
print
(
" LIBXSMM_UNUSED(pa);"
" LIBXSMM_UNUSED(pb);"
" LIBXSMM_UNUSED(pc);"
)
print
(
" LIBXSMM_INLINE_XGEMM(double, double, &transa, &transb,"
" &m, &n, &k, &alpha, a, &m, b, &k, &beta, c, &m);"
)
print
(
"#endif"
)
print
(
"}"
)
print
print
print
(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dmm_"
+
mnkstr
+
")(const double* a, const double* b, double* c"
+
pfsig
+
";"
)
print
(
"LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dmm_"
+
mnkstr
+
")(const double* a, const double* b, double* c"
+
pfsig
)
print
(
"{"
)
print
(
" libxsmm_dmm_"
+
mnkstr
+
"("
+
signature
+
");"
)
print
(
"}"
)
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
sys
.
argv
[
0
]
+
": wrong number of arguments!"
)
third_party/libxsmm/scripts/libxsmm_utilities.py
0 → 100755
View file @
c454d419
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import
itertools
import
operator
import
inspect
import
sys
import
os
try
:
from
functools
import
reduce
except
ImportError
:
pass
def
upper_list
(
lists
,
level
):
nlist
=
len
(
lists
)
upper
=
[
level
,
level
+
nlist
][
1
>
level
]
-
1
above
=
lists
[
upper
]
if
above
:
return
above
elif
-
nlist
<=
level
:
return
upper_list
(
lists
,
level
-
1
)
else
:
return
[]
# https://docs.python.org/3/library/itertools.html#itertools.product
def
itertools_product
(
*
args
):
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
pools
=
[
tuple
(
pool
)
for
pool
in
args
]
result
=
[[]]
for
pool
in
pools
:
result
=
[
x
+
[
y
]
for
x
in
result
for
y
in
pool
]
for
prod
in
result
:
yield
tuple
(
prod
)
def
load_mnklist
(
argv
,
threshold
,
inputformat
=
0
,
resultset
=
None
):
if
resultset
is
None
:
resultset
=
set
()
if
0
==
inputformat
:
# indexes format
resultset
=
set
(
map
(
lambda
mnk
:
tuple
(
map
(
int
,
mnk
.
split
(
"_"
))),
argv
))
elif
-
1
==
inputformat
:
# new input format
groups
=
map
(
lambda
group
:
[
int
(
i
)
for
i
in
group
.
split
()],
" "
.
join
(
argv
[
0
:]).
split
(
","
),
)
resultset
=
set
(
itertools
.
chain
(
*
[
list
(
itertools_product
(
*
(
i
,
i
,
i
)))
for
i
in
groups
]
)
)
elif
-
2
==
inputformat
:
# legacy format
mlist
=
list
(
map
(
int
,
map
(
lambda
s
:
str
(
s
).
replace
(
","
,
" "
).
strip
(),
argv
[
2
:
2
+
int
(
argv
[
0
])],
),
)
)
nlist
=
list
(
map
(
int
,
map
(
lambda
s
:
str
(
s
).
replace
(
","
,
" "
).
strip
(),
argv
[
2
+
int
(
argv
[
0
]):
2
+
int
(
argv
[
0
])
+
int
(
argv
[
1
])],
),
)
)
klist
=
list
(
map
(
int
,
map
(
lambda
s
:
str
(
s
).
replace
(
","
,
" "
).
strip
(),
argv
[
2
+
int
(
argv
[
0
])
+
int
(
argv
[
1
]):],
),
)
)
mnk
=
[
mlist
,
nlist
,
klist
]
top
=
[
[
mlist
,
upper_list
(
mnk
,
0
)][
0
==
len
(
mlist
)],
[
nlist
,
upper_list
(
mnk
,
1
)][
0
==
len
(
nlist
)],
[
klist
,
upper_list
(
mnk
,
2
)][
0
==
len
(
klist
)],
]
for
m
in
top
[
0
]:
for
n
in
top
[
1
]:
if
not
nlist
:
n
=
m
for
k
in
top
[
2
]:
if
not
klist
:
k
=
n
if
not
mlist
:
m
=
k
resultset
.
add
((
m
,
n
,
k
))
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
"load_mnklist: unexpected input format!"
)
if
0
!=
threshold
:
# threshold requested
return
set
(
filter
(
lambda
mnk
:
(
0
<
mnk
[
0
])
and
(
0
<
mnk
[
1
])
and
(
0
<
mnk
[
2
])
and
(
threshold
>=
(
mnk
[
0
]
*
mnk
[
1
]
*
mnk
[
2
])),
resultset
,
)
)
else
:
return
set
(
filter
(
lambda
mnk
:
(
0
<
mnk
[
0
])
and
(
0
<
mnk
[
1
])
and
(
0
<
mnk
[
2
]),
resultset
,
)
)
def
max_mnk
(
mnklist
,
init
=
0
,
index
=
None
):
if
index
is
not
None
and
0
<=
index
and
index
<
3
:
mapped
=
map
(
lambda
mnk
:
mnk
[
index
],
mnklist
)
else
:
mapped
=
map
(
lambda
mnk
:
mnk
[
0
]
*
mnk
[
1
]
*
mnk
[
2
],
mnklist
)
return
reduce
(
max
,
mapped
,
init
)
def
median
(
list_of_numbers
,
fallback
=
None
,
average
=
True
):
size
=
len
(
list_of_numbers
)
if
0
<
size
:
# TODO: use nth element
list_of_numbers
.
sort
()
size2
=
int
(
size
/
2
)
if
average
and
0
==
(
size
-
size2
*
2
):
medval
=
int
(
0.5
*
(
list_of_numbers
[
size2
-
1
]
+
list_of_numbers
[
size2
])
+
0.5
)
else
:
medval
=
list_of_numbers
[
size2
]
if
fallback
is
not
None
:
result
=
min
(
medval
,
fallback
)
else
:
result
=
medval
elif
fallback
is
not
None
:
result
=
fallback
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
"median: empty list!"
)
return
result
def
is_pot
(
num
):
return
0
<=
num
or
0
==
(
num
&
(
num
-
1
))
def
sanitize_alignment
(
alignment
):
if
0
>=
alignment
:
alignment
=
[
1
,
64
][
0
!=
alignment
]
elif
not
is_pot
(
alignment
):
sys
.
tracebacklimit
=
0
raise
ValueError
(
"sanitize_alignment: alignment must be a Power of Two (POT)!"
)
return
alignment
def
align_value
(
n
,
typesize
,
alignment
):
if
0
<
typesize
and
0
<
alignment
:
return
(
((
n
*
typesize
+
alignment
-
1
)
/
alignment
)
*
alignment
)
/
typesize
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
"align_value: invalid input!"
)
def
version_branch_from_file
(
version_filepath
):
version_file
=
open
(
version_filepath
,
"r"
)
version
,
branch
,
sep
=
"1.0"
,
""
,
"-"
try
:
version_list
,
n
=
version_file
.
read
().
replace
(
"
\n
"
,
""
).
split
(
sep
),
0
for
word
in
version_list
:
if
not
reduce
(
operator
.
and_
,
(
subword
.
isdigit
()
for
subword
in
word
.
split
(
"."
)),
True
,
):
branch
+=
[
sep
+
word
,
word
][
0
==
n
]
n
+=
1
else
:
break
version
=
sep
.
join
(
version_list
[
n
:])
finally
:
version_file
.
close
()
return
(
version
,
branch
)
def
version_numbers
(
version
,
branch
=
None
):
version_list
=
version
.
split
(
"-"
)
if
not
version_list
[
0
][
0
].
isdigit
():
vbranch
=
version_list
[
0
]
else
:
vbranch
=
"master"
if
branch
is
None
or
vbranch
==
branch
:
minor
=
update
=
patch
=
0
major
=
1
n
=
len
(
version_list
)
if
1
<
n
:
patch_list
=
version_list
[
n
-
1
]
if
1
==
len
(
patch_list
.
split
(
"."
)):
version_list
=
version_list
[
n
-
2
].
split
(
"."
)
if
version_list
!=
[
vbranch
]:
patch
=
int
(
patch_list
)
else
:
major
=
int
(
patch_list
)
else
:
version_list
=
patch_list
.
split
(
"."
)
else
:
version_list
=
version
.
split
(
"."
)
n
=
len
(
version_list
)
try
:
if
0
<
n
:
major
=
int
(
version_list
[
0
])
if
1
<
n
:
minor
=
int
(
version_list
[
1
])
if
2
<
n
:
update
=
int
(
version_list
[
2
])
except
ValueError
:
# if 1 == n: major = 0
pass
else
:
major
=
minor
=
update
=
patch
=
-
1
return
major
,
minor
,
update
,
patch
def
version_branch
(
max_strlen
=-
1
):
version_filename
=
"version.txt"
filepath_default
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
inspect
.
getfile
(
inspect
.
currentframe
())),
".."
,
version_filename
,
)
)
filepath_local
=
os
.
path
.
realpath
(
version_filename
)
# local version file
realversion
,
branch
=
version_branch_from_file
(
filepath_default
)
version
=
realversion
out_of_tree
=
filepath_default
!=
filepath_local
if
out_of_tree
and
os
.
path
.
isfile
(
filepath_local
):
local
,
ignored
=
version_branch_from_file
(
filepath_local
)
if
version_numbers
(
realversion
)
<
version_numbers
(
local
):
version
=
local
if
0
<
max_strlen
:
start
=
int
(
max_strlen
/
3
)
cut
=
max
(
branch
.
rfind
(
"-"
,
start
,
max_strlen
),
branch
.
rfind
(
"_"
,
start
,
max_strlen
),
branch
.
rfind
(
"."
,
start
,
max_strlen
),
)
if
start
<
cut
:
branch
=
branch
[
0
:
cut
]
else
:
branch
=
branch
[
0
:
max_strlen
]
return
(
version
,
branch
,
realversion
)
if
__name__
==
"__main__"
:
argc
=
len
(
sys
.
argv
)
if
1
<
argc
:
arg1
=
int
(
sys
.
argv
[
1
])
else
:
arg1
=
0
if
-
1
==
arg1
:
if
5
<
argc
:
# threshold = int(sys.argv[2])
mnk_size
=
int
(
sys
.
argv
[
3
])
dims
=
load_mnklist
(
sys
.
argv
[
4
:
4
+
mnk_size
],
0
,
-
1
)
dims
=
load_mnklist
(
sys
.
argv
[
4
+
mnk_size
:],
0
,
-
2
,
dims
)
mnklist
=
map
(
lambda
mnk
:
"_"
.
join
(
map
(
str
,
mnk
)),
sorted
(
dims
))
print
(
" "
.
join
(
mnklist
))
elif
3
==
argc
:
major
,
minor
,
update
,
patch
=
(
version_numbers
(
sys
.
argv
[
2
],
"release"
)
)
print
([
"0"
,
"1"
][
0
==
patch
])
elif
0
<=
arg1
:
if
0
==
arg1
and
3
==
argc
:
major
,
minor
,
update
,
patch
=
version_numbers
(
sys
.
argv
[
2
])
print
(
major
)
# soname version
else
:
version
,
branch
,
realversion
=
version_branch
()
major
,
minor
,
update
,
patch
=
version_numbers
(
version
)
if
1
==
arg1
:
print
(
major
)
elif
2
==
arg1
:
print
(
minor
)
elif
3
==
arg1
:
print
(
update
)
elif
4
==
arg1
:
print
(
patch
)
elif
""
!=
branch
:
print
(
"{}-{}"
.
format
(
branch
,
realversion
))
else
:
print
(
realversion
)
else
:
sys
.
tracebacklimit
=
0
raise
ValueError
(
"{}: wrong ({}) number of arguments ('{}') given!"
.
format
(
sys
.
argv
[
0
],
argc
-
1
,
" "
.
join
(
sys
.
argv
[
1
:]))
)
third_party/libxsmm/scripts/libxsmm_version.sh
0 → 100755
View file @
c454d419
#!/usr/bin/env sh
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
GIT
=
$(
command
-v
git
)
SHIFT
=
0
if
[
"
$1
"
]
;
then
SHIFT
=
$1
fi
NAME
=
$(
${
GIT
}
rev-parse
--abbrev-ref
HEAD 2>/dev/null
)
MAIN
=
$(
${
GIT
}
describe
--tags
--match
"[0-9]*"
--abbrev
=
0 2>/dev/null
)
if
[
"
${
MAIN
}
"
]
;
then
VERSION
=
"
${
NAME
}
-
${
MAIN
}
"
REVC
=
$(
${
GIT
}
rev-list
--count
--no-merges
"
${
MAIN
}
"
..HEAD 2>/dev/null
)
else
VERSION
=
${
NAME
}
REVC
=
$(
${
GIT
}
rev-list
--count
--no-merges
HEAD 2>/dev/null
)
fi
echo
"
${
VERSION
}
-
$((
REVC+SHIFT
))
"
third_party/libxsmm/src/libxsmm_cpuid_arm.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include <libxsmm_cpuid.h>
#include <libxsmm_generator.h>
#include <libxsmm_memory.h>
#include <libxsmm_sync.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <signal.h>
#include <setjmp.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#if defined(_MSC_VER)
# define LIBXSMM_CPUID_ARM_ENC16(OP0, OP1, CRN, CRM, OP2) ( \
(((OP0) & 1) << 14) | \
(((OP1) & 7) << 11) | \
(((CRN) & 15) << 7) | \
(((CRM) & 15) << 3) | \
(((OP2) & 7) << 0))
# define ID_AA64ISAR1_EL1 LIBXSMM_CPUID_ARM_ENC16(0b11, 0b000, 0b0000, 0b0110, 0b001)
# define ID_AA64PFR0_EL1 LIBXSMM_CPUID_ARM_ENC16(0b11, 0b000, 0b0000, 0b0100, 0b000)
# define LIBXSMM_CPUID_ARM_MRS(RESULT, ID) RESULT = _ReadStatusReg(ID)
#else
# define LIBXSMM_CPUID_ARM_MRS(RESULT, ID) __asm__ __volatile__( \
"mrs %0," LIBXSMM_STRINGIFY(ID) : "=r"(RESULT))
#endif
#if defined(LIBXSMM_PLATFORM_AARCH64)
LIBXSMM_APIVAR_DEFINE
(
jmp_buf
internal_cpuid_arm_jmp_buf
);
LIBXSMM_API_INTERN
void
internal_cpuid_arm_sigill
(
int
/*signum*/
);
LIBXSMM_API_INTERN
void
internal_cpuid_arm_sigill
(
int
signum
)
{
void
(
*
const
handler
)(
int
)
=
signal
(
signum
,
internal_cpuid_arm_sigill
);
LIBXSMM_ASSERT
(
SIGILL
==
signum
);
if
(
SIG_ERR
!=
handler
)
longjmp
(
internal_cpuid_arm_jmp_buf
,
1
);
}
#endif
LIBXSMM_API
int
libxsmm_cpuid_arm
(
libxsmm_cpuid_info
*
info
)
{
static
int
result
=
LIBXSMM_TARGET_ARCH_UNKNOWN
;
#if defined(LIBXSMM_PLATFORM_AARCH64)
#if defined(__APPLE__) && defined(__arm64__)
result
=
LIBXSMM_AARCH64_V81
;
#else
if
(
LIBXSMM_TARGET_ARCH_UNKNOWN
==
result
)
{
/* avoid redetecting features */
void
(
*
const
handler
)(
int
)
=
signal
(
SIGILL
,
internal_cpuid_arm_sigill
);
result
=
LIBXSMM_AARCH64_V81
;
if
(
SIG_ERR
!=
handler
)
{
uint64_t
capability
;
/* 64-bit value */
if
(
0
==
setjmp
(
internal_cpuid_arm_jmp_buf
))
{
LIBXSMM_CPUID_ARM_MRS
(
capability
,
ID_AA64ISAR1_EL1
);
if
(
0xF
&
capability
)
{
/* DPB */
result
=
LIBXSMM_AARCH64_V82
;
if
(
0
==
setjmp
(
internal_cpuid_arm_jmp_buf
))
{
LIBXSMM_CPUID_ARM_MRS
(
capability
,
ID_AA64PFR0_EL1
);
if
(
0xF
&
(
capability
>>
32
))
{
/* SVE */
result
=
LIBXSMM_AARCH64_A64FX
;
}
}
}
}
/* restore original state */
signal
(
SIGILL
,
handler
);
}
if
(
NULL
!=
info
)
LIBXSMM_MEMZERO127
(
info
);
}
#endif
#else
# if !defined(NDEBUG)
static
int
error_once
=
0
;
if
(
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM WARNING: libxsmm_cpuid_arm called on non-ARM platform!
\n
"
);
}
# endif
if
(
NULL
!=
info
)
LIBXSMM_MEMZERO127
(
info
);
#endif
return
result
;
}
third_party/libxsmm/src/libxsmm_cpuid_x86.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#include <libxsmm_generator.h>
#include <libxsmm_memory.h>
#include <libxsmm_sync.h>
#if !defined(_WIN32)
# include <sys/mman.h>
#endif
#if defined(LIBXSMM_PLATFORM_X86)
/* XGETBV: receive results (EAX, EDX) for eXtended Control Register (XCR). */
/* CPUID, receive results (EAX, EBX, ECX, EDX) for requested FUNCTION/SUBFN. */
#if defined(_MSC_VER)
/*defined(_WIN32) && !defined(__GNUC__)*/
# define LIBXSMM_XGETBV(XCR, EAX, EDX) { \
unsigned long long libxsmm_xgetbv_ = _xgetbv(XCR); \
EAX = (int)libxsmm_xgetbv_; \
EDX = (int)(libxsmm_xgetbv_ >> 32); \
}
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) { \
int libxsmm_cpuid_x86_[
/*4*/
] = { 0, 0, 0, 0 }; \
__cpuidex(libxsmm_cpuid_x86_, FUNCTION, SUBFN); \
EAX = (unsigned int)libxsmm_cpuid_x86_[0]; \
EBX = (unsigned int)libxsmm_cpuid_x86_[1]; \
ECX = (unsigned int)libxsmm_cpuid_x86_[2]; \
EDX = (unsigned int)libxsmm_cpuid_x86_[3]; \
}
# elif defined(__GNUC__) || !defined(_CRAYC)
# if (64 > (LIBXSMM_BITS))
LIBXSMM_EXTERN
LIBXSMM_RETARGETABLE
int
__get_cpuid
(
/* prototype */
unsigned
int
,
unsigned
int
*
,
unsigned
int
*
,
unsigned
int
*
,
unsigned
int
*
);
# define LIBXSMM_XGETBV(XCR, EAX, EDX) EAX = (EDX) = 0xFFFFFFFF
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) \
EAX = (EBX) = (EDX) = 0; ECX = (SUBFN); \
__get_cpuid(FUNCTION, &(EAX), &(EBX), &(ECX), &(EDX))
# else
/* 64-bit */
# define LIBXSMM_XGETBV(XCR, EAX, EDX) __asm__ __volatile__( \
".byte 0x0f, 0x01, 0xd0"
/*xgetbv*/
: "=a"(EAX), "=d"(EDX) : "c"(XCR) \
)
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) \
__asm__ __volatile__ (".byte 0x0f, 0xa2"
/*cpuid*/
\
: "=a"(EAX), "=b"(EBX), "=c"(ECX), "=d"(EDX) \
: "a"(FUNCTION), "b"(0), "c"(SUBFN), "d"(0) \
)
# endif
# else
/* legacy Cray Compiler */
# define LIBXSMM_XGETBV(XCR, EAX, EDX) EAX = (EDX) = 0
# define LIBXSMM_CPUID_X86(FUNCTION, SUBFN, EAX, EBX, ECX, EDX) EAX = (EBX) = (ECX) = (EDX) = 0
# endif
#endif
#define LIBXSMM_CPUID_CHECK(VALUE, CHECK) ((CHECK) == ((CHECK) & (VALUE)))
LIBXSMM_API
int
libxsmm_cpuid_x86
(
libxsmm_cpuid_info
*
info
)
{
static
int
result
=
LIBXSMM_TARGET_ARCH_UNKNOWN
;
#if defined(LIBXSMM_PLATFORM_X86)
unsigned
int
eax
,
ebx
,
ecx
,
edx
;
LIBXSMM_CPUID_X86
(
0
,
0
/*ecx*/
,
eax
,
ebx
,
ecx
,
edx
);
if
(
1
<=
eax
)
{
/* CPUID max. leaf */
/* avoid redetecting features but redetect on request (info given) */
if
(
LIBXSMM_TARGET_ARCH_UNKNOWN
==
result
||
NULL
!=
info
)
{
int
feature_cpu
=
LIBXSMM_X86_GENERIC
,
feature_os
=
LIBXSMM_X86_GENERIC
,
has_context
=
0
;
unsigned
int
maxleaf
=
eax
;
# if defined(__linux__)
if
(
0
==
libxsmm_se
&&
LIBXSMM_TARGET_ARCH_UNKNOWN
==
result
)
{
FILE
*
const
selinux
=
fopen
(
"/sys/fs/selinux/enforce"
,
"rb"
);
if
(
NULL
!=
selinux
)
{
if
(
1
==
fread
(
&
libxsmm_se
,
1
/*sizeof(char)*/
,
1
/*count*/
,
selinux
))
{
libxsmm_se
=
(
'0'
!=
libxsmm_se
?
1
:
0
);
}
else
{
/* conservative assumption in case of read-error */
libxsmm_se
=
1
;
}
fclose
(
selinux
);
}
}
# elif defined(MAP_JIT)
libxsmm_se
=
1
;
# endif
LIBXSMM_CPUID_X86
(
1
,
0
/*ecx*/
,
eax
,
ebx
,
ecx
,
edx
);
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x00000001
))
{
/* SSE3(0x00000001) */
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x00100000
))
{
/* SSE42(0x00100000) */
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x10000000
))
{
/* AVX(0x10000000) */
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x00001000
))
{
/* FMA(0x00001000) */
unsigned
int
ecx2
;
LIBXSMM_CPUID_X86
(
7
,
0
/*ecx*/
,
eax
,
ebx
,
ecx2
,
edx
);
/* AVX512F(0x00010000), AVX512CD(0x10000000) */
if
(
LIBXSMM_CPUID_CHECK
(
ebx
,
0x10010000
))
{
/* Common */
/* AVX512DQ(0x00020000), AVX512BW(0x40000000), AVX512VL(0x80000000) */
if
(
LIBXSMM_CPUID_CHECK
(
ebx
,
0xC0020000
))
{
/* AVX512-Core */
if
(
LIBXSMM_CPUID_CHECK
(
ecx2
,
0x00000800
))
{
/* VNNI */
unsigned
int
edx2
;
/* we need to save edx for AMX check */
# if 0
/* no check required yet */
unsigned
int
ecx3
;
LIBXSMM_CPUID_X86
(
7
,
1
/*ecx*/
,
eax
,
ebx
,
ecx3
,
edx
);
# else
LIBXSMM_CPUID_X86
(
7
,
1
/*ecx*/
,
eax
,
ebx
,
ecx2
,
edx2
);
# endif
if
(
LIBXSMM_CPUID_CHECK
(
eax
,
0x00000020
))
{
/* BF16 */
feature_cpu
=
LIBXSMM_X86_AVX512_CPX
;
if
(
LIBXSMM_CPUID_CHECK
(
edx
,
0x03400000
))
{
/* AMX-TILE, AMX-INT8, AMX-BF16 */
feature_cpu
=
LIBXSMM_X86_AVX512_SPR
;
}
}
else
feature_cpu
=
LIBXSMM_X86_AVX512_CLX
;
/* CLX */
}
else
feature_cpu
=
LIBXSMM_X86_AVX512_CORE
;
/* SKX */
}
/* AVX512PF(0x04000000), AVX512ER(0x08000000) */
else
if
(
LIBXSMM_CPUID_CHECK
(
ebx
,
0x0C000000
))
{
/* AVX512-MIC */
if
(
LIBXSMM_CPUID_CHECK
(
edx
,
0x0000000C
))
{
/* KNM */
feature_cpu
=
LIBXSMM_X86_AVX512_KNM
;
}
else
feature_cpu
=
LIBXSMM_X86_AVX512_MIC
;
/* KNL */
}
else
feature_cpu
=
LIBXSMM_X86_AVX512
;
/* AVX512-Common */
}
else
feature_cpu
=
LIBXSMM_X86_AVX2
;
}
else
feature_cpu
=
LIBXSMM_X86_AVX
;
}
else
feature_cpu
=
LIBXSMM_X86_SSE42
;
}
else
feature_cpu
=
LIBXSMM_X86_SSE3
;
}
# if !defined(LIBXSMM_INTRINSICS_DEBUG)
LIBXSMM_ASSERT_MSG
(
LIBXSMM_STATIC_TARGET_ARCH
<=
LIBXSMM_MAX
(
LIBXSMM_X86_GENERIC
,
feature_cpu
),
"missed detecting ISA extensions"
);
/* coverity[dead_error_line] */
if
(
LIBXSMM_STATIC_TARGET_ARCH
>
feature_cpu
)
feature_cpu
=
LIBXSMM_STATIC_TARGET_ARCH
;
# endif
/* XSAVE/XGETBV(0x04000000), OSXSAVE(0x08000000) */
if
(
LIBXSMM_CPUID_CHECK
(
ecx
,
0x0C000000
))
{
/* OS SSE support */
feature_os
=
LIBXSMM_MIN
(
LIBXSMM_X86_SSE42
,
feature_cpu
);
if
(
LIBXSMM_X86_AVX
<=
feature_cpu
)
{
LIBXSMM_XGETBV
(
0
,
eax
,
edx
);
if
(
LIBXSMM_CPUID_CHECK
(
eax
,
0x00000006
))
{
/* OS XSAVE 256-bit */
feature_os
=
LIBXSMM_MIN
(
LIBXSMM_X86_AVX2
,
feature_cpu
);
if
(
LIBXSMM_CPUID_CHECK
(
eax
,
0x000000E0
))
{
/* OS XSAVE 512-bit */
feature_os
=
LIBXSMM_MIN
(
LIBXSMM_X86_AVX512_CPX
,
feature_cpu
);
if
(
LIBXSMM_X86_AVX512_SPR
<=
feature_cpu
&&
7
<=
maxleaf
&&
LIBXSMM_CPUID_CHECK
(
eax
,
0x00060000
))
/* OS XSAVE 512-bit */
{
feature_os
=
feature_cpu
;
/* unlimited AMX */
}
}
}
}
}
else
if
(
LIBXSMM_X86_GENERIC
<=
feature_cpu
)
{
/* assume FXSAVE, which should be fine
* 16 years after the first x86_64 OS
*/
feature_os
=
LIBXSMM_X86_SSE42
;
}
else
feature_os
=
LIBXSMM_TARGET_ARCH_GENERIC
;
has_context
=
(
LIBXSMM_STATIC_TARGET_ARCH
>=
feature_cpu
||
feature_os
>=
feature_cpu
)
?
1
:
0
;
if
(
LIBXSMM_TARGET_ARCH_UNKNOWN
==
result
&&
0
!=
libxsmm_verbosity
)
{
/* library code is expected to be mute */
# if !defined(LIBXSMM_TARGET_ARCH)
const
int
target_vlen32
=
libxsmm_cpuid_vlen32
(
feature_cpu
);
const
char
*
const
compiler_support
=
(
libxsmm_cpuid_vlen32
(
LIBXSMM_MAX_STATIC_TARGET_ARCH
)
<
target_vlen32
?
""
:
(((
2
<=
libxsmm_verbosity
||
0
>
libxsmm_verbosity
)
&&
LIBXSMM_MAX_STATIC_TARGET_ARCH
<
feature_cpu
)
?
"highly "
:
NULL
));
if
(
NULL
!=
compiler_support
)
{
const
char
*
const
name
=
libxsmm_cpuid_name
(
/* exclude MIC when running on Core processors */
(((
LIBXSMM_X86_AVX512_MIC
==
LIBXSMM_MAX_STATIC_TARGET_ARCH
)
||
(
LIBXSMM_X86_AVX512_KNM
==
LIBXSMM_MAX_STATIC_TARGET_ARCH
))
&&
(
LIBXSMM_X86_AVX512_CORE
<=
feature_cpu
))
?
LIBXSMM_X86_AVX2
:
LIBXSMM_MAX_STATIC_TARGET_ARCH
);
fprintf
(
stderr
,
"LIBXSMM WARNING: %soptimized non-JIT code paths are limited to
\"
%s
\"
!
\n
"
,
compiler_support
,
name
);
}
# endif
# if !defined(NDEBUG) && defined(__OPTIMIZE__)
fprintf
(
stderr
,
"LIBXSMM WARNING: library is optimized without -DNDEBUG and contains debug code!
\n
"
);
# endif
# if !defined(__APPLE__) || !defined(__MACH__)
/* permitted features */
if
(
0
==
has_context
)
{
fprintf
(
stderr
,
"LIBXSMM WARNING: detected CPU features are not permitted by the OS!
\n
"
);
if
(
0
==
libxsmm_se
)
{
fprintf
(
stderr
,
"LIBXSMM WARNING: downgraded code generation to supported features!
\n
"
);
}
}
# endif
}
/* macOS is faulting AVX-512 (on-demand larger state) */
result
=
feature_cpu
;
# if !defined(__APPLE__) || !defined(__MACH__)
# if 0
/* opportunistic */
if
(
0
==
libxsmm_se
)
# endif
{
/* only permitted features */
result
=
LIBXSMM_MIN
(
feature_cpu
,
feature_os
);
}
# endif
if
(
NULL
!=
info
)
{
LIBXSMM_CPUID_X86
(
0x80000007
,
0
/*ecx*/
,
eax
,
ebx
,
ecx
,
edx
);
info
->
constant_tsc
=
LIBXSMM_CPUID_CHECK
(
edx
,
0x00000100
);
info
->
has_context
=
has_context
;
}
}
}
else
{
if
(
NULL
!=
info
)
LIBXSMM_MEMZERO127
(
info
);
result
=
LIBXSMM_X86_GENERIC
;
}
#else
# if !defined(NDEBUG)
static
int
error_once
=
0
;
if
(
0
!=
libxsmm_verbosity
/* library code is expected to be mute */
&&
1
==
LIBXSMM_ATOMIC_ADD_FETCH
(
&
error_once
,
1
,
LIBXSMM_ATOMIC_RELAXED
))
{
fprintf
(
stderr
,
"LIBXSMM WARNING: libxsmm_cpuid_x86 called on non-x86 platform!
\n
"
);
}
# endif
if
(
NULL
!=
info
)
LIBXSMM_MEMZERO127
(
info
);
#endif
return
result
;
}
LIBXSMM_API
int
libxsmm_cpuid
(
void
)
{
#if defined(LIBXSMM_PLATFORM_X86)
return
libxsmm_cpuid_x86
(
NULL
/*info*/
);
#else
return
libxsmm_cpuid_arm
(
NULL
/*info*/
);
#endif
}
/**
* This implementation also accounts for non-x86 platforms,
* which not only allows to resolve any given ID but to
* fallback gracefully ("unknown").
*/
LIBXSMM_API
const
char
*
libxsmm_cpuid_name
(
int
id
)
{
const
char
*
target_arch
=
NULL
;
switch
(
id
)
{
case
LIBXSMM_X86_AVX512_SPR
:
{
target_arch
=
"spr"
;
}
break
;
case
LIBXSMM_X86_AVX512_CPX
:
{
target_arch
=
"cpx"
;
}
break
;
case
LIBXSMM_X86_AVX512_CLX
:
{
target_arch
=
"clx"
;
}
break
;
case
LIBXSMM_X86_AVX512_CORE
:
{
target_arch
=
"skx"
;
}
break
;
case
LIBXSMM_X86_AVX512_KNM
:
{
target_arch
=
"knm"
;
}
break
;
case
LIBXSMM_X86_AVX512_MIC
:
{
target_arch
=
"knl"
;
}
break
;
case
LIBXSMM_X86_AVX512
:
{
/* TODO: rework BE to use target ID instead of set of strings (target_arch = "avx3") */
target_arch
=
"hsw"
;
}
break
;
case
LIBXSMM_X86_AVX2
:
{
target_arch
=
"hsw"
;
}
break
;
case
LIBXSMM_X86_AVX
:
{
target_arch
=
"snb"
;
}
break
;
case
LIBXSMM_X86_SSE42
:
{
target_arch
=
"wsm"
;
}
break
;
case
LIBXSMM_X86_SSE3
:
{
target_arch
=
"sse3"
;
}
break
;
case
LIBXSMM_AARCH64_V81
:
{
target_arch
=
"aarch64"
;
}
break
;
case
LIBXSMM_AARCH64_A64FX
:
{
target_arch
=
"a64fx"
;
}
break
;
case
LIBXSMM_TARGET_ARCH_GENERIC
:
{
target_arch
=
"generic"
;
}
break
;
default:
if
(
LIBXSMM_X86_GENERIC
<=
id
)
{
target_arch
=
"x86_64"
;
}
else
{
target_arch
=
"unknown"
;
}
}
LIBXSMM_ASSERT
(
NULL
!=
target_arch
);
return
target_arch
;
}
/**
* This implementation also accounts for non-x86 platforms,
* which not only allows to resolve any given ID but to
* fallback gracefully (scalar).
*/
LIBXSMM_API
int
libxsmm_cpuid_vlen32
(
int
id
)
{
int
result
;
#if defined(LIBXSMM_PLATFORM_X86)
if
(
LIBXSMM_X86_AVX512
<=
id
)
{
result
=
16
;
}
else
if
(
LIBXSMM_X86_AVX
<=
id
)
{
result
=
8
;
}
else
if
(
LIBXSMM_X86_SSE42
<=
id
)
{
result
=
4
;
}
else
#elif defined(LIBXSMM_PLATFORM_AARCH64)
if
(
LIBXSMM_AARCH64_V81
==
id
)
{
result
=
4
;
}
else
if
(
LIBXSMM_AARCH64_A64FX
==
id
)
{
result
=
16
;
}
else
#else
LIBXSMM_UNUSED
(
id
);
#endif
{
/* scalar */
result
=
1
;
}
return
result
;
}
third_party/libxsmm/src/libxsmm_diff.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DIFF_H
#define LIBXSMM_DIFF_H
#include <libxsmm_intrinsics_x86.h>
#if !defined(LIBXSMM_DIFF_AVX512_ENABLED) && 0
# define LIBXSMM_DIFF_AVX512_ENABLED
#endif
#define LIBXSMM_DIFF_4_DECL(A) const uint32_t *
/*const*/
A = NULL
#define LIBXSMM_DIFF_4_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_4_LOAD(A, SRC) A = (const uint32_t*)(SRC)
#define LIBXSMM_DIFF_4(A, B, ...) ((unsigned char)(0 != (*(A) ^ (*(const uint32_t*)(B)))))
#define LIBXSMM_DIFF_8_DECL(A) const uint64_t *
/*const*/
A = NULL
#define LIBXSMM_DIFF_8_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_8_LOAD(A, SRC) A = (const uint64_t*)(SRC)
#define LIBXSMM_DIFF_8(A, B, ...) ((unsigned char)(0 != (*(A) ^ (*(const uint64_t*)(B)))))
#define LIBXSMM_DIFF_SSE_DECL(A) __m128i A = LIBXSMM_INTRINSICS_MM_UNDEFINED_SI128()
#define LIBXSMM_DIFF_SSE_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_SSE_LOAD(A, SRC) A = LIBXSMM_INTRINSICS_LOADU_SI128((const __m128i*)(SRC))
#define LIBXSMM_DIFF_SSE(A, B, ...) ((unsigned char)(0xFFFF != _mm_movemask_epi8(_mm_cmpeq_epi8( \
A, LIBXSMM_INTRINSICS_LOADU_SI128((const __m128i*)(B))))))
#if (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH)
/*|| defined(LIBXSMM_INTRINSICS_TARGET)*/
# define LIBXSMM_DIFF_16_DECL LIBXSMM_DIFF_SSE_DECL
# define LIBXSMM_DIFF_16_ASSIGN LIBXSMM_DIFF_SSE_ASSIGN
# define LIBXSMM_DIFF_16_LOAD LIBXSMM_DIFF_SSE_LOAD
# define LIBXSMM_DIFF_16 LIBXSMM_DIFF_SSE
#else
# define LIBXSMM_DIFF_16_DECL(A) const uint64_t *
/*const*/
A = NULL
# define LIBXSMM_DIFF_16_ASSIGN(A, B) (A) = (B)
# define LIBXSMM_DIFF_16_LOAD(A, SRC) A = (const uint64_t*)(SRC)
# define LIBXSMM_DIFF_16(A, B, ...) ((unsigned char)(0 != (((A)[0] ^ (*(const uint64_t*)(B))) | \
((A)[1] ^ ((const uint64_t*)(B))[1]))))
#endif
#define LIBXSMM_DIFF_AVX2_DECL(A) __m256i A = LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256()
#define LIBXSMM_DIFF_AVX2_ASSIGN(A, B) (A) = (B)
#define LIBXSMM_DIFF_AVX2_LOAD(A, SRC) A = _mm256_loadu_si256((const __m256i*)(SRC))
#define LIBXSMM_DIFF_AVX2(A, B, ...) ((unsigned char)(-1 != _mm256_movemask_epi8(_mm256_cmpeq_epi8( \
A, _mm256_loadu_si256((const __m256i*)(B))))))
#if (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_DIFF_32_DECL LIBXSMM_DIFF_AVX2_DECL
# define LIBXSMM_DIFF_32_ASSIGN LIBXSMM_DIFF_AVX2_ASSIGN
# define LIBXSMM_DIFF_32_LOAD LIBXSMM_DIFF_AVX2_LOAD
# define LIBXSMM_DIFF_32 LIBXSMM_DIFF_AVX2
#else
# define LIBXSMM_DIFF_32_DECL(A) LIBXSMM_DIFF_16_DECL(A); LIBXSMM_DIFF_16_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _))
# define LIBXSMM_DIFF_32_ASSIGN(A, B) LIBXSMM_DIFF_16_ASSIGN(A, B); LIBXSMM_DIFF_16_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_32_, B, _))
# define LIBXSMM_DIFF_32_LOAD(A, SRC) LIBXSMM_DIFF_16_LOAD(A, SRC); LIBXSMM_DIFF_16_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), (const uint64_t*)(SRC) + 2)
# define LIBXSMM_DIFF_32(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_16(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_16(LIBXSMM_CONCATENATE3(libxsmm_diff_32_, A, _), (const uint64_t*)(B) + 2, __VA_ARGS__)))
#endif
#define LIBXSMM_DIFF_48_DECL(A) LIBXSMM_DIFF_16_DECL(A); LIBXSMM_DIFF_32_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _))
#define LIBXSMM_DIFF_48_ASSIGN(A, B) LIBXSMM_DIFF_16_ASSIGN(A, B); LIBXSMM_DIFF_32_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_48_, B, _))
#define LIBXSMM_DIFF_48_LOAD(A, SRC) LIBXSMM_DIFF_16_LOAD(A, SRC); LIBXSMM_DIFF_32_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), (const uint64_t*)(SRC) + 2)
#define LIBXSMM_DIFF_48(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_16(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_32(LIBXSMM_CONCATENATE3(libxsmm_diff_48_, A, _), (const uint64_t*)(B) + 2, __VA_ARGS__)))
#define LIBXSMM_DIFF_64SW_DECL(A) LIBXSMM_DIFF_32_DECL(A); LIBXSMM_DIFF_32_DECL(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _))
#define LIBXSMM_DIFF_64SW_ASSIGN(A, B) LIBXSMM_DIFF_32_ASSIGN(A, B); LIBXSMM_DIFF_32_ASSIGN(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), LIBXSMM_CONCATENATE3(libxsmm_diff_64_, B, _))
#define LIBXSMM_DIFF_64SW_LOAD(A, SRC) LIBXSMM_DIFF_32_LOAD(A, SRC); LIBXSMM_DIFF_32_LOAD(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), (const uint64_t*)(SRC) + 4)
#define LIBXSMM_DIFF_64SW(A, B, ...) ((unsigned char)(0 != LIBXSMM_DIFF_32(A, B, __VA_ARGS__) ? 1 : LIBXSMM_DIFF_32(LIBXSMM_CONCATENATE3(libxsmm_diff_64_, A, _), (const uint64_t*)(B) + 4, __VA_ARGS__)))
#if defined(LIBXSMM_DIFF_AVX512_ENABLED)
# define LIBXSMM_DIFF_AVX512_DECL(A) __m512i A = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32()
# define LIBXSMM_DIFF_AVX512_ASSIGN(A, B) (A) = (B)
# define LIBXSMM_DIFF_AVX512_LOAD(A, SRC) A = _mm512_loadu_si512((const __m512i*)(SRC))
# define LIBXSMM_DIFF_AVX512(A, B, ...) ((unsigned char)(0xFFFF != (unsigned int)
/*_cvtmask16_u32*/
(_mm512_cmpeq_epi32_mask( \
A, _mm512_loadu_si512((const __m512i*)(B))))))
#else
# define LIBXSMM_DIFF_AVX512_DECL LIBXSMM_DIFF_64SW_DECL
# define LIBXSMM_DIFF_AVX512_ASSIGN LIBXSMM_DIFF_64SW_ASSIGN
# define LIBXSMM_DIFF_AVX512_LOAD LIBXSMM_DIFF_64SW_LOAD
# define LIBXSMM_DIFF_AVX512 LIBXSMM_DIFF_64SW
#endif
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_DIFF_64_DECL LIBXSMM_DIFF_AVX512_DECL
# define LIBXSMM_DIFF_64_ASSIGN LIBXSMM_DIFF_AVX512_ASSIGN
# define LIBXSMM_DIFF_64_LOAD LIBXSMM_DIFF_AVX512_LOAD
# define LIBXSMM_DIFF_64 LIBXSMM_DIFF_AVX512
#else
# define LIBXSMM_DIFF_64_DECL LIBXSMM_DIFF_64SW_DECL
# define LIBXSMM_DIFF_64_ASSIGN LIBXSMM_DIFF_64SW_ASSIGN
# define LIBXSMM_DIFF_64_LOAD LIBXSMM_DIFF_64SW_LOAD
# define LIBXSMM_DIFF_64 LIBXSMM_DIFF_64SW
#endif
#define LIBXSMM_DIFF_DECL(N, A) LIBXSMM_CONCATENATE3(LIBXSMM_DIFF_, N, _DECL)(A)
#define LIBXSMM_DIFF_LOAD(N, A, SRC) LIBXSMM_CONCATENATE3(LIBXSMM_DIFF_, N, _LOAD)(A, SRC)
#define LIBXSMM_DIFF(N) LIBXSMM_CONCATENATE(LIBXSMM_DIFF_, N)
#define LIBXSMM_DIFF_N(TYPE, RESULT, DIFF, A, BN, ELEMSIZE, STRIDE, HINT, N) { \
const char* libxsmm_diff_b_ = (const char*)(BN) + (size_t)(HINT) * (STRIDE); \
for (RESULT = (HINT); (RESULT) < (N); ++(RESULT)) { \
if (0 == DIFF(A, libxsmm_diff_b_, ELEMSIZE)) break; \
libxsmm_diff_b_ += (STRIDE); \
} \
if ((N) == (RESULT)) {
/* wrong hint */
\
TYPE libxsmm_diff_r_ = 0; \
libxsmm_diff_b_ = (const char*)(BN);
/* reset */
\
for (; libxsmm_diff_r_ < (HINT); ++libxsmm_diff_r_) { \
if (0 == DIFF(A, libxsmm_diff_b_, ELEMSIZE)) { \
RESULT = libxsmm_diff_r_; \
break; \
} \
libxsmm_diff_b_ += (STRIDE); \
} \
} \
}
/** Function type representing the diff-functionality. */
LIBXSMM_EXTERN_C
typedef
LIBXSMM_RETARGETABLE
unsigned
int
(
*
libxsmm_diff_function
)(
const
void
*
/*a*/
,
const
void
*
/*b*/
,
...
/*size*/
);
/** Compare two data blocks of 4 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_4
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 8 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_8
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 16 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_16
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 32 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_32
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 48 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_48
(
const
void
*
a
,
const
void
*
b
,
...);
/** Compare two data blocks of 64 Byte each. */
LIBXSMM_API
unsigned
char
libxsmm_diff_64
(
const
void
*
a
,
const
void
*
b
,
...);
#endif
/*LIBXSMM_DIFF_H*/
third_party/libxsmm/src/libxsmm_dnn.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include <libxsmm_dnn.h>
#include "libxsmm_main.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(_OPENMP)
# include <omp.h>
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
LIBXSMM_API_INTERN
void
libxsmm_dnn_init
(
int
target_arch
)
{
LIBXSMM_UNUSED
(
target_arch
);
}
LIBXSMM_API_INTERN
void
libxsmm_dnn_finalize
(
void
)
{
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_get_feature_map_blocks
(
int
C
,
int
K
,
int
*
C_block
,
int
*
K_block
,
int
*
fm_lp_block
,
libxsmm_dnn_datatype
datatype_in
,
libxsmm_dnn_datatype
datatype_out
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
int
ifmblock
=
0
;
int
ofmblock
=
0
;
int
lp_block
=
0
;
int
tmp_max_c_block
=
64
;
int
tmp_max_k_block
=
64
;
int
tmp_block
=
0
;
/* init libxsmm */
LIBXSMM_INIT
/* C */
if
(
((
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
&&
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
))
||
(
libxsmm_target_archid
<
LIBXSMM_X86_AVX512
)
)
{
tmp_max_c_block
=
32
;
}
else
if
(
libxsmm_target_archid
==
LIBXSMM_AARCH64_V81
)
{
tmp_max_c_block
=
16
;
}
if
(
C
<
tmp_max_c_block
)
{
ifmblock
=
C
;
}
else
{
for
(
tmp_block
=
1
;
tmp_block
<=
tmp_max_c_block
;
tmp_block
*=
2
)
{
if
(
C
%
tmp_block
==
0
)
ifmblock
=
tmp_block
;
}
}
/* K */
if
(
((
libxsmm_target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
&&
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
))
||
(
libxsmm_target_archid
<
LIBXSMM_X86_AVX512
)
)
{
tmp_max_k_block
=
32
;
}
else
if
(
libxsmm_target_archid
==
LIBXSMM_AARCH64_V81
)
{
tmp_max_k_block
=
16
;
}
if
(
K
<
tmp_max_k_block
)
{
ofmblock
=
K
;
}
else
{
for
(
tmp_block
=
1
;
tmp_block
<=
tmp_max_k_block
;
tmp_block
*=
2
)
{
if
(
K
%
tmp_block
==
0
)
ofmblock
=
tmp_block
;
}
}
/* when do we need VNNI format? */
if
(
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
lp_block
=
1
;
}
else
if
(
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
lp_block
=
2
;
}
else
if
(
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
((
datatype_out
==
LIBXSMM_DNN_DATATYPE_I32
)
||
(
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
)
{
lp_block
=
2
;
}
else
if
(
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
lp_block
=
4
;
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
*
C_block
=
ifmblock
;
*
K_block
=
ofmblock
;
*
fm_lp_block
=
lp_block
;
return
status
;
}
LIBXSMM_API
const
char
*
libxsmm_dnn_get_error
(
libxsmm_dnn_err_t
code
)
{
switch
(
code
)
{
case
LIBXSMM_DNN_SUCCESS
:
return
"LIBXSMM DNN Success!"
;
case
LIBXSMM_DNN_WARN_FALLBACK
:
return
"LIBXSMM DNN Warning: Falling back to naive code as target is currently not supported by LIBXSMM!"
;
case
LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING
:
return
"LIBXSMM DNN Warning: RNN cell suboptimal minibatch blocking!"
;
case
LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING
:
return
"LIBXSMM DNN Warning: RNN cell suboptimal input feature blocking!"
;
case
LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING
:
return
"LIBXSMM DNN Warning: RNN cell suboptimal output feature blocking!"
;
case
LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING
:
return
"LIBXSMM DNN Warning: FC layer suboptimal minibatch blocking!"
;
case
LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING
:
return
"LIBXSMM DNN Warning: FC layer suboptimal input feature blocking!"
;
case
LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING
:
return
"LIBXSMM DNN Warning: FC layer suboptimal output feature blocking!"
;
case
LIBXSMM_DNN_ERR_GENERAL
:
return
"LIBXSMM DNN Error: General error occurred!"
;
case
LIBXSMM_DNN_ERR_CREATE_HANDLE
:
return
"LIBXSMM DNN Error: Handle creation failed!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
:
return
"LIBXSMM DNN Error: Requested datatype is not available!"
;
case
LIBXSMM_DNN_ERR_INVALID_BLOCKING
:
return
"LIBXSMM DNN Error: Requested Input/Output buffer size cannot be blocked!"
;
case
LIBXSMM_DNN_ERR_INVALID_HANDLE
:
return
"LIBXSMM DNN Error: An invalid handle was provided!"
;
case
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
:
return
"LIBXSMM DNN Error: Not all required sources and destinations have been bound to convolution!"
;
case
LIBXSMM_DNN_ERR_CREATE_TENSOR
:
return
"LIBXSMM DNN Error: Tensor creation failed!"
;
case
LIBXSMM_DNN_ERR_INVALID_TENSOR
:
return
"LIBXSMM DNN Error: Invalid tensor was specified!"
;
case
LIBXSMM_DNN_ERR_MISMATCH_TENSOR
:
return
"LIBXSMM DNN Error: Tensor doesn't match handle it should be bind to!"
;
case
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
:
return
"LIBXSMM DNN Error: Invalid handle or tensor!"
;
case
LIBXSMM_DNN_ERR_INVALID_KIND
:
return
"LIBXSMM DNN Error: Invalid convolution kind!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW
:
return
"LIBXSMM DNN Error: NCHW format is currently not natively supported by LIBXSMM!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT
:
return
"LIBXSMM DNN Error: Unsupported destination format when copying data!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT
:
return
"LIBXSMM DNN Error: Unsupported source format when copying data!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
:
return
"LIBXSMM DNN Error: Unsupported format when requesting a convolution!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS
:
return
"LIBXSMM DNN Error: KCRS format is currently not natively supported by LIBXSMM!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
:
return
"LIBXSMM DNN Error: Invalid format was specified!"
;
case
LIBXSMM_DNN_ERR_CREATE_LAYOUT
:
return
"LIBXSMM DNN Error: Layout creation failed!"
;
case
LIBXSMM_DNN_ERR_INVALID_LAYOUT
:
return
"LIBXSMM DNN Error: Invalid layout was specified!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
:
return
"LIBXSMM DNN Error: Unsupported architecture!"
;
case
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
:
return
"LIBXSMM DNN Error: scratch binding failed as scratch was not allocated!"
;
case
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
:
return
"LIBXSMM DNN Error: an unknown tensor type was provided!"
;
case
LIBXSMM_DNN_ERR_INVALID_ALGO
:
return
"LIBXSMM DNN Error: Invalid algorithm was specified!"
;
case
LIBXSMM_DNN_ERR_INVALID_PADDING
:
return
"LIBXSMM DNN Error: Invalid padding was specified!"
;
case
LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL
:
return
"LIBXSMM DNN Error: time steps should be >= 2 for RNN/LSTM!"
;
case
LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS
:
return
"LIBXSMM DNN Error: failed to create internal layout arrays!"
;
case
LIBXSMM_DNN_ERR_NOT_IMPLEMENTED
:
return
"LIBXSMM DNN Error: the requested functionality is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER
:
return
"LIBXSMM DNN Error: the requested order of fusion in batch norm is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION
:
return
"LIBXSMM DNN Error: the requested fusion in batch norm is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN
:
return
"LIBXSMM DNN Error: Unsupported format when requesting a fused batch norm!"
;
case
LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING
:
return
"LIBXSMM DNN Error: Unsupported pooling operations was requested!"
;
case
LIBXSMM_DNN_ERR_INVALID_FORMAT_FC
:
return
"LIBXSMM DNN Error: Unsupported format when requesting a fullyconnected layer!"
;
case
LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN
:
return
"LIBXSMM DNN Error: max sequence length is shorter than sequence length we attempt to set!"
;
case
LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER
:
return
"LIBXSMM DNN Error: the requested order of fusion in group norm is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION
:
return
"LIBXSMM DNN Error: the requested fusion in group norm is right now not implemented!"
;
case
LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION
:
return
"LIBXSMM DNN Error: the requested fusion in fullyconnected is right now not implemented!"
;
default:
return
"LIBXSMM DNN Error: Unknown error or warning occurred!"
;
}
}
LIBXSMM_API
size_t
libxsmm_dnn_typesize
(
libxsmm_dnn_datatype
datatype
)
{
switch
(
datatype
)
{
case
LIBXSMM_DNN_DATATYPE_F32
:
return
4
;
case
LIBXSMM_DNN_DATATYPE_I32
:
return
4
;
case
LIBXSMM_DNN_DATATYPE_BF16
:
return
2
;
case
LIBXSMM_DNN_DATATYPE_I16
:
return
2
;
case
LIBXSMM_DNN_DATATYPE_I8
:
return
1
;
/* no error expected as enumeration really arrives at an enum; compiler-checked */
default:
return
1
;
}
}
LIBXSMM_API
size_t
libxsmm_dnn_get_simd_width
(
libxsmm_dnn_datatype
datatype
)
{
size_t
l_cl_width_bytes
;
/* init libxsmm */
LIBXSMM_INIT
if
(
libxsmm_target_archid
==
LIBXSMM_X86_GENERIC
||
libxsmm_target_archid
==
LIBXSMM_X86_SSE3
||
libxsmm_target_archid
==
LIBXSMM_X86_SSE42
)
{
l_cl_width_bytes
=
16
;
}
else
if
(
libxsmm_target_archid
==
LIBXSMM_X86_AVX2
||
libxsmm_target_archid
==
LIBXSMM_X86_AVX
)
{
l_cl_width_bytes
=
32
;
}
else
{
l_cl_width_bytes
=
64
;
}
return
l_cl_width_bytes
/
libxsmm_dnn_typesize
(
datatype
);
}
LIBXSMM_API_INLINE
float
libxsmm_internal_get_max
(
float
*
in_buffer
,
int
length
)
{
float
absmax_value
=
LIBXSMM_ABS
(
in_buffer
[
0
]);
int
i
=
0
;
#ifdef _OPENMP
LIBXSMM_OMP_VAR
(
i
);
# pragma omp parallel private(i)
{
float
my_absmax_value
=
absmax_value
;
# pragma omp for
for
(
i
=
0
;
i
<
length
;
++
i
)
{
if
(
LIBXSMM_ABS
(
in_buffer
[
i
])
>
my_absmax_value
)
{
my_absmax_value
=
LIBXSMM_ABS
(
in_buffer
[
i
]);
}
}
# pragma omp critical
{
if
(
my_absmax_value
>
absmax_value
)
{
absmax_value
=
my_absmax_value
;
}
}
}
#else
for
(
i
=
1
;
i
<
length
;
++
i
)
{
if
(
LIBXSMM_ABS
(
in_buffer
[
i
])
>
absmax_value
)
{
absmax_value
=
LIBXSMM_ABS
(
in_buffer
[
i
]);
}
}
#endif
return
absmax_value
;
}
LIBXSMM_API_INLINE
unsigned
char
libxsmm_internal_get_max_exp
(
float
*
in_buffer
,
int
length
)
{
libxsmm_intfloat
val_exp
;
unsigned
char
max_exp
=
0
;
/* bit-wise conversion to int */
val_exp
.
f
=
libxsmm_internal_get_max
(
in_buffer
,
length
);
/* shift by mantissa to the right and convert to char */
max_exp
=
(
unsigned
char
)((
val_exp
.
ui
&
LIBXSMM_DNN_MASK_ABS_F32
)
>>
LIBXSMM_DNN_MANT_SZ_F32
);
return
max_exp
;
}
LIBXSMM_API_INLINE
short
libxsmm_internal_quantize_scalar_no_scf
(
float
input
,
unsigned
char
max_exp
,
unsigned
char
add_shift
,
int
round_mode
)
{
libxsmm_intfloat
value
;
unsigned
int
qvalue
=
0
;
unsigned
int
mant
=
0
;
unsigned
int
sign
=
0
;
unsigned
char
rhs
=
0
;
unsigned
char
exp_off
=
0
;
/* init libxsmm */
LIBXSMM_INIT
/* in case of zero we don't need to do anything */
if
(
LIBXSMM_FEQ
(
input
,
0
))
{
qvalue
=
0
;
}
else
{
/* let's get a float copy to work on */
/* vinp = LIBXSMM_INTRINSICS_MM512_LOAD_PS( in_buffer[i] ); */
value
.
f
=
input
;
/* let's compute the offset of the current exp at pos i from max offset, we need to mask the sign bit though */
/*__m512i vexp = _mm512_cvtps_epi32(_mm512_getexp_ps (vinp));
__m512i vexp_off = _mm512_sub_epi32(maxexpf, vexp);*/
exp_off
=
(
unsigned
char
)(
max_exp
-
((
value
.
ui
&
LIBXSMM_DNN_MASK_ABS_F32
)
>>
LIBXSMM_DNN_MANT_SZ_F32
));
/* cut out mantissa and set leading bit */
/*__m512i mmask = _mm512_set1_epi32(LIBXSMM_DNN_MASK_MANT_F32);
__m512i vmant = _mm512_or_epi32(_mm512_set1_epi32(0x1 << LIBXSMM_DNN_MANT_SZ_F32), _mm512_and_epi32( _mm512_castps_si512( vinp ), mmask));*/
mant
=
((
0x1
<<
LIBXSMM_DNN_MANT_SZ_F32
)
|
(
value
.
ui
&
LIBXSMM_DNN_MASK_MANT_F32
));
/* extract sign */
/* __mmask16 smask = _mm512_cmplt_ps_mask (inp, _mm512_set1_ps(0)); */
sign
=
((
value
.
ui
&
LIBXSNN_DNN_MASK_SIGN_F32
)
>>
(
LIBXSMM_DNN_SZ_F32
-
1
));
/* calculate rhs, be aware of the now explicit leading bit, @TODO add DFP8/4 */
rhs
=
(
unsigned
char
)((
LIBXSMM_DNN_MANT_SZ_F32
+
1
)
-
LIBXSMM_DNN_MANT_DFP16
+
exp_off
+
add_shift
);
/* some safety, to generate 0 when we fall off quant region, @TODO issue a LIBXSMM WARNING: that we shifted out the entire mantissa */
if
(
rhs
>
(
LIBXSMM_DNN_MANT_SZ_F32
+
1
))
{
rhs
=
(
LIBXSMM_DNN_MANT_SZ_F32
+
1
);
}
/* finally shift the value into the region we need, this is now a 15-add_rhs bit number for the max value in in_buffer */
qvalue
=
(
mant
>>
rhs
);
/* handle sign, 2 complement */
if
(
(
sign
>
0
)
&&
(
qvalue
>
0
)
)
{
qvalue
=
(
~
qvalue
+
1
);
}
if
(
round_mode
==
LIBXSMM_DNN_QUANT_BIAS_ROUND
)
{
/* biased rounding towards next bigger number */
/* first let's determine in the original number if we need a bias rounding, @TODO need fix for F64 */
int
bias_needed
=
(
mant
&
(
0x3
<<
(
rhs
-
2
)));
/* apply bias */
if
(
bias_needed
>
0
)
{
qvalue
++
;
}
}
else
if
(
round_mode
==
LIBXSMM_DNN_QUANT_NEAREST_ROUND
)
{
int
nearest_needed
=
(
mant
&
(
0x1
<<
(
rhs
-
1
)));
/* apply rounding */
if
((
nearest_needed
>
0
)
&&
(
rhs
>
1
))
{
qvalue
++
;
}
}
else
if
(
round_mode
==
LIBXSMM_DNN_QUANT_STOCH_ROUND
)
{
/* stochastic rounding, as implemented in the IBM paper from 2015, @TODO, fix F64 and DFP8 */
const
float
eps
=
LIXSMMM_DNN_RES_DFP16
;
/* coverity[dont_call] */
const
float
r
=
(
float
)
rand
();
libxsmm_intfloat
fvalue
;
float
p
,
q
;
/* masking all bits which will be shifted out */
fvalue
.
ui
=
value
.
ui
&
((
LIBXSMM_DNN_MASK_FULL_F32
)
<<
rhs
);
/* drawing a random number */
p
=
r
/
((
float
)
RAND_MAX
);
q
=
(
input
-
fvalue
.
f
)
/
eps
;
/* apply rounding if needed */
if
((
p
+
q
)
>
0
.
5
f
)
{
++
qvalue
;
}
}
else
{
/* do nothing about rounding, just chop */
}
}
return
(
short
)
qvalue
;
}
/* @TODO make this routine aware of any int type */
LIBXSMM_API
void
libxsmm_dnn_quantize
(
float
*
in_buffer
,
short
*
out_buffer
,
int
length
,
unsigned
char
add_shift
,
unsigned
char
*
scf
,
int
round_mode
)
{
int
i
=
0
;
/* init libxsmm */
LIBXSMM_INIT
/* in case we are using FP-Mul based quantization we use a different path for now
@TODO let's unify the paths by using the similar vectorization for both */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_FPHW_ROUND
)
{
const
float
max_value
=
libxsmm_internal_get_max
(
in_buffer
,
length
);
int
maxexp
=
0
;
/* take return value of LIBXSMM_FREXPF to mute static analysis issue */
float
scfq
=
LIBXSMM_FREXPF
(
max_value
,
&
maxexp
);
maxexp
-=
(
15
/*LIBXSMM_DNN_MANT_DFP16?*/
-
add_shift
);
scfq
=
libxsmm_sexp2_i8i
(
-
maxexp
);
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
if
(
length
%
16
==
0
)
{
__m512
vscfq
=
_mm512_set1_ps
(
scfq
);
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for
(
i
=
0
;
i
<
length
;
i
+=
16
)
{
_mm256_stream_si256
(
(
__m256i
*
)
&
(
out_buffer
[
i
]),
LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16
(
&
(
in_buffer
[
i
]),
vscfq
)
);
}
}
else
{
#endif
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for
(
i
=
0
;
i
<
length
;
++
i
)
{
out_buffer
[
i
]
=
(
short
)
LIBXSMM_ROUNDF
(
in_buffer
[
i
]
*
scfq
);
}
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
}
#endif
/* @TODO, we need to potentially fix this unsigned char problem */
#if !defined(NDEBUG)
/* library code is expected to be mute */
if
(
maxexp
>
0
)
{
fprintf
(
stderr
,
"error quant fil
\n
"
);
}
#endif
*
scf
=
(
unsigned
char
)(
-
maxexp
);
}
else
{
/* get max exponent */
unsigned
char
max_exp
=
libxsmm_internal_get_max_exp
(
in_buffer
,
length
);
/* if we go for stochastic rounding, let's initialize random seed */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_STOCH_ROUND
)
{
srand
(
libxsmm_timer_tick
()
%
((
unsigned
int
)
-
1
));
}
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for
(
i
=
0
;
i
<
length
;
++
i
)
{
out_buffer
[
i
]
=
libxsmm_internal_quantize_scalar_no_scf
(
in_buffer
[
i
],
max_exp
,
add_shift
,
round_mode
);
}
*
scf
=
(
unsigned
char
)(
14
-
add_shift
-
(
max_exp
-
127
));
}
}
LIBXSMM_API
void
libxsmm_dnn_quantize_act
(
float
*
in_buffer
,
short
*
out_buffer
,
unsigned
int
N
,
unsigned
int
C
,
unsigned
int
H
,
unsigned
int
W
,
unsigned
int
cblk_f32
,
unsigned
int
cblk_i16
,
unsigned
int
lp_blk
,
unsigned
char
add_shift
,
unsigned
char
*
scf
,
int
round_mode
)
{
LIBXSMM_VLA_DECL
(
5
,
const
float
,
in
,
in_buffer
,
C
/
cblk_f32
,
H
,
W
,
cblk_f32
);
LIBXSMM_VLA_DECL
(
6
,
short
,
out
,
out_buffer
,
C
/
(
cblk_i16
*
lp_blk
),
H
,
W
,
cblk_i16
,
lp_blk
);
const
unsigned
int
cblk
=
C
/
(
cblk_i16
*
lp_blk
);
int
i1
=
0
,
i2
=
0
,
i3
=
0
,
i4
=
0
,
i5
,
i6
;
/* init libxsmm */
LIBXSMM_INIT
/* some quick and dirty checks */
assert
((
C
%
cblk_f32
)
==
0
);
assert
((
C
%
cblk_i16
)
==
0
);
/* in case we are using FP-Mul based quantization we use a different path for now
@TODO let's unify the paths by using the similar vectorization for both */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_FPHW_ROUND
)
{
const
float
max_value
=
libxsmm_internal_get_max
(
in_buffer
,
N
*
C
*
H
*
W
);
int
maxexp
=
0
;
/* take return value of LIBXSMM_FREXPF to mute static analysis issue */
float
scfq
=
LIBXSMM_FREXPF
(
max_value
,
&
maxexp
);
maxexp
-=
(
15
/*LIBXSMM_DNN_MANT_DFP16?*/
-
add_shift
);
scfq
=
libxsmm_sexp2_i8i
(
-
maxexp
);
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
if
(
(
cblk_f32
==
16
)
&&
(
cblk_i16
*
lp_blk
==
16
)
)
{
__m512
vscfq
=
_mm512_set1_ps
(
scfq
);
#ifdef _OPENMP
LIBXSMM_OMP_VAR
(
i1
);
# pragma omp parallel for private(i1)
#endif
for
(
i1
=
0
;
i1
<
(
int
)(
N
*
C
*
H
*
W
);
i1
+=
16
)
{
_mm256_stream_si256
(
(
__m256i
*
)
&
(
out_buffer
[
i1
]),
LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16
(
&
(
in_buffer
[
i1
]),
vscfq
)
);
}
}
else
{
#endif
#ifdef _OPENMP
LIBXSMM_OMP_VAR
(
i1
);
LIBXSMM_OMP_VAR
(
i2
);
LIBXSMM_OMP_VAR
(
i3
);
LIBXSMM_OMP_VAR
(
i4
);
LIBXSMM_OMP_VAR
(
i5
);
LIBXSMM_OMP_VAR
(
i6
);
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
N
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
H
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
W
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
(
int
)
cblk_i16
;
++
i5
)
{
for
(
i6
=
0
;
i6
<
(
int
)
lp_blk
;
++
i6
)
{
const
int
fi1
=
i1
;
const
int
fi2
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i6
)
/
cblk_f32
;
const
int
fi3
=
i3
;
const
int
fi4
=
i4
;
const
int
fi5
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i6
)
%
cblk_f32
;
LIBXSMM_VLA_ACCESS
(
6
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
cblk
,
H
,
W
,
cblk_i16
,
lp_blk
)
=
(
short
)
LIBXSMM_ROUNDF
(
LIBXSMM_VLA_ACCESS
(
5
,
in
,
fi1
,
fi2
,
fi3
,
fi4
,
fi5
,
C
/
cblk_f32
,
H
,
W
,
cblk_f32
)
*
scfq
);
}
}
}
}
}
}
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
}
#endif
/* @TODO, we need to potentially fix this unsigned char problem */
#if !defined(NDEBUG)
/* library code is expected to be mute */
if
(
maxexp
>
0
)
{
fprintf
(
stderr
,
"error quant act
\n
"
);
}
#endif
*
scf
=
(
unsigned
char
)(
-
maxexp
);
}
else
{
/* get max exponent */
unsigned
char
max_exp
=
libxsmm_internal_get_max_exp
(
in_buffer
,
N
*
C
*
H
*
W
);
/* if we go for stochastic rounding, let's initialize random seed */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_STOCH_ROUND
)
{
srand
(
libxsmm_timer_tick
()
%
((
unsigned
int
)
-
1
));
}
#ifdef _OPENMP
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
N
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
H
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
W
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
(
int
)
cblk_i16
;
++
i5
)
{
for
(
i6
=
0
;
i6
<
(
int
)
lp_blk
;
++
i6
)
{
const
int
fi1
=
i1
;
const
int
fi2
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i6
)
/
cblk_f32
;
const
int
fi3
=
i3
;
const
int
fi4
=
i4
;
const
int
fi5
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i6
)
%
cblk_f32
;
LIBXSMM_VLA_ACCESS
(
6
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
cblk
,
H
,
W
,
cblk_i16
,
lp_blk
)
=
libxsmm_internal_quantize_scalar_no_scf
(
LIBXSMM_VLA_ACCESS
(
5
,
in
,
fi1
,
fi2
,
fi3
,
fi4
,
fi5
,
C
/
cblk_f32
,
H
,
W
,
cblk_f32
),
max_exp
,
add_shift
,
round_mode
);
}
}
}
}
}
}
*
scf
=
(
unsigned
char
)(
14
-
add_shift
-
(
max_exp
-
127
));
}
}
LIBXSMM_API
void
libxsmm_dnn_quantize_fil
(
float
*
in_buffer
,
short
*
out_buffer
,
unsigned
int
K
,
unsigned
int
C
,
unsigned
int
R
,
unsigned
int
S
,
unsigned
int
cblk_f32
,
unsigned
int
cblk_i16
,
unsigned
int
kblk_f32
,
unsigned
int
kblk_i16
,
unsigned
int
lp_blk
,
unsigned
char
add_shift
,
unsigned
char
*
scf
,
int
round_mode
)
{
LIBXSMM_VLA_DECL
(
6
,
const
float
,
in
,
in_buffer
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
);
LIBXSMM_VLA_DECL
(
7
,
short
,
out
,
out_buffer
,
C
/
(
cblk_i16
*
lp_blk
),
R
,
S
,
cblk_i16
,
kblk_i16
,
lp_blk
);
unsigned
int
cblk
=
C
/
(
cblk_i16
*
lp_blk
);
unsigned
int
kblk
=
K
/
kblk_i16
;
int
i1
=
0
,
i2
=
0
,
i3
=
0
,
i4
=
0
,
i5
,
i6
,
i7
;
/* some quick and dirty checks */
assert
((
C
%
cblk_f32
)
==
0
);
assert
((
C
%
(
cblk_i16
*
lp_blk
))
==
0
);
assert
((
K
%
kblk_f32
)
==
0
);
assert
((
K
%
kblk_i16
)
==
0
);
assert
((
lp_blk
%
2
)
==
0
);
/* init libxsmm */
LIBXSMM_INIT
/* in case we are using FP-Mul based quantization we use a different path for now
@TODO let's unify the paths by using the similar vectorization for both */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_FPHW_ROUND
)
{
const
float
max_value
=
libxsmm_internal_get_max
(
in_buffer
,
K
*
C
*
R
*
S
);
int
maxexp
=
0
;
/* take return value of LIBXSMM_FREXPF to mute static analysis issue */
float
scfq
=
LIBXSMM_FREXPF
(
max_value
,
&
maxexp
);
maxexp
-=
(
15
/*LIBXSMM_DNN_MANT_DFP16?*/
-
add_shift
);
scfq
=
libxsmm_sexp2_i8i
(
-
maxexp
);
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
if
(
(
kblk_f32
==
16
)
&&
(
cblk_f32
==
16
)
&&
(
kblk_i16
==
16
)
&&
(
cblk_i16
*
lp_blk
==
16
)
)
{
const
__m512
vscfq
=
_mm512_set1_ps
(
scfq
);
const
__m512i
permute_compact_idx
=
_mm512_set_epi32
(
15
,
14
,
13
,
12
,
7
,
6
,
5
,
4
,
11
,
10
,
9
,
8
,
3
,
2
,
1
,
0
);
#ifdef _OPENMP
# pragma omp parallel for private(i1, i2, i3, i4, i5) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
kblk
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
R
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
S
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
16
;
i5
+=
2
)
{
__m256i
even_ch
=
LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16
(
&
LIBXSMM_VLA_ACCESS
(
6
,
in
,
i1
,
i2
,
i3
,
i4
,
i5
+
0
,
0
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
),
vscfq
);
__m256i
odd_ch
=
LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16
(
&
LIBXSMM_VLA_ACCESS
(
6
,
in
,
i1
,
i2
,
i3
,
i4
,
i5
+
1
,
0
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
),
vscfq
);
__m256i
compressed_lo
=
_mm256_unpacklo_epi16
(
even_ch
,
odd_ch
);
__m256i
compressed_hi
=
_mm256_unpackhi_epi16
(
even_ch
,
odd_ch
);
__m512i
compact
=
_mm512_inserti64x4
(
_mm512_setzero_si512
(),
compressed_lo
,
0
);
compact
=
_mm512_inserti64x4
(
compact
,
compressed_hi
,
1
);
compact
=
_mm512_permutexvar_epi32
(
permute_compact_idx
,
compact
);
LIBXSMM_INTRINSICS_MM512_STREAM_SI512
(
(
void
*
)
&
LIBXSMM_VLA_ACCESS
(
7
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
/
2
,
0
,
0
,
cblk
,
R
,
S
,
cblk_i16
,
kblk_i16
,
lp_blk
),
compact
);
}
}
}
}
}
}
else
{
#endif
#ifdef _OPENMP
LIBXSMM_OMP_VAR
(
i1
);
LIBXSMM_OMP_VAR
(
i2
);
LIBXSMM_OMP_VAR
(
i3
);
LIBXSMM_OMP_VAR
(
i4
);
LIBXSMM_OMP_VAR
(
i5
);
LIBXSMM_OMP_VAR
(
i6
);
LIBXSMM_OMP_VAR
(
i7
);
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
kblk
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
R
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
S
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
(
int
)
cblk_i16
;
++
i5
)
{
for
(
i6
=
0
;
i6
<
(
int
)
kblk_i16
;
++
i6
)
{
for
(
i7
=
0
;
i7
<
(
int
)
lp_blk
;
++
i7
)
{
const
int
fi1
=
((
i1
*
kblk_i16
)
+
i6
)
/
kblk_f32
;
const
int
fi2
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i7
)
/
cblk_f32
;
const
int
fi3
=
i3
;
const
int
fi4
=
i4
;
const
int
fi5
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i7
)
%
cblk_f32
;
const
int
fi6
=
((
i1
*
kblk_i16
)
+
i6
)
%
kblk_f32
;
LIBXSMM_VLA_ACCESS
(
7
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
i7
,
cblk
,
R
,
S
,
cblk_i16
,
kblk_i16
,
lp_blk
)
=
(
short
)
LIBXSMM_ROUNDF
(
LIBXSMM_VLA_ACCESS
(
6
,
in
,
fi1
,
fi2
,
fi3
,
fi4
,
fi5
,
fi6
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
)
*
scfq
);
}
}
}
}
}
}
}
#if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
}
#endif
/* @TODO, we need to potentially fix this unsigned char problem */
#if !defined(NDEBUG)
/* library code is expected to be mute */
if
(
maxexp
>
0
)
{
fprintf
(
stderr
,
"error quant fil
\n
"
);
}
#endif
*
scf
=
(
unsigned
char
)(
-
maxexp
);
}
else
{
/* get max exponent */
unsigned
char
max_exp
=
libxsmm_internal_get_max_exp
(
in_buffer
,
K
*
C
*
R
*
S
);
/* if we go for stochastic rounding, let's initialize random seed */
if
(
round_mode
==
LIBXSMM_DNN_QUANT_STOCH_ROUND
)
{
srand
(
libxsmm_timer_tick
()
%
((
unsigned
int
)
-
1
));
}
#ifdef _OPENMP
# pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
#endif
for
(
i1
=
0
;
i1
<
(
int
)
kblk
;
++
i1
)
{
for
(
i2
=
0
;
i2
<
(
int
)
cblk
;
++
i2
)
{
for
(
i3
=
0
;
i3
<
(
int
)
R
;
++
i3
)
{
for
(
i4
=
0
;
i4
<
(
int
)
S
;
++
i4
)
{
for
(
i5
=
0
;
i5
<
(
int
)
cblk_i16
;
++
i5
)
{
for
(
i6
=
0
;
i6
<
(
int
)
kblk_i16
;
++
i6
)
{
for
(
i7
=
0
;
i7
<
(
int
)
lp_blk
;
++
i7
)
{
const
int
fi1
=
((
i1
*
kblk_i16
)
+
i6
)
/
kblk_f32
;
const
int
fi2
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i7
)
/
cblk_f32
;
const
int
fi3
=
i3
;
const
int
fi4
=
i4
;
const
int
fi5
=
((
i2
*
cblk_i16
*
lp_blk
)
+
(
i5
*
lp_blk
)
+
i7
)
%
cblk_f32
;
const
int
fi6
=
((
i1
*
kblk_i16
)
+
i6
)
%
kblk_f32
;
LIBXSMM_VLA_ACCESS
(
7
,
out
,
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
i7
,
cblk
,
R
,
S
,
cblk_i16
,
kblk_i16
,
lp_blk
)
=
libxsmm_internal_quantize_scalar_no_scf
(
LIBXSMM_VLA_ACCESS
(
6
,
in
,
fi1
,
fi2
,
fi3
,
fi4
,
fi5
,
fi6
,
C
/
cblk_f32
,
R
,
S
,
cblk_f32
,
kblk_f32
),
max_exp
,
add_shift
,
round_mode
);
}
}
}
}
}
}
}
*
scf
=
(
unsigned
char
)(
14
-
add_shift
-
(
max_exp
-
127
));
}
}
LIBXSMM_API
void
libxsmm_dnn_dequantize
(
short
*
in_buffer
,
float
*
out_buffer
,
int
length
,
unsigned
char
scf
)
{
const
float
val_exp
=
libxsmm_sexp2_i8i
(
-
scf
);
int
i
=
0
;
#ifdef _OPENMP
# pragma omp parallel for private(i)
#endif
for
(
i
=
0
;
i
<
length
;
++
i
)
{
out_buffer
[
i
]
=
((
float
)
in_buffer
[
i
])
*
val_exp
;
}
}
LIBXSMM_API
void
libxsmm_truncate_convert_f32_bf16
(
const
float
*
in
,
libxsmm_bfloat16
*
out
,
unsigned
int
length
)
{
unsigned
int
i
=
0
;
/* truncate buffer to bf16 */
for
(
i
=
0
;
i
<
length
;
++
i
)
{
libxsmm_bfloat16_hp
t
;
t
.
f
=
in
[
i
];
out
[
i
]
=
t
.
i
[
1
];
}
}
LIBXSMM_API
void
libxsmm_rnaz_convert_fp32_bf16
(
const
float
*
in
,
libxsmm_bfloat16
*
out
,
unsigned
int
len
)
{
unsigned
int
i
=
0
;
/* truncate buffer to bf16 */
for
(
i
=
0
;
i
<
len
;
++
i
)
{
unsigned
int
int_round
=
0
;
unsigned
int
do_round
=
1
;
int_round
=
*
((
unsigned
int
*
)
&
(
in
[
i
]));
/* we don't round NaN and inf */
if
(
(
int_round
&
0x7f800000
)
==
0x7f800000
)
{
do_round
=
0
;
}
/* perform round nearest tie away from zero */
if
(
do_round
!=
0
)
{
int_round
=
int_round
+
0x00008000
;
}
/* create the bf16 value by shifting out the lower 16bits */
int_round
=
int_round
>>
16
;
out
[
i
]
=
(
libxsmm_bfloat16
)
int_round
;
}
}
LIBXSMM_API
void
libxsmm_rne_convert_fp32_bf16
(
const
float
*
in
,
libxsmm_bfloat16
*
out
,
unsigned
int
len
)
{
unsigned
int
i
=
0
;
/* truncate buffer to bf16 */
for
(
i
=
0
;
i
<
len
;
++
i
)
{
unsigned
int
int_round
=
0
;
unsigned
int
do_round
=
1
;
int_round
=
*
((
unsigned
int
*
)
&
(
in
[
i
]));
/* we don't round NaN and inf */
if
(
(
int_round
&
0x7f800000
)
==
0x7f800000
)
{
do_round
=
0
;
}
/* perform round nearest tie even */
if
(
do_round
!=
0
)
{
unsigned
int
fixup
=
(
int_round
>>
16
)
&
1
;
int_round
=
int_round
+
0x00007fff
+
fixup
;
}
/* create the bf16 value by shifting out the lower 16bits */
int_round
=
int_round
>>
16
;
out
[
i
]
=
(
unsigned
short
)
int_round
;
}
}
LIBXSMM_API
void
libxsmm_convert_bf16_f32
(
const
libxsmm_bfloat16
*
in
,
float
*
out
,
unsigned
int
length
)
{
unsigned
int
i
=
0
;
/* up-convert is super simple */
for
(
i
=
0
;
i
<
length
;
++
i
)
{
libxsmm_bfloat16_hp
t
;
t
.
i
[
1
]
=
in
[
i
];
t
.
i
[
0
]
=
0
;
out
[
i
]
=
t
.
f
;
}
}
third_party/libxsmm/src/libxsmm_dnn_convolution.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst, Alexander Heinecke, Evangelos Georganas, Rajkishore Barik (Intel Corp.)
******************************************************************************/
#include <libxsmm_sync.h>
#include "libxsmm_main.h"
#include "libxsmm_dnn_convolution_forward.h"
#include "libxsmm_dnn_convolution_backward.h"
#include "libxsmm_dnn_convolution_weight_update.h"
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <math.h>
#if defined(_OPENMP)
# include <omp.h>
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif
#define MIXED 0
#define KHWC 1
#define HWKC 2
#define CHWK 3
#define HWCK 4
/**********************************************************/
/* Helper functions for convolutions' general param setup */
/**********************************************************/
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_ifmblock
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
int
ofm
,
lp
;
libxsmm_dnn_get_feature_map_blocks
(
handle
->
desc
.
C
,
handle
->
desc
.
K
,
&
result
,
&
ofm
,
&
lp
,
handle
->
desc
.
datatype_in
,
handle
->
desc
.
datatype_out
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_ofmblock
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
int
ifm
,
lp
;
libxsmm_dnn_get_feature_map_blocks
(
handle
->
desc
.
C
,
handle
->
desc
.
K
,
&
ifm
,
&
result
,
&
lp
,
handle
->
desc
.
datatype_in
,
handle
->
desc
.
datatype_out
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fm_lp_block
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
int
ifm
,
ofm
;
libxsmm_dnn_get_feature_map_blocks
(
handle
->
desc
.
C
,
handle
->
desc
.
K
,
&
ifm
,
&
ofm
,
&
result
,
handle
->
desc
.
datatype_in
,
handle
->
desc
.
datatype_out
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fallback_loops_fwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* FIXME: For now fallback only if MB is not divisible by number of threads */
if
(
handle
->
desc
.
N
%
handle
->
desc
.
threads
!=
0
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_blocksifm
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
desc
.
C
/
handle
->
ifmblock
;
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_blocksofm
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
desc
.
K
/
handle
->
ofmblock
;
return
result
;
}
/**********************************************************/
/* Helper functions for FWD convolutions' parameter setup */
/**********************************************************/
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_ofw_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
result
=
handle
->
ofw
;
if
(
handle
->
ofw
==
56
)
{
result
=
28
;
}
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
if
(
handle
->
ofw
%
2
==
0
)
{
result
=
handle
->
ofw
/
2
;
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_pack_input_fwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Pack only for small images and when having large K to amortize, and we can only pack for 1x1 convolutions */
if
((
handle
->
ofw
<=
14
)
&&
(
handle
->
desc
.
K
>
512
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
&&
(
handle
->
desc
.
u
==
2
)
&&
(
handle
->
desc
.
v
==
2
))
{
result
=
1
;
}
/* For SPR we allow packing more aggressively to generate more efficient BRGEMMs */
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
if
((
handle
->
ofw
<=
14
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
&&
(
handle
->
desc
.
u
==
2
)
&&
(
handle
->
desc
.
v
==
2
))
{
result
=
1
;
}
}
/* Make sure we don't pack when minibatch is not divisible by number of threads since H is used potentially for parallelism */
if
(
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
{
result
=
0
;
}
/* we don't pack for int8 */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_ofh_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
/* Multiple rows for "small" images and 1x1 convolutions */
if
((
handle
->
ofh
<=
14
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
))
{
result
=
handle
->
ofh
;
}
/* In this case we will be using fallback generic loops, thus ofh_rb should be 1 */
if
((
handle
->
desc
.
N
%
handle
->
desc
.
threads
!=
0
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
{
result
=
1
;
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
if
(
handle
->
ofw
==
7
&&
handle
->
ofh
==
7
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
S
==
3
)
{
result
=
7
;
}
if
(
handle
->
ofw
==
14
&&
handle
->
ofh
==
14
/*&& handle->desc.R == 3 && handle->desc.S == 3*/
)
{
result
=
2
;
}
}
/* Make sure we don't use multiple rows when we don't pack input and convolutions are strided*/
if
((
handle
->
pack_input
==
0
)
&&
((
handle
->
desc
.
u
!=
1
)
||
(
handle
->
desc
.
v
!=
1
)))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_pixels_gemm
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
fwd_ofw_rb
*
handle
->
fwd_ofh_rb
;
/* In the case below we calculate redundantly pixels in order to efficiently use AMX */
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
if
(
handle
->
desc
.
R
!=
1
||
handle
->
desc
.
R
!=
1
)
{
if
(
handle
->
ofw
<
24
)
{
result
=
(
handle
->
fwd_ofw_rb
+
2
*
handle
->
desc
.
pad_w
)
*
(
handle
->
fwd_ofh_rb
-
2
)
+
2
*
(
handle
->
fwd_ofw_rb
+
handle
->
desc
.
pad_w
);
}
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_block_H
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
14
;
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
/* Spatial dimension block tuning for SPR */
if
((
handle
->
ofh
==
7
&&
handle
->
desc
.
u
==
2
)
||
(
handle
->
ofh
==
14
&&
handle
->
desc
.
R
!=
3
)
||
handle
->
ofh
==
27
||
(
handle
->
ofh
==
28
&&
handle
->
desc
.
R
==
1
)
||
handle
->
ofh
==
48
||
handle
->
ofh
==
54
||
handle
->
ofh
==
56
||
handle
->
ofh
==
112
)
{
result
=
4
;
}
}
else
{
/* Block H only for large images */
if
(
handle
->
ofh
>=
28
)
{
result
=
4
;
}
if
(
handle
->
ofh
==
28
&&
handle
->
desc
.
R
==
3
)
{
result
=
14
;
}
}
/* Make sure it is divisible bu the ofh_rb factor in the kernel */
while
(
result
%
handle
->
fwd_ofh_rb
!=
0
)
{
result
--
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_blocksifm_blocking
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
/* For 1x1 Convolutions bring in kernel all IFMs unless filters are huge*/
if
((
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
)
{
result
=
handle
->
blocksifm
;
if
((
handle
->
desc
.
C
>=
2048
)
&&
(
handle
->
desc
.
K
>=
512
))
{
result
=
1
;
}
if
(
(
handle
->
target_archid
<
LIBXSMM_X86_AVX512
)
&&
(
handle
->
desc
.
C
>=
512
)
)
{
result
=
2
;
}
if
(
(
handle
->
target_archid
<
LIBXSMM_X86_AVX512
)
&&
(
handle
->
desc
.
C
>=
1024
)
)
{
result
=
4
;
}
}
else
{
result
=
1
;
/* If small image can bring in more IFMS even if NOT 1x1 convolution */
if
(
handle
->
ofw
<=
7
)
{
result
=
2
;
}
}
if
(
handle
->
blocksifm
%
result
!=
0
)
{
result
=
1
;
}
/* In case of SPR bring always in all accumulation */
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)))
{
result
=
handle
->
blocksifm
;
}
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
result
=
handle
->
blocksifm
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_loop_order_fwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Switch to loop order 1 only if 1x1 convolution with "large" input image and "small" K */
if
((
handle
->
desc
.
H
>=
28
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
&&
(
handle
->
desc
.
C
>=
512
)
&&
(
handle
->
desc
.
K
<=
512
))
{
result
=
1
;
}
if
(
handle
->
ofw
==
56
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
256
&&
handle
->
desc
.
K
==
64
)
{
result
=
1
;
}
if
(
handle
->
ofw
==
28
&&
handle
->
desc
.
R
==
1
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_fwd_IFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
8
;
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
C
==
2048
&&
handle
->
desc
.
K
==
512
)
{
result
=
4
;
}
/* Make sure it is divisible by ifms in the kernel */
while
(
result
%
handle
->
blocksifm_blocking
!=
0
)
{
result
++
;
}
result
=
LIBXSMM_MIN
(
handle
->
blocksifm
,
result
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_fwd_OFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
8
;
if
(
handle
->
ofw
==
14
&&
handle
->
desc
.
K
==
1024
)
{
result
=
16
;
}
if
(
handle
->
ofw
==
7
)
{
result
=
16
;
}
result
=
LIBXSMM_MIN
(
handle
->
blocksofm
,
result
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_use_ofm_parallelization
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
#if 0
/* Use "hybrid" minibatch/ofm parallelization if we have huge filters */
if ((handle->desc.R >= 3) && (handle->desc.S >= 3) && (handle->desc.C >= 512) && (handle->desc.K >= 512)) {
result = 1;
}
#endif
if
((
handle
->
ofw
<=
7
)
&&
(
handle
->
desc
.
C
==
1024
)
&&
(
handle
->
desc
.
K
==
512
))
{
result
=
1
;
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)))
{
if
(
handle
->
ofw
==
7
)
{
result
=
1
;
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Avoid rim FMA if the convolution is 3x3 (non-strided) and the image is "small" */
if
((
handle
->
desc
.
R
==
3
)
&&
(
handle
->
desc
.
S
==
3
)
&&
(
handle
->
desc
.
u
==
1
)
&&
(
handle
->
desc
.
v
==
1
)
&&
(
handle
->
desc
.
pad_h_in
==
1
)
&&
(
handle
->
desc
.
pad_w_in
==
1
)
&&
(
handle
->
desc
.
H
==
handle
->
desc
.
W
)
)
{
if
(
handle
->
ofw
<=
28
)
{
result
=
1
;
}
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
result
=
0
;
}
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)))
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_shuffle_filter_accesses
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Shuffle filter accesses only if "pure minibatch" parallelization and large filters are involved */
if
((
handle
->
use_ofm_parallelization
==
0
)
&&
(
handle
->
desc
.
C
>
512
)
&&
(
handle
->
desc
.
K
>
512
))
{
result
=
1
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
C
==
512
)
{
result
=
1
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
512
&&
handle
->
desc
.
K
==
2048
)
{
result
=
1
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
2048
&&
handle
->
desc
.
K
==
512
)
{
result
=
1
;
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_acc_load
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
((
handle
->
options
&
LIBXSMM_DNN_CONV_OPTION_OVERWRITE
)
>
0
)
{
if
((
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
))
{
if
(
handle
->
blocksifm_blocking
==
handle
->
blocksifm
)
{
result
=
1
;
}
}
else
{
if
((
handle
->
blocksifm_blocking
==
handle
->
blocksifm
)
&&
(
handle
->
avoid_fmas_in_rim
==
0
))
{
result
=
1
;
}
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_init_fwd_gemm_flags
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
#if defined(LIBXSMM_DNN_CONVOLUTION_SETUP_USE_NTS)
/* If large image and NOT already loaded in accumulators, tnen use streaming stores */
if
((
handle
->
ofw
>=
56
)
&&
(
handle
->
desc
.
K
>=
256
)
&&
(
handle
->
avoid_acc_load
==
1
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
))
{
result
=
LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT
;
}
if
(
handle
->
ofw
==
56
&&
handle
->
desc
.
C
==
64
&&
handle
->
desc
.
K
==
64
&&
handle
->
desc
.
R
==
1
)
{
result
=
LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT
;
}
if
(
handle
->
ofw
==
56
&&
handle
->
desc
.
C
==
256
&&
handle
->
desc
.
K
==
64
&&
handle
->
desc
.
R
==
1
)
{
result
=
LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT
;
}
/* Disable since the GEMM output is going to f32 scratch */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
||
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
{
result
=
0
;
}
#else
LIBXSMM_UNUSED
(
handle
);
#endif
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)))
{
result
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fwd_padding_copy
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
&&
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_fwd_scratch
(
libxsmm_dnn_layer
*
handle
)
{
handle
->
fwd_packing_padding_scratch_size
=
0
;
/* packing of input */
if
(
handle
->
pack_input
!=
0
)
{
handle
->
fwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
handle
->
desc
.
H
/
handle
->
desc
.
u
*
handle
->
desc
.
W
/
handle
->
desc
.
v
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* logical padding with copying in the fly */
if
(
handle
->
fwd_padding_copy
!=
0
)
{
handle
->
fwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
(
handle
->
desc
.
H
+
2
*
handle
->
desc
.
pad_h
)
*
(
handle
->
desc
.
W
+
2
*
handle
->
desc
.
pad_w
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* output buffer in high precision when we use BF16 */
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
)
{
handle
->
fwd_lp_output_full_scratch_size
=
(
size_t
)
LIBXSMM_MAX
(
handle
->
desc
.
threads
*
handle
->
fwd_gemm_pixels
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
),
handle
->
desc
.
N
*
handle
->
desc
.
K
*
handle
->
ofwp
*
handle
->
ofhp
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
));
handle
->
fwd_lp_output_block_scratch_size
=
(
size_t
)
handle
->
desc
.
threads
*
handle
->
fwd_ofw_rb
*
handle
->
fwd_ofh_rb
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
);
}
else
{
handle
->
fwd_lp_output_full_scratch_size
=
0
;
handle
->
fwd_lp_output_block_scratch_size
=
0
;
}
/* align sizes to full cacheline */
handle
->
fwd_packing_padding_scratch_size
+=
(
handle
->
fwd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
fwd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
fwd_lp_output_full_scratch_size
+=
(
handle
->
fwd_lp_output_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
fwd_lp_output_full_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
fwd_lp_output_block_scratch_size
+=
(
handle
->
fwd_lp_output_block_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
fwd_lp_output_block_scratch_size
%
LIBXSMM_CACHELINE
);
/* set offsets */
handle
->
fwd_packing_padding_scratch_offset
=
0
;
handle
->
fwd_lp_output_full_scratch_offset
=
handle
->
fwd_packing_padding_scratch_size
;
handle
->
fwd_lp_output_block_scratch_offset
=
handle
->
fwd_lp_output_full_scratch_offset
+
handle
->
fwd_lp_output_full_scratch_size
;
/* set overall scratch size for forward */
handle
->
fwd_scratch_size
=
handle
->
fwd_packing_padding_scratch_size
+
handle
->
fwd_lp_output_full_scratch_size
+
handle
->
fwd_lp_output_block_scratch_size
;
}
/**********************************************************/
/* Helper functions for BWD convolutions' parameter setup */
/**********************************************************/
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_fallback_loops_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* FIXME: Fallback if MB is not divisible by number of threads */
if
(
handle
->
desc
.
N
%
handle
->
desc
.
threads
!=
0
)
{
result
=
1
;
}
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
(
handle
->
desc
.
pad_h
!=
0
||
handle
->
desc
.
pad_w
!=
0
))
{
result
=
1
;
}
if
((
handle
->
desc
.
R
>
1
&&
handle
->
desc
.
pad_h
==
0
)
||
(
handle
->
desc
.
S
>
1
&&
handle
->
desc
.
pad_w
==
0
))
{
result
=
1
;
}
if
((
handle
->
desc
.
R
>
1
&&
(
handle
->
desc
.
pad_h_out
==
0
||
handle
->
desc
.
pad_h_in
==
0
))
||
(
handle
->
desc
.
S
>
1
&&
(
handle
->
desc
.
pad_w_out
==
0
||
handle
->
desc
.
pad_w_in
==
0
))
)
{
result
=
1
;
}
if
((
handle
->
desc
.
R
>
1
&&
handle
->
desc
.
u
>
1
)
||
(
handle
->
desc
.
S
>
1
&&
handle
->
desc
.
v
>
1
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_bwd_ofw_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
libxsmm_dnn_convolution_setup_fwd_ofw_rb
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_bwd_ofh_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
libxsmm_dnn_convolution_setup_fwd_ofh_rb
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_bwd_pixels_gemm
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
bwd_ofw_rb
*
handle
->
bwd_ofh_rb
;
/* In the case below we calculate redundantly pixels in order to efficiently use AMX */
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
if
(
handle
->
desc
.
R
!=
1
||
handle
->
desc
.
R
!=
1
)
{
if
(
handle
->
ofw
<
24
)
{
result
=
(
handle
->
bwd_ofw_rb
+
2
*
handle
->
desc
.
pad_w
)
*
(
handle
->
bwd_ofh_rb
-
2
)
+
2
*
(
handle
->
bwd_ofw_rb
+
handle
->
desc
.
pad_w
);
}
}
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_bwd_block_H
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
result
=
libxsmm_dnn_convolution_setup_fwd_block_H
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_loop_order_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
result
=
libxsmm_dnn_convolution_setup_loop_order_fwd
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_bwd_IFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
result
=
LIBXSMM_MIN
(
handle
->
blocksifm
,
16
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_bwd_OFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
8
;
while
(
result
%
handle
->
blocksofm_blocking
!=
0
)
{
result
++
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_pack_input_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
((
handle
->
desc
.
u
!=
1
)
&&
(
handle
->
bwd_ofh_rb
!=
1
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_use_ifm_parallelization
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
(
handle
->
ofw
<=
7
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_rim_fmas_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_blocksofm_blocking
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
result
=
handle
->
blocksofm
;
}
else
{
result
=
1
;
if
(
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
S
==
3
&&
handle
->
ofh
==
7
&&
handle
->
ofw
==
7
)
{
result
=
2
;
}
}
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
result
=
handle
->
blocksofm
;
}
if
(
handle
->
blocksofm
%
result
!=
0
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_init_bwd_gemm_flags
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
LIBXSMM_UNUSED
(
handle
);
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
&&
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
))
)
{
result
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_spread_input_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
LIBXSMM_UNUSED
(
handle
);
if
(((
handle
->
desc
.
u
!=
1
)
||
(
handle
->
desc
.
v
!=
1
))
&&
(
handle
->
bwd_ofh_rb
==
1
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_acc_load_bwd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
((
handle
->
options
&
LIBXSMM_DNN_CONV_OPTION_OVERWRITE
)
>
0
)
{
if
((
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
))
{
if
(
handle
->
blocksofm_blocking
==
handle
->
blocksofm
)
{
result
=
1
;
}
}
else
{
if
((
handle
->
blocksofm_blocking
==
handle
->
blocksofm
)
&&
(
handle
->
avoid_fmas_in_rim
==
0
))
{
result
=
1
;
}
}
}
return
result
;
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_bwd_scratch
(
libxsmm_dnn_layer
*
handle
)
{
/* transpose of weights */
handle
->
bwd_filter_trans_scratch_size
=
(
size_t
)
handle
->
desc
.
C
*
handle
->
desc
.
K
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
bwd_packing_padding_scratch_size
=
0
;
/* packing of input */
if
(
handle
->
pack_input_bwd
!=
0
)
{
handle
->
bwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
handle
->
ofhp
*
handle
->
ofwp
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* logical padding with copying in the fly */
if
(
handle
->
use_fallback_bwd_loops
!=
0
)
{
handle
->
bwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
threads
*
handle
->
ifmblock
*
(
handle
->
desc
.
H
+
2
*
handle
->
desc
.
pad_h
)
*
(
handle
->
desc
.
W
+
2
*
handle
->
desc
.
pad_w
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* input bufffer in high precision when we use BF16 */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
handle
->
bwd_lp_input_full_scratch_size
=
(
size_t
)
LIBXSMM_MAX
(
handle
->
desc
.
threads
*
handle
->
bwd_gemm_pixels
*
handle
->
ifmblock
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
),
handle
->
desc
.
N
*
handle
->
desc
.
C
*
handle
->
ifwp
*
handle
->
ifhp
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
));
/* logical padding with copying in the fly */
if
(
handle
->
use_fallback_bwd_loops
!=
0
)
{
handle
->
bwd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
threads
*
handle
->
ifmblock
*
(
handle
->
desc
.
H
+
2
*
handle
->
desc
.
pad_h
)
*
(
handle
->
desc
.
W
+
2
*
handle
->
desc
.
pad_w
)
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
);
}
}
else
{
handle
->
bwd_lp_input_full_scratch_size
=
0
;
}
/* align sizes to full cacheline */
handle
->
bwd_filter_trans_scratch_size
+=
(
handle
->
bwd_filter_trans_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
bwd_filter_trans_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
bwd_packing_padding_scratch_size
+=
(
handle
->
bwd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
bwd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
bwd_lp_input_full_scratch_size
+=
(
handle
->
bwd_lp_input_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
bwd_lp_input_full_scratch_size
%
LIBXSMM_CACHELINE
);
/* set offsets */
handle
->
bwd_filter_trans_scratch_offset
=
0
;
handle
->
bwd_packing_padding_scratch_offset
=
handle
->
bwd_filter_trans_scratch_size
;
handle
->
bwd_lp_input_full_scratch_offset
=
handle
->
bwd_packing_padding_scratch_offset
+
handle
->
bwd_packing_padding_scratch_size
;
/* set overall scratch size for forward */
handle
->
bwd_scratch_size
=
handle
->
bwd_filter_trans_scratch_size
+
handle
->
bwd_packing_padding_scratch_size
+
handle
->
bwd_lp_input_full_scratch_size
;
}
/**********************************************************/
/* Helper functions for UPD convolutions' parameter setup */
/**********************************************************/
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_loop_order_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
if
(
handle
->
ofh
==
28
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
C
==
128
&&
handle
->
desc
.
K
==
512
)
{
result
=
0
;
}
if
(
handle
->
ofh
==
28
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
C
==
128
&&
handle
->
desc
.
K
==
128
)
{
result
=
0
;
}
if
(
handle
->
ofw
==
28
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
256
&&
handle
->
desc
.
K
==
512
)
{
result
=
0
;
}
if
(
handle
->
ofw
==
14
&&
!
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
C
==
1024
&&
handle
->
desc
.
K
==
256
))
{
result
=
0
;
}
if
(
handle
->
ofw
==
7
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_pack_input_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Pack input only for very small images, 1x1 convs, with large K to amortize the relevant overhead */
if
((
handle
->
ofh
<=
7
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
&&
(
handle
->
desc
.
u
!=
1
)
&&
(
handle
->
desc
.
v
!=
1
)
&&
(
handle
->
desc
.
K
>=
2048
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_avoid_rim_fmas_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Avoid rim FMAs only for small images */
if
(
(
handle
->
ofh
<=
7
)
&&
(
handle
->
desc
.
R
==
3
)
&&
(
handle
->
desc
.
S
==
3
)
&&
(
handle
->
desc
.
pad_w
==
1
)
&&
(
handle
->
desc
.
pad_h
==
1
))
{
result
=
1
;
}
if
(
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_upd_ofw_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
result
=
handle
->
ofw
;
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_upd_ofh_rb
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
/* Restrict the reduction chain which is ofw_rb*ofh_rb*/
if
(
handle
->
ofh
<=
28
)
{
result
=
handle
->
ofh
;
}
/* In the following scenario with strided convolutions and non batch reduce kernel make sure we have ofh_rb = 1 */
if
((
handle
->
desc
.
u
!=
1
)
&&
(
handle
->
desc
.
v
!=
1
)
&&
(
handle
->
upd_use_batchreduce
==
0
)
&&
(
handle
->
upd_pack_input
==
0
))
{
result
=
1
;
}
/* If using linearized taskview and have strided convs, make sure ofh_rb is 1.. */
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_avoid_rim_fmas
==
0
&&
handle
->
upd_pack_input
==
0
&&
handle
->
desc
.
u
!=
1
)
{
result
=
1
;
}
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_use_batchreduce
==
0
&&
(
handle
->
desc
.
R
!=
1
||
handle
->
desc
.
S
!=
1
))
{
result
=
1
;
}
if
(
handle
->
upd_linearized_tasklist
==
0
&&
handle
->
upd_use_batchreduce
==
0
&&
(
handle
->
desc
.
R
!=
1
||
handle
->
desc
.
S
!=
1
))
{
result
=
1
;
}
if
(
handle
->
ofw
==
56
&&
handle
->
desc
.
R
==
1
)
{
result
=
2
;
}
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_use_batchreduce
==
1
&&
handle
->
upd_avoid_rim_fmas
==
1
)
{
result
=
handle
->
ofh
;
}
if
((
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
&&
(
handle
->
desc
.
R
>
1
||
handle
->
desc
.
S
>
1
)
&&
(
handle
->
desc
.
u
>
1
||
handle
->
desc
.
v
>
1
))
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_upd_IFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
if
(
handle
->
ofh
==
56
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
v
==
1
)
{
result
=
4
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_block_upd_OFM
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
LIBXSMM_UNUSED
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_img_batchreduce_block
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
LIBXSMM_UNUSED
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_use_batchreduce_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
1
;
/* If W is large, no need for batchreduce kernel */
if
(
handle
->
ofw
>=
56
)
{
result
=
0
;
}
/* If we have packed the input, then disable batch-reduce GEMM */
if
(
handle
->
upd_pack_input
==
1
)
{
result
=
0
;
}
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_avoid_rim_fmas
==
0
)
{
result
=
0
;
}
if
(
handle
->
upd_linearized_tasklist
==
1
&&
handle
->
upd_avoid_rim_fmas
==
1
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_weight_copies_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
handle
->
desc
.
threads
;
if
(
handle
->
ofw
<=
14
)
{
result
=
9
;
}
if
(
handle
->
ofw
==
14
&&
handle
->
desc
.
N
==
92
&&
handle
->
desc
.
threads
==
92
)
{
result
=
23
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
N
==
92
&&
handle
->
desc
.
threads
==
92
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
S
==
3
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
v
==
1
)
{
result
=
23
;
}
while
(
handle
->
desc
.
threads
%
result
!=
0
)
{
result
--
;
}
/* FIXME: Hardcoded logic for N=27, N=26 */
if
(
handle
->
desc
.
N
==
27
&&
handle
->
desc
.
threads
==
27
&&
handle
->
desc
.
R
==
1
&&
handle
->
ofw
==
14
&&
handle
->
desc
.
u
==
1
)
{
result
=
7
;
}
if
(((
handle
->
ofh
==
14
)
||
(
handle
->
ofw
==
7
&&
handle
->
desc
.
u
==
2
))
&&
handle
->
desc
.
N
==
26
&&
handle
->
desc
.
threads
==
26
)
{
result
=
13
;
}
if
((
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
&&
!
(
handle
->
upd_linearized_tasklist
==
0
&&
handle
->
upd_use_batchreduce
==
0
))
{
result
=
handle
->
desc
.
N
;
}
/* Make sure a single copy when we use linearized-task view */
if
(
handle
->
upd_linearized_tasklist
==
1
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_linearized_tasklist_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
/* Use linearized task-list (i.e. no reduction) only if small images and large filters */
if
(
handle
->
ofh
<=
10
&&
handle
->
ofw
<=
10
)
{
result
=
1
;
}
if
(
handle
->
ofw
==
7
&&
handle
->
desc
.
N
==
92
&&
handle
->
desc
.
threads
==
92
&&
handle
->
desc
.
R
==
3
&&
handle
->
desc
.
S
==
3
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
v
==
1
)
{
result
=
0
;
}
if
(
handle
->
ofh
==
14
&&
handle
->
ofw
==
14
&&
handle
->
desc
.
N
==
23
&&
handle
->
desc
.
threads
==
23
)
{
result
=
1
;
}
#if 0
if ((handle->blocksofm * handle->blocksifm * handle->desc.R * handle->desc.S > (handle->desc.threads * 4)) && (handle->ofh <= 56)) {
result = 1;
}
#endif
if
(
handle
->
desc
.
u
==
2
&&
handle
->
desc
.
v
==
2
&&
handle
->
desc
.
K
==
512
)
{
result
=
0
;
}
return
result
;
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_init_upd_gemm_flags
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
LIBXSMM_UNUSED
(
handle
);
return
result
;
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_bf16_upd
(
libxsmm_dnn_layer
*
handle
)
{
int
remainder_pixels
,
max_init_offset
,
max_compute_offset_input
,
input_compute_pad
,
accum_length_pixels
,
compute_pixels
;
const
int
multiple_target
=
2
;
int
IFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ifhp
;
int
IFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ifwp
;
int
OFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ofhp
;
int
OFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ofwp
;
handle
->
upd_linearized_pixels
=
1
;
if
(
handle
->
desc
.
S
!=
1
&&
handle
->
desc
.
v
!=
1
)
{
handle
->
upd_linearized_pixels
=
0
;
handle
->
upd_trans_w_only
=
0
;
}
/* For large images facilitate the "large" transposes by blocking the pixel/reduction domains */
if
(
handle
->
ofw
>=
56
&&
handle
->
ofh
>=
56
&&
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
handle
->
desc
.
u
==
1
&&
handle
->
desc
.
v
==
1
)
{
handle
->
upd_linearized_pixels
=
0
;
handle
->
upd_trans_w_only
=
1
;
}
handle
->
on_the_fly_input_packing
=
0
;
handle
->
upd_pack_input_upfront
=
0
;
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
upd_linearized_tasklist
=
0
;
if
(
handle
->
upd_linearized_pixels
==
1
)
{
/* Logistics to pad accumulation chainlength */
compute_pixels
=
handle
->
ofw
*
handle
->
ofh
+
2
*
handle
->
desc
.
pad_w
*
(
handle
->
ofh
-
1
);
remainder_pixels
=
(
compute_pixels
%
multiple_target
==
0
)
?
0
:
(
compute_pixels
/
multiple_target
+
1
)
*
multiple_target
-
compute_pixels
;
accum_length_pixels
=
compute_pixels
+
remainder_pixels
;
/* In this case compact input upfront */
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
(
handle
->
desc
.
u
!=
1
||
handle
->
desc
.
v
!=
1
))
{
handle
->
upd_pack_input_upfront
=
1
;
}
/* Logistics for input transpose and additional pixel padding */
max_init_offset
=
2
*
handle
->
desc
.
pad_h
*
IFWP
+
2
*
handle
->
desc
.
pad_w
;
max_compute_offset_input
=
max_init_offset
+
accum_length_pixels
;
input_compute_pad
=
(
max_compute_offset_input
>
IFWP
*
IFHP
)
?
max_compute_offset_input
-
IFWP
*
IFHP
:
0
;
handle
->
input_pixels
=
IFWP
*
IFHP
+
input_compute_pad
;
if
(
handle
->
upd_pack_input_upfront
)
{
handle
->
input_pixels
=
accum_length_pixels
;
}
handle
->
output_pixels
=
accum_length_pixels
;
handle
->
pixel_blocking
=
accum_length_pixels
;
handle
->
n_used_pixels
=
accum_length_pixels
;
handle
->
compute_pixels
=
compute_pixels
;
handle
->
use_intermediate_f32_wt_tensor
=
(
handle
->
pixel_blocking
==
handle
->
n_used_pixels
)
?
0
:
1
;
if
(
handle
->
ofw
<=
14
)
{
handle
->
use_hybrid_imgofm_parallelization
=
1
;
handle
->
weight_copies
=
libxsmm_dnn_convolution_setup_weight_copies_upd
(
handle
);
if
(
handle
->
ofw
==
14
&&
handle
->
desc
.
K
>=
1024
)
{
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
weight_copies
=
handle
->
desc
.
threads
;
}
}
else
{
handle
->
weight_copies
=
handle
->
desc
.
threads
;
}
}
if
(
handle
->
upd_linearized_pixels
==
0
)
{
handle
->
weight_copies
=
handle
->
desc
.
threads
;
if
(
handle
->
desc
.
v
!=
1
)
{
handle
->
on_the_fly_input_packing
=
1
;
}
remainder_pixels
=
(
handle
->
ofw
%
multiple_target
==
0
)
?
0
:
(
handle
->
ofw
/
multiple_target
+
1
)
*
multiple_target
-
handle
->
ofw
;
handle
->
ofwp_extended
=
OFWP
+
remainder_pixels
;
handle
->
ifwp_extended
=
IFWP
+
remainder_pixels
;
handle
->
output_pixels
=
OFHP
*
handle
->
ofwp_extended
;
/* coverity[identical_branches] */
handle
->
batchreduce_h_pixels
=
(
handle
->
upd_trans_w_only
)
?
1
:
1
;
/* TODO: identical_branches */
handle
->
use_intermediate_f32_wt_tensor
=
(
handle
->
batchreduce_h_pixels
==
handle
->
ofh
)
?
0
:
1
;
}
if
(
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
{
handle
->
use_intermediate_f32_wt_tensor
=
1
;
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
weight_copies
=
LIBXSMM_MIN
(
handle
->
desc
.
N
,
handle
->
desc
.
threads
);
}
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_bf16_upd_amx
(
libxsmm_dnn_layer
*
handle
)
{
/* JIT related variables... */
libxsmm_blasint
LDA
=
handle
->
ofmblock
;
libxsmm_blasint
LDB
=
handle
->
input_pixels
;
libxsmm_blasint
LDC
=
handle
->
ofmblock
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
int
l_tc_flags
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
size_t
stride_a
,
stride_b
;
int
unroll_hint
;
float
beta
;
int
remainder_pixels
,
max_init_offset
,
max_compute_offset_input
,
input_compute_pad
,
accum_length_pixels
,
compute_pixels
;
const
int
multiple_target
=
32
;
int
IFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ifhp
;
int
IFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ifwp
;
int
OFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ofwp
;
handle
->
upd_linearized_pixels
=
1
;
if
(
handle
->
desc
.
S
!=
1
&&
handle
->
desc
.
v
!=
1
)
{
handle
->
upd_linearized_pixels
=
0
;
}
handle
->
fuse_upd_transposes
=
1
;
handle
->
pack_to_cnhw
=
0
;
handle
->
on_the_fly_input_packing
=
0
;
handle
->
upd_pack_input_upfront
=
0
;
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
upd_linearized_tasklist
=
0
;
if
(((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
&&
(
handle
->
ofw
==
7
)
&&
(
handle
->
desc
.
R
==
1
)
&&
(
handle
->
desc
.
S
==
1
)
)
{
handle
->
pack_to_cnhw
=
1
;
}
if
(
handle
->
upd_linearized_pixels
==
1
)
{
if
(
handle
->
pack_to_cnhw
==
0
)
{
handle
->
fuse_upd_transposes
=
1
;
/* Logistics to pad accumulation chainlength */
compute_pixels
=
handle
->
ofw
*
handle
->
ofh
+
2
*
handle
->
desc
.
pad_w
*
(
handle
->
ofh
-
1
);
remainder_pixels
=
(
compute_pixels
%
multiple_target
==
0
)
?
0
:
(
compute_pixels
/
multiple_target
+
1
)
*
multiple_target
-
compute_pixels
;
accum_length_pixels
=
compute_pixels
+
remainder_pixels
;
/* In this case compact input upfront */
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
&&
(
handle
->
desc
.
u
!=
1
||
handle
->
desc
.
v
!=
1
))
{
handle
->
upd_pack_input_upfront
=
1
;
}
/* Logistics for input transpose and additional pixel padding */
max_init_offset
=
2
*
handle
->
desc
.
pad_h
*
IFWP
+
2
*
handle
->
desc
.
pad_w
;
max_compute_offset_input
=
max_init_offset
+
accum_length_pixels
;
input_compute_pad
=
(
max_compute_offset_input
>
IFWP
*
IFHP
)
?
max_compute_offset_input
-
IFWP
*
IFHP
:
0
;
handle
->
input_pixels
=
IFWP
*
IFHP
+
input_compute_pad
;
if
(
handle
->
upd_pack_input_upfront
)
{
handle
->
input_pixels
=
accum_length_pixels
;
}
handle
->
output_pixels
=
accum_length_pixels
;
handle
->
pixel_blocking
=
accum_length_pixels
;
handle
->
n_used_pixels
=
accum_length_pixels
;
handle
->
compute_pixels
=
compute_pixels
;
handle
->
use_intermediate_f32_wt_tensor
=
(
handle
->
pixel_blocking
==
handle
->
n_used_pixels
)
?
0
:
1
;
#if 0
handle->scratch2_size = (size_t) (handle->desc.N * handle->output_pixels * handle->desc.K * sizeof(float)/2);
if (handle->use_intermediate_f32_wt_tensor) {
handle->scratch2_size += (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads * sizeof(float);
}
handle->scratch3_size = (size_t) (handle->desc.N * handle->input_pixels * handle->desc.C * sizeof(float)/2);
#endif
if
(
handle
->
ofw
<=
14
)
{
handle
->
use_hybrid_imgofm_parallelization
=
1
;
handle
->
fuse_upd_transposes
=
0
;
}
else
{
handle
->
weight_copies
=
handle
->
desc
.
threads
;
}
if
((
handle
->
ofmblock
%
32
!=
0
)
||
(
handle
->
ifmblock
%
32
!=
0
))
{
handle
->
fuse_upd_transposes
=
0
;
}
}
else
{
/* Logistics to pad accumulation chainlength */
handle
->
use_hybrid_imgofm_parallelization
=
1
;
handle
->
weight_copies
=
7
;
while
(
handle
->
desc
.
threads
%
handle
->
weight_copies
!=
0
)
{
handle
->
weight_copies
--
;
}
compute_pixels
=
handle
->
ofw
*
handle
->
ofh
*
(
handle
->
desc
.
N
/
handle
->
weight_copies
);
remainder_pixels
=
(
compute_pixels
%
multiple_target
==
0
)
?
0
:
(
compute_pixels
/
multiple_target
+
1
)
*
multiple_target
-
compute_pixels
;
handle
->
remainder_pixels
=
remainder_pixels
;
accum_length_pixels
=
compute_pixels
+
remainder_pixels
;
handle
->
output_pixels
=
accum_length_pixels
*
handle
->
weight_copies
;
handle
->
input_pixels
=
accum_length_pixels
*
handle
->
weight_copies
;
handle
->
pixel_blocking
=
accum_length_pixels
;
handle
->
n_used_pixels
=
accum_length_pixels
;
handle
->
use_intermediate_f32_wt_tensor
=
0
;
#if 0
handle->scratch2_size = (size_t) (handle->weight_copies * handle->output_pixels * handle->desc.K * sizeof(float)/2);
handle->scratch3_size = (size_t) (handle->weight_copies * handle->input_pixels * handle->desc.C * sizeof(float)/2);
#endif
}
}
if
(
handle
->
upd_linearized_pixels
==
0
)
{
handle
->
weight_copies
=
handle
->
desc
.
threads
;
if
(
handle
->
desc
.
v
!=
1
)
{
handle
->
on_the_fly_input_packing
=
1
;
}
remainder_pixels
=
(
handle
->
ofw
%
multiple_target
==
0
)
?
0
:
(
handle
->
ofw
/
multiple_target
+
1
)
*
multiple_target
-
handle
->
ofw
;
handle
->
remainder_pixels
=
remainder_pixels
;
handle
->
ofwp_extended
=
OFWP
+
remainder_pixels
;
handle
->
ifwp_extended
=
IFWP
+
remainder_pixels
;
handle
->
batchreduce_h_pixels
=
handle
->
ofh
;
handle
->
use_intermediate_f32_wt_tensor
=
(
handle
->
batchreduce_h_pixels
==
handle
->
ofh
)
?
0
:
1
;
#if 0
handle->scratch2_size = (size_t) (handle->desc.N * handle->ofhp*handle->ofwp_extended * handle->desc.K * sizeof(float)/2);
if (handle->use_intermediate_f32_wt_tensor) {
handle->scratch2_size += (size_t) handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K * handle->desc.threads * sizeof(float);
}
handle->scratch3_size = (size_t) (handle->desc.N * handle->ifhp * handle->ifwp_extended * handle->desc.C * sizeof(float)/2);
#endif
}
/* Now that all decisions have been made, JIT the proper kernel... */
beta
=
(
handle
->
use_intermediate_f32_wt_tensor
)
?
(
float
)
1
.
0
:
(
float
)
0
.
0
;
if
(
handle
->
upd_linearized_pixels
==
0
)
{
LDA
=
handle
->
ofmblock
;
LDB
=
IFHP
*
handle
->
ifwp_extended
;
LDC
=
handle
->
ofmblock
;
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
unroll_hint
=
handle
->
batchreduce_h_pixels
;
stride_a
=
handle
->
ofwp_extended
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
stride_b
=
handle
->
desc
.
u
*
handle
->
ifwp_extended
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
upd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
ofw
+
handle
->
remainder_pixels
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
handle
->
upd_compute_kernel_brgemm_no_linearized_pixels
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
ofw
+
handle
->
remainder_pixels
,
(
libxsmm_blasint
)
stride_a
,
(
libxsmm_blasint
)
stride_b
,
unroll_hint
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
}
else
{
LDA
=
handle
->
ofmblock
;
LDB
=
handle
->
input_pixels
;
LDC
=
handle
->
ofmblock
;
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
if
(
handle
->
use_hybrid_imgofm_parallelization
==
0
)
{
handle
->
upd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
handle
->
upd_compute_kernel_gemm_linearized_pixels_no_hybrid_par
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
}
else
{
if
(
handle
->
pack_to_cnhw
==
1
)
{
handle
->
upd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
handle
->
upd_compute_kernel_gemm_linearized_pixels_hybrid_par_cnhw
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
}
else
{
/* TODO: Hoist here hybrid parallelization logic and then we should be able to also provide unroll hint in the BRGEMM call */
stride_a
=
handle
->
blocksofm
*
handle
->
output_pixels
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
stride_b
=
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
input_pixels
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
upd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
handle
->
upd_compute_kernel_brgemm_linearized_pixels_hybrid_par_no_cnhw
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
ifmblock
,
handle
->
pixel_blocking
,
(
libxsmm_blasint
)
stride_a
,
(
libxsmm_blasint
)
stride_b
,
&
LDA
,
&
LDB
,
&
LDC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
}
}
}
if
(
handle
->
desc
.
N
!=
handle
->
desc
.
threads
)
{
handle
->
use_intermediate_f32_wt_tensor
=
1
;
handle
->
use_hybrid_imgofm_parallelization
=
0
;
handle
->
weight_copies
=
LIBXSMM_MIN
(
handle
->
desc
.
N
,
handle
->
desc
.
threads
);
}
}
LIBXSMM_API_INLINE
int
libxsmm_dnn_convolution_setup_upd_padding_copy
(
libxsmm_dnn_layer
*
handle
)
{
int
result
=
0
;
if
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
&&
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
{
result
=
1
;
}
return
result
;
}
LIBXSMM_API_INLINE
void
libxsmm_dnn_convolution_setup_upd_scratch
(
libxsmm_dnn_layer
*
handle
)
{
handle
->
upd_packing_padding_scratch_size
=
0
;
/* packing of input */
if
(
handle
->
upd_pack_input
!=
0
)
{
handle
->
upd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
handle
->
desc
.
H
/
handle
->
desc
.
u
*
handle
->
desc
.
W
/
handle
->
desc
.
v
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* logical padding with copying in the fly */
if
(
handle
->
upd_padding_copy
!=
0
)
{
handle
->
upd_packing_padding_scratch_size
=
(
size_t
)
handle
->
desc
.
N
*
handle
->
desc
.
C
*
(
handle
->
desc
.
H
+
2
*
handle
->
desc
.
pad_h
)
*
(
handle
->
desc
.
W
+
2
*
handle
->
desc
.
pad_w
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
}
/* output/input buffer to transpose when we use bf16 */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
if
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
int
OFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ofhp
;
int
IFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ifhp
;
if
(
handle
->
upd_linearized_pixels
==
1
)
{
handle
->
upd_lp_output_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
handle
->
output_pixels
*
handle
->
desc
.
K
*
sizeof
(
handle
->
datatype_in
));
handle
->
upd_lp_input_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
handle
->
input_pixels
*
handle
->
desc
.
C
*
sizeof
(
handle
->
datatype_in
));
}
if
(
handle
->
upd_linearized_pixels
==
0
)
{
handle
->
upd_lp_output_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
OFHP
*
handle
->
ofwp_extended
*
handle
->
desc
.
K
*
sizeof
(
handle
->
datatype_in
));
handle
->
upd_lp_input_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
IFHP
*
handle
->
ifwp_extended
*
handle
->
desc
.
C
*
sizeof
(
handle
->
datatype_in
));
}
}
else
{
const
int
multiple_target
=
2
;
int
IFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ifhp
;
int
IFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ifwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ifwp
;
int
OFHP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofhp
+
2
*
handle
->
desc
.
pad_h
:
handle
->
ofhp
;
int
OFWP
=
(
handle
->
upd_padding_copy
==
1
)
?
handle
->
ofwp
+
2
*
handle
->
desc
.
pad_w
:
handle
->
ofwp
;
if
(
handle
->
upd_linearized_pixels
==
1
)
{
int
compute_pixels
=
handle
->
ofw
*
handle
->
ofh
+
2
*
handle
->
desc
.
pad_w
*
(
handle
->
ofh
-
1
);
int
remainder_pixels
=
(
compute_pixels
%
multiple_target
==
0
)
?
0
:
(
compute_pixels
/
multiple_target
+
1
)
*
multiple_target
-
compute_pixels
;
int
accum_length_pixels
=
compute_pixels
+
remainder_pixels
;
int
max_init_offset
=
2
*
handle
->
desc
.
pad_h
*
IFWP
+
2
*
handle
->
desc
.
pad_w
;
int
max_compute_offset_input
=
max_init_offset
+
accum_length_pixels
;
int
input_compute_pad
=
(
max_compute_offset_input
>
IFWP
*
IFHP
)
?
max_compute_offset_input
-
IFWP
*
IFHP
:
0
;
int
input_pixels
=
IFWP
*
IFHP
+
input_compute_pad
;
if
(
handle
->
upd_pack_input_upfront
==
1
)
{
input_pixels
=
accum_length_pixels
;
}
handle
->
upd_lp_output_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
accum_length_pixels
*
handle
->
desc
.
K
*
sizeof
(
handle
->
datatype_in
));
handle
->
upd_lp_input_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
input_pixels
*
handle
->
desc
.
C
*
sizeof
(
handle
->
datatype_in
));
}
if
(
handle
->
upd_linearized_pixels
==
0
)
{
int
remainder_pixels
=
(
handle
->
ofw
%
multiple_target
==
0
)
?
0
:
(
handle
->
ofw
/
multiple_target
+
1
)
*
multiple_target
-
handle
->
ofw
;
int
ofwp_extended
=
OFWP
+
remainder_pixels
;
int
ifwp_extended
=
IFWP
+
remainder_pixels
;
handle
->
upd_lp_output_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
OFHP
*
ofwp_extended
*
handle
->
desc
.
K
*
sizeof
(
handle
->
datatype_in
));
handle
->
upd_lp_input_full_scratch_size
=
(
size_t
)
(
handle
->
desc
.
N
*
IFHP
*
ifwp_extended
*
handle
->
desc
.
C
*
sizeof
(
handle
->
datatype_in
));
}
}
handle
->
upd_lp_filter_full_scratch_size
=
(
size_t
)
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
desc
.
C
*
handle
->
desc
.
K
*
handle
->
desc
.
threads
*
libxsmm_dnn_typesize
(
LIBXSMM_DNN_DATATYPE_F32
);
}
else
{
handle
->
upd_lp_output_full_scratch_size
=
0
;
handle
->
upd_lp_input_full_scratch_size
=
0
;
handle
->
upd_lp_filter_full_scratch_size
=
0
;
}
/* filter scratch */
handle
->
upd_filter_scratch_size
=
(
size_t
)
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
desc
.
C
*
handle
->
desc
.
K
*
LIBXSMM_MAX
(
handle
->
desc
.
threads
,
handle
->
desc
.
N
)
*
sizeof
(
float
);
/* align sizes to full cacheline */
handle
->
upd_packing_padding_scratch_size
+=
(
handle
->
upd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_packing_padding_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
upd_lp_output_full_scratch_size
+=
(
handle
->
upd_lp_output_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_lp_output_full_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
upd_lp_input_full_scratch_size
+=
(
handle
->
upd_lp_input_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_lp_input_full_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
upd_filter_scratch_size
+=
(
handle
->
upd_filter_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_filter_scratch_size
%
LIBXSMM_CACHELINE
);
handle
->
upd_lp_filter_full_scratch_size
+=
(
handle
->
upd_lp_filter_full_scratch_size
%
LIBXSMM_CACHELINE
==
0
)
?
0
:
LIBXSMM_CACHELINE
-
(
handle
->
upd_lp_filter_full_scratch_size
%
LIBXSMM_CACHELINE
);
/* calculate offsets */
handle
->
upd_packing_padding_scratch_offset
=
0
;
handle
->
upd_lp_output_full_scratch_offset
=
handle
->
upd_packing_padding_scratch_size
;
handle
->
upd_lp_input_full_scratch_offset
=
handle
->
upd_lp_output_full_scratch_offset
+
handle
->
upd_lp_output_full_scratch_size
;
handle
->
upd_filter_scratch_offset
=
handle
->
upd_lp_input_full_scratch_offset
+
handle
->
upd_lp_input_full_scratch_size
;
handle
->
upd_lp_filter_full_scratch_offset
=
handle
->
upd_filter_scratch_offset
+
handle
->
upd_filter_scratch_size
;
/* set overall scratch size for update */
handle
->
upd_scratch_size
=
handle
->
upd_packing_padding_scratch_size
+
handle
->
upd_lp_output_full_scratch_size
+
handle
->
upd_lp_input_full_scratch_size
+
handle
->
upd_filter_scratch_size
+
handle
->
upd_lp_filter_full_scratch_size
;
}
LIBXSMM_API_INLINE
libxsmm_dnn_err_t
libxsmm_dnn_convolution_setup
(
libxsmm_dnn_layer
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
libxsmm_blasint
_ldi
=
64
,
_ldo
=
64
;
libxsmm_blasint
ldx
;
libxsmm_blasint
ldA
;
libxsmm_blasint
ldC
;
int
beta_int
;
float
beta
;
int
l_flags
;
int
l_tc_flags
;
/* init libxsmm */
LIBXSMM_INIT
/* Generic parameter setup */
handle
->
target_archid
=
libxsmm_target_archid
;
if
(
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
&&
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
((
handle
->
desc
.
C
%
16
!=
0
)
||
(
handle
->
desc
.
K
%
16
!=
0
))
)
{
handle
->
target_archid
=
LIBXSMM_X86_AVX512_CPX
;
}
handle
->
ifmblock
=
libxsmm_dnn_convolution_setup_ifmblock
(
handle
);
handle
->
ofmblock
=
libxsmm_dnn_convolution_setup_ofmblock
(
handle
);
handle
->
fm_lp_block
=
libxsmm_dnn_convolution_setup_fm_lp_block
(
handle
);
handle
->
blocksifm
=
libxsmm_dnn_convolution_setup_blocksifm
(
handle
);
handle
->
blocksofm
=
libxsmm_dnn_convolution_setup_blocksofm
(
handle
);
/* If in SPR, generate tilerelease kernel */
if
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
int
l_tr_flags
=
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
handle
->
tilerelease_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ifmblock
,
handle
->
ifmblock
,
handle
->
ifmblock
,
NULL
,
NULL
,
NULL
,
NULL
,
NULL
,
&
l_tr_flags
,
NULL
);
}
/* FWD parameter setup */
handle
->
fwd_ofw_rb
=
libxsmm_dnn_convolution_setup_fwd_ofw_rb
(
handle
);
handle
->
pack_input
=
libxsmm_dnn_convolution_setup_pack_input_fwd
(
handle
);
handle
->
fwd_ofh_rb
=
libxsmm_dnn_convolution_setup_fwd_ofh_rb
(
handle
);
handle
->
fwd_gemm_pixels
=
libxsmm_dnn_convolution_setup_fwd_pixels_gemm
(
handle
);
handle
->
block_fwd_oj
=
libxsmm_dnn_convolution_setup_fwd_block_H
(
handle
);
handle
->
loop_order
=
libxsmm_dnn_convolution_setup_loop_order_fwd
(
handle
);
handle
->
blocksifm_blocking
=
libxsmm_dnn_convolution_setup_blocksifm_blocking
(
handle
);
handle
->
block_fwd_ofm
=
libxsmm_dnn_convolution_setup_block_fwd_OFM
(
handle
);
handle
->
block_fwd_ifm
=
libxsmm_dnn_convolution_setup_block_fwd_IFM
(
handle
);
handle
->
avoid_fmas_in_rim
=
libxsmm_dnn_convolution_setup_avoid_rim_fmas_fwd
(
handle
);
handle
->
use_ofm_parallelization
=
libxsmm_dnn_convolution_setup_use_ofm_parallelization
(
handle
);
handle
->
shuffle_filter_accesses
=
libxsmm_dnn_convolution_setup_shuffle_filter_accesses
(
handle
);
handle
->
avoid_acc_load
=
libxsmm_dnn_convolution_setup_avoid_acc_load
(
handle
);
handle
->
fwd_flags
=
libxsmm_dnn_convolution_setup_init_fwd_gemm_flags
(
handle
);
handle
->
use_fallback_fwd_loops
=
libxsmm_dnn_convolution_setup_fallback_loops_fwd
(
handle
);
handle
->
fwd_padding_copy
=
libxsmm_dnn_convolution_setup_fwd_padding_copy
(
handle
);
#if 0
if ( handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32 ) {
int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
int brgemm_pf_oob = 0;
const char *const env_brgemm_pf_oob = getenv("BRGEMM_PF_OOB");
handle->block_fwd_ofm = 1;
handle->block_fwd_oj = handle->fwd_ofh_rb;
ldx = (handle->pack_input == 1) ? (libxsmm_blasint)handle->ifmblock : (libxsmm_blasint)handle->desc.v*handle->ifmblock;
ldA = handle->ofmblock;
ldC = handle->ofmblock;
beta = (handle->avoid_acc_load) ? (float)0.0 : (float)1.0;
l_flags = ( LIBXSMM_GEMM_FLAGS('N', 'N') ) | handle->fwd_flags;
if ( 0 == env_brgemm_pf_oob ) {
} else {
brgemm_pf_oob = atoi(env_brgemm_pf_oob);
}
if (brgemm_pf_oob > 0) {
prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB);
}
handle->fwd_compute_kernel_offs_f32 = NULL;
handle->fwd_compute_kernel_strd_f32 = NULL;
handle->fwd_compute_kernel_addr_a_f32 = NULL;
handle->fwd_compute_kernel_addr_b_f32 = NULL;
if (handle->desc.R == 1 && handle->desc.S == 1) {
const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp;
const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp;
int stride_a = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock * libxsmm_dnn_typesize(handle->datatype_in);
int stride_b = IFW * IFH * handle->ifmblock * libxsmm_dnn_typesize(handle->datatype_in);
handle->fwd_compute_kernel_strd_f32 = libxsmm_smmdispatch_reducebatch_strd_unroll(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, stride_a, stride_b, handle->blocksifm_blocking, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
} else {
const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp );
const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp );
int n_blocks = handle->desc.R * handle->desc.S * handle->blocksifm_blocking;
int i = 0, ifm, ki, kj;
handle->A_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
handle->B_offsets = (unsigned long long*) malloc(n_blocks * sizeof(unsigned long long));
for (ifm = 0; ifm < handle->blocksifm_blocking; ifm++) {
for (kj = 0; kj < handle->desc.R; kj++) {
for (ki = 0; ki < handle->desc.S; ki++) {
handle->A_offsets[i] = (ifm * handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock +
kj * handle->desc.S * handle->ifmblock * handle->ofmblock +
ki * handle->ifmblock * handle->ofmblock) * libxsmm_dnn_typesize(handle->datatype_in);
handle->B_offsets[i] = (ifm * IFH * IFW * handle->ifmblock +
kj * IFW * handle->ifmblock +
ki * handle->ifmblock) * libxsmm_dnn_typesize(handle->datatype_in);
i++;
}
}
}
handle->fwd_compute_kernel_offs_f32 = libxsmm_smmdispatch_reducebatch_offs(handle->ofmblock, handle->fwd_gemm_pixels, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, NULL);
}
handle->fwd_compute_kernel_addr_a_f32 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*handle->fwd_ofw_rb, handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
handle->fwd_compute_kernel_addr_b_f32 = libxsmm_smmdispatch_reducebatch_addr(handle->ofmblock, handle->fwd_ofh_rb*(handle->fwd_ofw_rb-1), handle->ifmblock, &ldA, &ldx, &ldC, NULL, &beta, &l_flags, &prefetch_mode);
}
#endif
if
(
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
&&
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
handle
->
block_fwd_ofm
=
1
;
handle
->
block_fwd_oj
=
handle
->
fwd_ofh_rb
;
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
ldA
=
handle
->
ofmblock
;
ldC
=
handle
->
ofmblock
;
beta
=
(
handle
->
avoid_acc_load
)
?
(
float
)
0
.
0
:
(
float
)
1
.
0
;
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
l_tc_flags
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
handle
->
fwd_compute_kernel_addr
=
NULL
;
handle
->
fwd_compute_kernel_offs_a
=
NULL
;
handle
->
fwd_compute_kernel_offs_b
=
NULL
;
handle
->
fwd_compute_kernel_strd
=
NULL
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
const
int
IFW
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofwp
:
handle
->
ifwp
;
const
int
IFH
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofhp
:
handle
->
ifhp
;
size_t
stride_a
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
size_t
stride_b
=
IFW
*
IFH
*
handle
->
ifmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
fwd_compute_kernel_strd
=
libxsmm_bmmdispatch_reducebatch_strd_unroll
(
handle
->
ofmblock
,
handle
->
fwd_gemm_pixels
,
handle
->
ifmblock
,
(
libxsmm_blasint
)
stride_a
,
(
libxsmm_blasint
)
stride_b
,
handle
->
blocksifm_blocking
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
}
else
{
const
int
IFW
=
(
handle
->
fwd_padding_copy
==
1
)
?
handle
->
ifwp
+
2
*
handle
->
desc
.
pad_w
:
(
(
handle
->
pack_input
==
1
)
?
handle
->
ofwp
:
handle
->
ifwp
);
const
int
IFH
=
(
handle
->
fwd_padding_copy
==
1
)
?
handle
->
ifhp
+
2
*
handle
->
desc
.
pad_h
:
(
(
handle
->
pack_input
==
1
)
?
handle
->
ofhp
:
handle
->
ifhp
);
int
n_blocks
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
blocksifm_blocking
;
int
i
=
0
,
ifm
,
ki
,
kj
;
handle
->
A_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
handle
->
B_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
for
(
ifm
=
0
;
ifm
<
handle
->
blocksifm_blocking
;
ifm
++
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
kj
++
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
ki
++
)
{
handle
->
A_offsets
[
i
]
=
(
ifm
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
kj
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
ki
*
handle
->
ifmblock
*
handle
->
ofmblock
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
B_offsets
[
i
]
=
(
ifm
*
IFH
*
IFW
*
handle
->
ifmblock
+
kj
*
IFW
*
handle
->
ifmblock
+
ki
*
handle
->
ifmblock
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
i
++
;
}
}
}
handle
->
fwd_compute_kernel_offs_a
=
libxsmm_bmmdispatch_reducebatch_offs
(
handle
->
ofmblock
,
handle
->
fwd_gemm_pixels
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
handle
->
fwd_compute_kernel_offs_b
=
libxsmm_bsmmdispatch_reducebatch_offs
(
handle
->
ofmblock
,
handle
->
fwd_gemm_pixels
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
}
handle
->
fwd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ofmblock
,
handle
->
fwd_gemm_pixels
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
}
handle
->
code_fwd
[
0
].
ptr
=
0
;
handle
->
code_fwd
[
1
].
ptr
=
0
;
handle
->
code_fwd
[
2
].
ptr
=
0
;
/* JIT cvt eltwise functions for fwd convolutions */
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
_ldi
=
handle
->
ofmblock
*
handle
->
ofwp
;
_ldo
=
handle
->
ofmblock
*
handle
->
ofwp
;
handle
->
fwd_cvtfp32bf16_kernel
=
libxsmm_dispatch_meltw_unary
(
handle
->
ofmblock
*
handle
->
fwd_ofw_rb
,
handle
->
fwd_ofh_rb
,
&
_ldi
,
&
_ldo
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_BF16
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_IDENTITY
);
}
/* Create strided BRGEMMs for i8i32 convolutions */
if
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_I32
))
{
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
ldA
=
handle
->
ofmblock
;
ldC
=
handle
->
ofmblock
;
beta_int
=
(
handle
->
avoid_acc_load
)
?
0
:
1
;
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
handle
->
fwd_flags
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
const
int
IFW
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofwp
:
handle
->
ifwp
;
const
int
IFH
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofhp
:
handle
->
ifhp
;
libxsmm_blasint
stride_A
=
handle
->
ifmblock
*
handle
->
ofmblock
*
sizeof
(
char
);
libxsmm_blasint
stride_B
=
handle
->
ifmblock
*
IFW
*
IFH
*
sizeof
(
char
)
;
handle
->
gemm_fwd
.
xgemm
.
subimrs
=
libxsmm_subimmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
stride_A
,
stride_B
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
else
{
const
int
IFW
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofwp
:
handle
->
ifwp
;
const
int
IFH
=
(
handle
->
pack_input
==
1
)
?
handle
->
ofhp
:
handle
->
ifhp
;
if
(
handle
->
avoid_fmas_in_rim
==
0
)
{
int
n_blocks
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
blocksifm_blocking
;
int
i
=
0
,
ifm
,
ki
,
kj
;
handle
->
A_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
handle
->
B_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
for
(
ifm
=
0
;
ifm
<
handle
->
blocksifm_blocking
;
ifm
++
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
kj
++
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
ki
++
)
{
handle
->
A_offsets
[
i
]
=
(
ifm
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
kj
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
ki
*
handle
->
ifmblock
*
handle
->
ofmblock
)
*
sizeof
(
char
);
handle
->
B_offsets
[
i
]
=
(
ifm
*
IFH
*
IFW
*
handle
->
ifmblock
+
kj
*
IFW
*
handle
->
ifmblock
+
ki
*
handle
->
ifmblock
)
*
sizeof
(
char
);
i
++
;
}
}
}
handle
->
gemm_fwd
.
xgemm
.
subimro
=
libxsmm_subimmdispatch_reducebatch_offs
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
else
{
libxsmm_blasint
stride_A
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
char
);
libxsmm_blasint
stride_B
=
handle
->
ifmblock
*
IFW
*
IFH
*
sizeof
(
char
)
;
handle
->
gemm_fwd
.
xgemm
.
subimrs
=
libxsmm_subimmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
stride_A
,
stride_B
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
handle
->
gemm_fwd2
.
xgemm
.
subimrs
=
libxsmm_subimmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
stride_A
,
stride_B
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
}
}
else
if
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_I8
))
{
ldx
=
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
ldA
=
handle
->
ofmblock
;
ldC
=
handle
->
ofmblock
;
beta_int
=
0
;
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
handle
->
fwd_flags
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
const
int
IFW
=
handle
->
ifwp
;
const
int
IFH
=
handle
->
ifhp
;
libxsmm_blasint
stride_A
=
handle
->
ifmblock
*
handle
->
ofmblock
*
sizeof
(
char
);
libxsmm_blasint
stride_B
=
handle
->
ifmblock
*
IFW
*
IFH
*
sizeof
(
char
)
;
handle
->
gemm_fwd
.
xgemm
.
sububmrs
=
libxsmm_sububmmdispatch_reducebatch_strd
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
stride_A
,
stride_B
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
else
{
const
int
IFW
=
handle
->
ifwp
;
const
int
IFH
=
handle
->
ifhp
;
int
n_blocks
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
blocksifm_blocking
;
int
i
=
0
,
ifm
,
ki
,
kj
;
handle
->
A_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
handle
->
B_offsets
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
for
(
ifm
=
0
;
ifm
<
handle
->
blocksifm_blocking
;
ifm
++
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
kj
++
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
ki
++
)
{
handle
->
A_offsets
[
i
]
=
(
ifm
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
kj
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
ki
*
handle
->
ifmblock
*
handle
->
ofmblock
)
*
sizeof
(
char
);
handle
->
B_offsets
[
i
]
=
(
ifm
*
IFH
*
IFW
*
handle
->
ifmblock
+
kj
*
IFW
*
handle
->
ifmblock
+
ki
*
handle
->
ifmblock
)
*
sizeof
(
char
);
i
++
;
}
}
}
handle
->
gemm_fwd
.
xgemm
.
sububmro
=
libxsmm_sububmmdispatch_reducebatch_offs
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta_int
,
&
l_flags
,
NULL
);
}
}
#if 0
/* Spit out FWD parameters that are selected... */
printf("FWD params...\n");
printf("Fwd_ofw_rb = %d\n", handle->fwd_ofw_rb);
printf("Fwd_ofh_rb = %d\n", handle->fwd_ofh_rb);
printf("Pack input = %d\n", handle->pack_input);
printf("Block oj = %d\n", handle->block_fwd_oj);
printf("Loop order = %d\n", handle->loop_order);
printf("Blocksifm_blocking = %d\n", handle->blocksifm_blocking);
printf("Block fwd ofm = %d\n", handle->block_fwd_ofm);
printf("Block fwd ifm = %d\n", handle->block_fwd_ifm);
printf("Avoid rim fmas = %d\n", handle->avoid_fmas_in_rim);
printf("Ofm parallelization = %d\n", handle->use_ofm_parallelization);
printf("Shuffle filter accesses = %d\n", handle->shuffle_filter_accesses);
printf("Avoid acc load = %d\n", handle->avoid_acc_load);
printf("Fwd GEMM flags = %d\n", handle->fwd_flags);
#endif
/* BWD parameter setup */
handle
->
bwd_ofw_rb
=
libxsmm_dnn_convolution_setup_bwd_ofw_rb
(
handle
);
handle
->
bwd_ofh_rb
=
libxsmm_dnn_convolution_setup_bwd_ofh_rb
(
handle
);
handle
->
bwd_gemm_pixels
=
libxsmm_dnn_convolution_setup_bwd_pixels_gemm
(
handle
);
handle
->
pack_input_bwd
=
libxsmm_dnn_convolution_setup_pack_input_bwd
(
handle
);
handle
->
spread_input_bwd
=
libxsmm_dnn_convolution_setup_spread_input_bwd
(
handle
);
handle
->
blocksofm_blocking
=
libxsmm_dnn_convolution_setup_blocksofm_blocking
(
handle
);
handle
->
avoid_acc_load_bwd
=
libxsmm_dnn_convolution_setup_avoid_acc_load_bwd
(
handle
);
handle
->
use_ifm_parallelization
=
libxsmm_dnn_convolution_setup_use_ifm_parallelization
(
handle
);
handle
->
block_bwd_ofm
=
libxsmm_dnn_convolution_setup_block_bwd_OFM
(
handle
);
handle
->
block_bwd_ifm
=
libxsmm_dnn_convolution_setup_block_bwd_IFM
(
handle
);
handle
->
block_bwd_oj
=
libxsmm_dnn_convolution_setup_bwd_block_H
(
handle
);
handle
->
use_fallback_bwd_loops
=
libxsmm_dnn_convolution_setup_fallback_loops_bwd
(
handle
);
handle
->
bwd_flags
=
libxsmm_dnn_convolution_setup_init_bwd_gemm_flags
(
handle
);
if
(
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
&&
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
handle
->
block_bwd_ifm
=
1
;
handle
->
block_bwd_oj
=
handle
->
bwd_ofh_rb
;
ldx
=
((
libxsmm_blasint
)
handle
->
ofmblock
);
ldA
=
handle
->
ifmblock
;
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
handle
->
ifmblock
*
handle
->
desc
.
v
:
handle
->
ifmblock
;
beta
=
(
handle
->
avoid_acc_load_bwd
)
?
(
float
)
0
.
0
:
(
float
)
1
.
0
;
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG
;
l_tc_flags
=
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG
|
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
);
handle
->
bwd_compute_kernel_addr
=
NULL
;
handle
->
bwd_compute_kernel_offs
=
NULL
;
handle
->
bwd_compute_kernel_strd
=
NULL
;
if
(
handle
->
desc
.
R
==
1
&&
handle
->
desc
.
S
==
1
)
{
size_t
stride_a
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
size_t
stride_b
=
handle
->
ofwp
*
handle
->
ofhp
*
handle
->
ofmblock
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
bwd_compute_kernel_strd
=
libxsmm_bsmmdispatch_reducebatch_strd_unroll
(
handle
->
ifmblock
,
handle
->
bwd_gemm_pixels
,
handle
->
ofmblock
,
(
libxsmm_blasint
)
stride_a
,
(
libxsmm_blasint
)
stride_b
,
handle
->
blocksofm_blocking
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
}
else
{
int
n_blocks
=
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
blocksofm_blocking
;
int
i
=
0
,
ofm
,
ki
,
kj
;
handle
->
A_offsets_bwd
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
handle
->
B_offsets_bwd
=
(
unsigned
long
long
*
)
malloc
(
n_blocks
*
sizeof
(
unsigned
long
long
));
for
(
ofm
=
0
;
ofm
<
handle
->
blocksofm_blocking
;
ofm
++
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
kj
++
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
ki
++
)
{
handle
->
A_offsets_bwd
[
i
]
=
(
ofm
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
kj
*
handle
->
desc
.
S
*
handle
->
ifmblock
*
handle
->
ofmblock
+
ki
*
handle
->
ifmblock
*
handle
->
ofmblock
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
handle
->
B_offsets_bwd
[
i
]
=
(
ofm
*
handle
->
ofhp
*
handle
->
ofwp
*
handle
->
ofmblock
+
kj
*
handle
->
ofwp
*
handle
->
ofmblock
+
ki
*
handle
->
ofmblock
)
*
libxsmm_dnn_typesize
(
handle
->
datatype_in
);
i
++
;
}
}
}
handle
->
bwd_compute_kernel_offs
=
libxsmm_bsmmdispatch_reducebatch_offs
(
handle
->
ifmblock
,
handle
->
bwd_gemm_pixels
,
handle
->
ofmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
}
handle
->
bwd_config_kernel
=
libxsmm_bsmmdispatch
(
handle
->
ifmblock
,
handle
->
bwd_gemm_pixels
,
handle
->
ofmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_tc_flags
,
NULL
);
}
#if 0
/* Spit out BWD parameters that are selected... */
printf("BWD params...\n");
printf("Bwd_ofw_rb = %d\n", handle->bwd_ofw_rb);
printf("Bwd_ofh_rb = %d\n", handle->bwd_ofh_rb);
printf("Pack input = %d\n", handle->pack_input_bwd);
printf("Spread input = %d\n", handle->spread_input_bwd);
printf("Blocksofm_blocking = %d\n", handle->blocksofm_blocking);
printf("Avoid acc load = %d\n", handle->avoid_acc_load_bwd);
printf("Ifm parallelization = %d\n", handle->use_ifm_parallelization);
printf("Block bwd ofm = %d\n", handle->block_bwd_ofm);
printf("Block bwd ifm = %d\n", handle->block_bwd_ifm);
printf("Block oj = %d\n", handle->block_bwd_oj);
#endif
handle
->
code_bwd
[
0
].
ptr
=
0
;
handle
->
code_bwd
[
1
].
ptr
=
0
;
handle
->
code_bwd
[
2
].
ptr
=
0
;
/* Transpose kernel used for filter transpose in bwd pass */
handle
->
tr_kernel
=
libxsmm_dispatch_meltw_unary
(
64
,
16
,
&
(
_ldi
),
&
(
_ldo
),
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_DATATYPE_F32
,
LIBXSMM_MELTW_FLAG_UNARY_NONE
,
LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT
);
/* UPD parameter setup */
handle
->
upd_linearized_tasklist
=
libxsmm_dnn_convolution_setup_linearized_tasklist_upd
(
handle
);
handle
->
upd_avoid_rim_fmas
=
libxsmm_dnn_convolution_setup_avoid_rim_fmas_upd
(
handle
);
handle
->
upd_pack_input
=
libxsmm_dnn_convolution_setup_pack_input_upd
(
handle
);
handle
->
upd_use_batchreduce
=
libxsmm_dnn_convolution_setup_use_batchreduce_upd
(
handle
);
handle
->
upd_ofw_rb
=
libxsmm_dnn_convolution_setup_upd_ofw_rb
(
handle
);
handle
->
upd_ofh_rb
=
libxsmm_dnn_convolution_setup_upd_ofh_rb
(
handle
);
handle
->
upd_loop_order
=
libxsmm_dnn_convolution_setup_loop_order_upd
(
handle
);
handle
->
weight_copies
=
libxsmm_dnn_convolution_setup_weight_copies_upd
(
handle
);
handle
->
block_upd_ofm
=
libxsmm_dnn_convolution_setup_block_upd_OFM
(
handle
);
handle
->
block_upd_ifm
=
libxsmm_dnn_convolution_setup_block_upd_IFM
(
handle
);
handle
->
upd_loop_order
=
libxsmm_dnn_convolution_setup_loop_order_upd
(
handle
);
handle
->
upd_padding_copy
=
libxsmm_dnn_convolution_setup_upd_padding_copy
(
handle
);
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
{
if
((
handle
->
target_archid
==
LIBXSMM_X86_AVX512_SPR
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
libxsmm_dnn_convolution_setup_bf16_upd_amx
(
handle
);
}
else
{
libxsmm_dnn_convolution_setup_bf16_upd
(
handle
);
}
}
#if 0
/* Spit out UPD parameters that are selected... */
printf("UPD params...\n");
if (handle->datatype_in == LIBXSMM_DNN_DATATYPE_BF16) {
printf("BF16 path...\n");
printf("UPD use_hybrid_imgofm_parallelization = %d\n", handle->use_hybrid_imgofm_parallelization);
printf("UPD linearized_pixels = %d\n", handle->upd_linearized_pixels);
printf("UPD upd_trans_w_only = %d\n", handle->upd_trans_w_only);
printf("UPD on_the_fly_input_packing = %d\n", handle->on_the_fly_input_packing);
printf("UPD use_intermediate_f32_wt_tensor = %d\n", handle->use_intermediate_f32_wt_tensor);
printf("UPD pack to CNHW format = %d\n", handle->pack_to_cnhw);
printf("UPD batchreduce H pixels = %d\n", handle->batchreduce_h_pixels);
}
printf("UPD linearized tasks = %d\n", handle->upd_linearized_tasklist);
printf("UPD avoid rim fmas = %d\n", handle->upd_avoid_rim_fmas);
printf("UPD Pack input = %d\n", handle->upd_pack_input);
printf("UPD use batch-reduce GEMM = %d\n", handle->upd_use_batchreduce);
printf("Upd_ofw_rb = %d\n", handle->upd_ofw_rb);
printf("Upd_ofh_rb = %d\n", handle->upd_ofh_rb);
printf("UPD loop order = %d\n", handle->upd_loop_order);
printf("UPD weight_copies = %d\n", handle->weight_copies);
printf("Block upd ofm = %d\n", handle->block_upd_ofm);
printf("Block upd ifm = %d\n", handle->block_upd_ifm);
#endif
handle
->
code_upd
[
0
].
ptr
=
0
;
handle
->
code_upd
[
1
].
ptr
=
0
;
/* prepare barrier */
handle
->
barrier
=
libxsmm_barrier_create
(
handle
->
desc
.
threads
,
1
);
/* setup up scratch */
libxsmm_dnn_convolution_setup_fwd_scratch
(
handle
);
libxsmm_dnn_convolution_setup_bwd_scratch
(
handle
);
libxsmm_dnn_convolution_setup_upd_scratch
(
handle
);
handle
->
scratch
=
0
;
handle
->
scratch_size
=
LIBXSMM_MAX
(
handle
->
fwd_scratch_size
,
LIBXSMM_MAX
(
handle
->
bwd_scratch_size
,
handle
->
upd_scratch_size
)
);
return
status
;
}
#undef MIXED
#undef KHWC
#undef HWKC
#undef CHWK
#undef HWCK
LIBXSMM_API
libxsmm_dnn_layer
*
libxsmm_dnn_create_conv_layer
(
libxsmm_dnn_conv_desc
conv_desc
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_layer
*
handle
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
/* currently we don't support NCHW */
if
(
(
conv_desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NCHW
)
>
0
)
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW
;
return
0
;
}
/* currently we don't support KCRS */
if
(
(
conv_desc
.
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_KCRS
)
>
0
)
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS
;
return
0
;
}
/* we only support physical paddind in these days */
/* @TODO: add logical padding support for other datatypes than FP32 */
if
(
(
(
conv_desc
.
pad_h
!=
conv_desc
.
pad_h_in
)
||
(
conv_desc
.
pad_w
!=
conv_desc
.
pad_w_in
)
||
(
conv_desc
.
pad_h
!=
conv_desc
.
pad_h_out
)
||
(
conv_desc
.
pad_w
!=
conv_desc
.
pad_w_out
)
)
&&
(
conv_desc
.
datatype_in
!=
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
conv_desc
.
datatype_in
!=
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_PADDING
;
return
0
;
}
/* zero entire content; not only safer but also sets data and code pointers to NULL */
handle
=
(
libxsmm_dnn_layer
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_layer
));
if
(
0
!=
handle
)
{
/* initialize known handle components */
handle
->
desc
=
conv_desc
;
handle
->
datatype_in
=
conv_desc
.
datatype_in
;
handle
->
datatype_out
=
conv_desc
.
datatype_out
;
/* select the intermediate format, only applicable for integer types */
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_F32
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_F32
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_I32
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_I8
)
)
{
/* error */
}
else
if
(
(
conv_desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
&&
(
conv_desc
.
datatype_out
!=
LIBXSMM_DNN_DATATYPE_F32
)
)
{
/* error */
}
else
{
/* fine, no error */
}
handle
->
buffer_format
=
conv_desc
.
buffer_format
;
handle
->
filter_format
=
conv_desc
.
filter_format
;
handle
->
fuse_ops
=
conv_desc
.
fuse_ops
;
handle
->
options
=
conv_desc
.
options
;
/* derive additional values */
handle
->
ifhp
=
conv_desc
.
H
+
2
*
conv_desc
.
pad_h_in
;
handle
->
ifwp
=
conv_desc
.
W
+
2
*
conv_desc
.
pad_w_in
;
handle
->
ofh
=
(
conv_desc
.
H
+
2
*
conv_desc
.
pad_h
-
conv_desc
.
R
)
/
conv_desc
.
u
+
1
;
handle
->
ofw
=
(
conv_desc
.
W
+
2
*
conv_desc
.
pad_w
-
conv_desc
.
S
)
/
conv_desc
.
v
+
1
;
handle
->
ofhp
=
handle
->
ofh
+
2
*
conv_desc
.
pad_h_out
;
handle
->
ofwp
=
handle
->
ofw
+
2
*
conv_desc
.
pad_w_out
;
handle
->
ifmblock
=
1
;
handle
->
ofmblock
=
1
;
handle
->
blocksifm
=
conv_desc
.
C
;
handle
->
blocksofm
=
conv_desc
.
K
;
handle
->
fwd_ofw_rb
=
1
;
handle
->
fwd_ofh_rb
=
1
;
handle
->
bwd_ofw_rb
=
1
;
handle
->
bwd_ofh_rb
=
1
;
handle
->
upd_ofw_rb
=
1
;
handle
->
upd_ofh_rb
=
1
;
handle
->
fm_lp_block
=
1
;
handle
->
blocksifm_blocking
=
1
;
handle
->
blocksofm_blocking
=
1
;
/* Set algorithm to use */
if
(
conv_desc
.
algo
==
LIBXSMM_DNN_CONV_ALGO_AUTO
)
{
handle
->
algo
=
LIBXSMM_DNN_CONV_ALGO_DIRECT
;
}
else
{
handle
->
algo
=
conv_desc
.
algo
;
}
if
(
handle
->
algo
!=
LIBXSMM_DNN_CONV_ALGO_DIRECT
)
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_ALGO
;
free
(
handle
);
handle
=
0
;
return
0
;
}
*
status
=
libxsmm_dnn_convolution_setup
(
handle
);
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_HANDLE
;
}
/* account for eventually deallocated handle */
if
(
LIBXSMM_DNN_SUCCESS
!=
*
status
)
{
handle
=
0
;
}
return
handle
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_destroy_conv_layer
(
const
libxsmm_dnn_layer
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
/* Deallocate barrier */
if
(
handle
->
barrier
!=
0
)
{
libxsmm_barrier_release
((
const
libxsmm_barrier
*
)
handle
->
barrier
);
}
/* deallocate handle structure itself */
free
(
/*remove constness*/
(
libxsmm_dnn_layer
*
)
handle
);
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor_datalayout
*
libxsmm_dnn_create_tensor_datalayout
(
const
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor_datalayout
*
layout
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
layout
=
0
;
if
(
handle
!=
0
)
{
/* zero entire content; not only safer but also sets data and code pointers to NULL */
layout
=
(
libxsmm_dnn_tensor_datalayout
*
)
calloc
(
1
,
sizeof
(
libxsmm_dnn_tensor_datalayout
));
if
(
layout
!=
0
)
{
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
||
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_ACTIVATION
;
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
/* @TODO this need to change */
}
else
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_I32
)
)
{
if
(
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
)
{
layout
->
datatype
=
handle
->
datatype_in
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
datatype
=
handle
->
datatype_out
;
}
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_BF16
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
6
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
6
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
/* coverity[dead_error_begin] */
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
if
(
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
)
{
if
(
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
)
)
{
layout
->
datatype
=
handle
->
datatype_in
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
)
{
layout
->
datatype
=
handle
->
datatype_out
;
}
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
5
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
5
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
layout
->
num_dims
=
5
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
N
;
}
else
{
/* coverity[dead_error_begin] */
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
)
>
0
)
{
if
(
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
4
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_W
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_H
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
if
(
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
||
(
type
==
LIBXSMM_DNN_INPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
*
handle
->
blocksifm
;
layout
->
dim_size
[
1
]
=
handle
->
ifwp
;
layout
->
dim_size
[
2
]
=
handle
->
ifhp
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
N
;
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
||
(
type
==
LIBXSMM_DNN_OUTPUT
)
)
{
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
*
handle
->
blocksofm
;
layout
->
dim_size
[
1
]
=
handle
->
ofwp
;
layout
->
dim_size
[
2
]
=
handle
->
ofhp
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
N
;
}
else
{
/* coverity[dead_error_begin] */
free
(
layout
->
dim_type
);
free
(
layout
->
dim_size
);
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
||
(
type
==
LIBXSMM_DNN_FILTER
)
)
{
layout
->
format
=
handle
->
filter_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_FILTER
;
if
((
handle
->
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
6
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
6
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
6
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
2
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
4
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
5
]
=
handle
->
blocksofm
;
}
}
else
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_BF16
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
7
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
7
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
7
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
6
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
handle
->
fm_lp_block
;
layout
->
dim_size
[
1
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
2
]
=
handle
->
ifmblock
/
handle
->
fm_lp_block
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
5
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
6
]
=
handle
->
blocksofm
;
}
}
else
if
(
((
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
))
||
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
)
)
{
if
(
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
||
(
type
==
LIBXSMM_DNN_FILTER
)
)
{
layout
->
datatype
=
handle
->
datatype_in
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
{
layout
->
datatype
=
handle
->
datatype_out
;
}
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
7
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
7
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
if
((
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
||
(
type
==
LIBXSMM_DNN_FILTER
))
{
layout
->
num_dims
=
7
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
6
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_size
[
0
]
=
handle
->
fm_lp_block
;
layout
->
dim_size
[
1
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
2
]
=
handle
->
ifmblock
/
handle
->
fm_lp_block
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
5
]
=
handle
->
blocksifm
;
layout
->
dim_size
[
6
]
=
handle
->
blocksofm
;
}
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
)
>
0
)
{
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
4
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
*
handle
->
blocksofm
;
layout
->
dim_size
[
1
]
=
handle
->
ifmblock
*
handle
->
blocksifm
;
layout
->
dim_size
[
2
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
R
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
{
layout
->
format
=
handle
->
filter_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_REGULAR_FILTER_TRANS
;
if
((
handle
->
filter_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
6
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
6
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
6
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
2
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
4
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
5
]
=
handle
->
blocksifm
;
}
}
else
if
(
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
)
&&
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_BF16
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
7
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
7
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
7
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_S
;
layout
->
dim_type
[
4
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_R
;
layout
->
dim_type
[
5
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_K
;
layout
->
dim_type
[
6
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
handle
->
fm_lp_block
;
layout
->
dim_size
[
1
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
2
]
=
handle
->
ofmblock
/
handle
->
fm_lp_block
;
layout
->
dim_size
[
3
]
=
handle
->
desc
.
S
;
layout
->
dim_size
[
4
]
=
handle
->
desc
.
R
;
layout
->
dim_size
[
5
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
6
]
=
handle
->
blocksifm
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
#if 0
} else if ((handle->filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) {
if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 4;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
layout->dim_size[0] = handle->ofmblock * handle->blocksofm;
layout->dim_size[1] = handle->ifmblock * handle->blocksifm;
layout->dim_size[2] = handle->desc.S;
layout->dim_size[3] = handle->desc.K;
}
} else {
free(layout);
layout = 0; /* make sure a NULL is returned */
*status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
}
#endif
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
||
(
type
==
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
||
(
type
==
LIBXSMM_DNN_CHANNEL_BIAS
)
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_CHANNEL_SCALAR
;
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
layout
->
datatype
=
handle
->
datatype_out
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
blocksofm
;
}
#if 0
} else if ( (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I16) || (handle->datatype_in == LIBXSMM_DNN_DATATYPE_I8) ) {
layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype));
layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int));
if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
layout->num_dims = 3;
layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
layout->dim_size[0] = handle->fm_lp_block;
layout->dim_size[1] = handle->ofmblock;
layout->dim_size[2] = handle->blocksofm;
}
#endif
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
)
>
0
)
{
layout
->
datatype
=
handle
->
datatype_out
;
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
)
{
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
1
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
1
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
1
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
*
handle
->
blocksofm
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
(
type
==
LIBXSMM_DNN_BATCH_STATS
)
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_BATCH_STATS
;
if
((
handle
->
buffer_format
&
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
)
>
0
)
{
if
(
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
||
(
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
)
)
{
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
4
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
4
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_type
[
2
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
3
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_X
;
layout
->
dim_size
[
0
]
=
handle
->
ofmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
N
;
layout
->
dim_size
[
2
]
=
handle
->
blocksofm
;
layout
->
dim_size
[
3
]
=
2
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL
;
}
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_FWD
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_MAX_STATS_FWD
;
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
N
;
}
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_BWD
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_MAX_STATS_BWD
;
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
N
;
}
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_UPD
)
{
layout
->
format
=
handle
->
buffer_format
;
layout
->
tensor_type
=
LIBXSMM_DNN_MAX_STATS_UPD
;
layout
->
datatype
=
LIBXSMM_DNN_DATATYPE_F32
;
layout
->
dim_type
=
(
libxsmm_dnn_tensor_dimtype
*
)
malloc
(
2
*
sizeof
(
libxsmm_dnn_tensor_dimtype
));
layout
->
dim_size
=
(
unsigned
int
*
)
malloc
(
2
*
sizeof
(
unsigned
int
));
if
(
0
!=
layout
->
dim_type
&&
0
!=
layout
->
dim_size
)
{
/* TODO: handle the error */
layout
->
num_dims
=
2
;
layout
->
dim_type
[
0
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_C
;
layout
->
dim_type
[
1
]
=
LIBXSMM_DNN_TENSOR_DIMTYPE_N
;
layout
->
dim_size
[
0
]
=
handle
->
ifmblock
;
layout
->
dim_size
[
1
]
=
handle
->
desc
.
N
;
}
}
else
{
free
(
layout
);
layout
=
0
;
/* make sure a NULL is returned */
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_CREATE_LAYOUT
;
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
layout
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_trans_reg_bf16_filter
(
const
libxsmm_dnn_layer
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
handle
!=
0
)
{
if
(
(
handle
->
reg_filter
!=
0
)
&&
(
handle
->
reg_filter_tr
!=
0
)
)
{
/* TODO handle more datatypes */
int
ifm1
,
ifm2
,
kj
,
ki
,
ofm1
,
ofm2
;
int
ofmblock_lp
=
handle
->
ofmblock
/
handle
->
fm_lp_block
;
int
ifmblock_lp
=
handle
->
ifmblock
/
handle
->
fm_lp_block
;
int
lpb
=
handle
->
fm_lp_block
;
LIBXSMM_VLA_DECL
(
7
,
libxsmm_bfloat16
,
wt
,
(
libxsmm_bfloat16
*
)
handle
->
reg_filter
->
data
,
handle
->
blocksifm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
ifmblock_lp
,
handle
->
ofmblock
,
lpb
);
LIBXSMM_VLA_DECL
(
7
,
libxsmm_bfloat16
,
tr_wt
,
(
libxsmm_bfloat16
*
)
handle
->
reg_filter_tr
->
data
,
handle
->
blocksofm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
ofmblock_lp
,
handle
->
ifmblock
,
lpb
);
/* TODO we might want to do this in parallel.... */
for
(
ifm1
=
0
;
ifm1
<
handle
->
blocksifm
;
++
ifm1
)
{
for
(
ofm1
=
0
;
ofm1
<
handle
->
blocksofm
;
++
ofm1
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
++
kj
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
++
ki
)
{
for
(
ofm2
=
0
;
ofm2
<
handle
->
ofmblock
;
++
ofm2
)
{
for
(
ifm2
=
0
;
ifm2
<
handle
->
ifmblock
;
++
ifm2
)
{
LIBXSMM_VLA_ACCESS
(
7
,
tr_wt
,
ifm1
,
ofm1
,
handle
->
desc
.
R
-
1
-
kj
,
handle
->
desc
.
S
-
1
-
ki
,
ofm2
/
lpb
,
ifm2
,
ofm2
%
lpb
,
handle
->
blocksofm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
ofmblock_lp
,
handle
->
ifmblock
,
lpb
)
=
LIBXSMM_VLA_ACCESS
(
7
,
wt
,
ofm1
,
ifm1
,
kj
,
ki
,
ifm2
/
lpb
,
ofm2
,
ifm2
%
lpb
,
handle
->
blocksifm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
ifmblock_lp
,
handle
->
ofmblock
,
lpb
);
}
}
}
}
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_trans_reg_filter
(
const
libxsmm_dnn_layer
*
handle
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
handle
!=
0
)
{
if
(
(
handle
->
reg_filter
!=
0
)
&&
(
handle
->
reg_filter_tr
!=
0
)
)
{
/* TODO handle more datatypes */
int
ifm1
,
ifm2
,
kj
,
ki
,
ofm1
,
ofm2
;
LIBXSMM_VLA_DECL
(
6
,
float
,
wt
,
(
float
*
)
handle
->
reg_filter
->
data
,
handle
->
blocksifm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
handle
->
ifmblock
,
handle
->
ofmblock
);
LIBXSMM_VLA_DECL
(
6
,
float
,
tr_wt
,
(
float
*
)
handle
->
reg_filter_tr
->
data
,
handle
->
blocksofm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
handle
->
ofmblock
,
handle
->
ifmblock
);
/* TODO we might want to do this in parallel.... */
for
(
ifm1
=
0
;
ifm1
<
handle
->
blocksifm
;
++
ifm1
)
{
for
(
ofm1
=
0
;
ofm1
<
handle
->
blocksofm
;
++
ofm1
)
{
for
(
kj
=
0
;
kj
<
handle
->
desc
.
R
;
++
kj
)
{
for
(
ki
=
0
;
ki
<
handle
->
desc
.
S
;
++
ki
)
{
for
(
ofm2
=
0
;
ofm2
<
handle
->
ofmblock
;
++
ofm2
)
{
for
(
ifm2
=
0
;
ifm2
<
handle
->
ifmblock
;
++
ifm2
)
{
LIBXSMM_VLA_ACCESS
(
6
,
tr_wt
,
ifm1
,
ofm1
,
handle
->
desc
.
R
-
1
-
kj
,
handle
->
desc
.
S
-
1
-
ki
,
ofm2
,
ifm2
,
handle
->
blocksofm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
handle
->
ofmblock
,
handle
->
ifmblock
)
=
LIBXSMM_VLA_ACCESS
(
6
,
wt
,
ofm1
,
ifm1
,
kj
,
ki
,
ifm2
,
ofm2
,
handle
->
blocksifm
,
handle
->
desc
.
R
,
handle
->
desc
.
S
,
handle
->
ifmblock
,
handle
->
ofmblock
);
}
}
}
}
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_TENSOR
;
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_bind_tensor
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_tensor
*
tensor
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_BATCH_STATS
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_FWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_BWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_UPD
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
&&
tensor
!=
0
)
{
libxsmm_dnn_tensor_datalayout
*
handle_layout
=
libxsmm_dnn_create_tensor_datalayout
(
handle
,
type
,
&
status
);
if
(
libxsmm_dnn_compare_tensor_datalayout
(
handle_layout
,
tensor
->
layout
,
&
status
)
==
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
handle
->
reg_input
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
handle
->
grad_input
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
handle
->
reg_output
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
handle
->
grad_output
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
{
handle
->
reg_filter
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
{
handle
->
grad_filter
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
{
handle
->
reg_bias
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
{
handle
->
grad_bias
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
{
handle
->
reg_filter_tr
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_BATCH_STATS
)
{
handle
->
batch_stats
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_FWD
)
{
handle
->
maxstats_fwd
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_BWD
)
{
handle
->
maxstats_bwd
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_UPD
)
{
handle
->
maxstats_upd
=
(
libxsmm_dnn_tensor
*
)
tensor
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_MISMATCH_TENSOR
;
}
libxsmm_dnn_destroy_tensor_datalayout
(
handle_layout
);
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_tensor
*
libxsmm_dnn_get_tensor
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_tensor_type
type
,
libxsmm_dnn_err_t
*
status
)
{
libxsmm_dnn_tensor
*
return_tensor
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_BATCH_STATS
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_FWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_BWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_UPD
)
)
{
*
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
return_tensor
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
return_tensor
=
handle
->
reg_input
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
return_tensor
=
handle
->
grad_input
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
return_tensor
=
handle
->
reg_output
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
return_tensor
=
handle
->
grad_output
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
{
return_tensor
=
handle
->
reg_filter
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
{
return_tensor
=
handle
->
grad_filter
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
{
return_tensor
=
handle
->
reg_bias
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
{
return_tensor
=
handle
->
grad_bias
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
{
return_tensor
=
handle
->
reg_filter_tr
;
}
else
if
(
type
==
LIBXSMM_DNN_BATCH_STATS
)
{
return_tensor
=
handle
->
batch_stats
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_FWD
)
{
return_tensor
=
handle
->
maxstats_fwd
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_BWD
)
{
return_tensor
=
handle
->
maxstats_bwd
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_UPD
)
{
return_tensor
=
handle
->
maxstats_upd
;
}
else
{
/* cannot happen */
}
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
return_tensor
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_release_tensor
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_tensor_type
type
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check for tensor type */
if
(
(
type
!=
LIBXSMM_DNN_REGULAR_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_INPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_OUTPUT
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_FILTER
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
&&
(
type
!=
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
&&
(
type
!=
LIBXSMM_DNN_BATCH_STATS
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_FWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_BWD
)
&&
(
type
!=
LIBXSMM_DNN_MAX_STATS_UPD
)
)
{
status
=
LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE
;
return
status
;
}
if
(
handle
!=
0
)
{
if
(
type
==
LIBXSMM_DNN_REGULAR_INPUT
)
{
handle
->
reg_input
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_INPUT
)
{
handle
->
grad_input
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_OUTPUT
)
{
handle
->
reg_output
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_OUTPUT
)
{
handle
->
grad_output
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER
)
{
handle
->
reg_filter
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_FILTER
)
{
handle
->
grad_filter
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_CHANNEL_BIAS
)
{
handle
->
reg_bias
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS
)
{
handle
->
grad_bias
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_REGULAR_FILTER_TRANS
)
{
handle
->
reg_filter_tr
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_BATCH_STATS
)
{
handle
->
batch_stats
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_FWD
)
{
handle
->
maxstats_fwd
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_BWD
)
{
handle
->
maxstats_bwd
=
0
;
}
else
if
(
type
==
LIBXSMM_DNN_MAX_STATS_UPD
)
{
handle
->
maxstats_upd
=
0
;
}
else
{
/* cannot happen */
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR
;
}
return
status
;
}
LIBXSMM_API
size_t
libxsmm_dnn_get_scratch_size
(
const
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_compute_kind
kind
,
libxsmm_dnn_err_t
*
status
)
{
size_t
l_scratch_size
=
0
;
*
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
break
;
default:
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
l_scratch_size
+=
handle
->
scratch_size
+
64
;
}
else
{
*
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
l_scratch_size
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_bind_scratch
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_compute_kind
kind
,
const
void
*
scratch
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
uintptr_t
address
=
(
uintptr_t
)
scratch
;
size_t
offset
=
0
;
if
(
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED
;
return
status
;
}
if
(
0
!=
handle
)
{
if
(
address
%
64
==
0
)
{
handle
->
scratch
=
(
void
*
)
address
;
}
else
{
offset
=
(
64
-
address
%
64
);
handle
->
scratch
=
(
void
*
)(
address
+
offset
);
}
address
+=
handle
->
scratch_size
+
64
;
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_release_scratch
(
libxsmm_dnn_layer
*
handle
,
const
libxsmm_dnn_compute_kind
kind
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
handle
->
scratch
=
0
;
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_ALL
:
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API_INLINE
libxsmm_dnn_err_t
internal_execute_st
(
libxsmm_dnn_layer
*
handle
,
libxsmm_dnn_compute_kind
kind
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
if
(
0
!=
handle
)
{
switch
(
handle
->
algo
)
{
case
LIBXSMM_DNN_CONV_ALGO_DIRECT
:
{
switch
(
kind
)
{
case
LIBXSMM_DNN_COMPUTE_KIND_FWD
:
{
switch
(
handle
->
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
:
{
status
=
libxsmm_dnn_convolve_st_fwd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_fwd_nhwc_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWD
:
{
switch
(
handle
->
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
:
{
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_UPD
:
{
switch
(
handle
->
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
:
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_COMPUTE_KIND_BWDUPD
:
{
switch
(
handle
->
buffer_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom
(
handle
,
start_thread
,
tid
);
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_NHWC
:
{
switch
(
handle
->
filter_format
)
{
case
LIBXSMM_DNN_TENSOR_FORMAT_RSCK
:
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_rsck
(
handle
,
start_thread
,
tid
);
}
break
;
case
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM
:
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_custom
(
handle
,
start_thread
,
tid
);
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_custom
(
handle
,
start_thread
,
tid
);
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_KIND
;
}
}
}
break
;
default:
{
status
=
LIBXSMM_DNN_ERR_INVALID_ALGO
;
}
}
}
else
{
status
=
LIBXSMM_DNN_ERR_INVALID_HANDLE
;
}
return
status
;
}
LIBXSMM_API
libxsmm_dnn_err_t
libxsmm_dnn_execute_st
(
libxsmm_dnn_layer
*
handle
,
libxsmm_dnn_compute_kind
kind
,
/*unsigned*/
int
start_thread
,
/*unsigned*/
int
tid
)
{
return
internal_execute_st
(
handle
,
kind
,
start_thread
,
tid
);
}
LIBXSMM_API
void
libxsmm_dnn_execute
(
libxsmm_dnn_layer
*
handle
,
libxsmm_dnn_compute_kind
kind
)
{
#if defined(_OPENMP)
# pragma omp parallel num_threads(handle->desc.threads)
{
const
int
tid
=
omp_get_thread_num
();
internal_execute_st
(
handle
,
kind
,
0
,
tid
);
}
#else
internal_execute_st
(
handle
,
kind
,
0
/*start_thread*/
,
0
/*tid*/
);
#endif
}
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Evangelos Georganas, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_convolution_backward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
bf16_vnni_transpose_16x16_kernel
(
void
*
source_void
,
void
*
dest_void
,
int
source_stride
,
int
dest_stride
)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
libxsmm_bfloat16
*
source
=
(
libxsmm_bfloat16
*
)
source_void
;
libxsmm_bfloat16
*
dest
=
(
libxsmm_bfloat16
*
)
dest_void
;
__m512i
zmm0
,
zmm1
,
zmm2
,
zmm3
,
zmm4
,
zmm5
,
zmm6
,
zmm7
;
__m512i
tmp0
,
tmp1
,
tmp2
,
tmp3
;
const
__m512i
abcdefgh_to_abefcdgh
=
_mm512_set4_epi32
(
0x0f0e0b0a
,
0x0d0c0908
,
0x07060302
,
0x05040100
);
zmm0
=
_mm512_load_epi32
(
source
);
zmm1
=
_mm512_load_epi32
(
source
+
source_stride
);
zmm2
=
_mm512_load_epi32
(
source
+
source_stride
*
2
);
zmm3
=
_mm512_load_epi32
(
source
+
source_stride
*
3
);
zmm4
=
_mm512_load_epi32
(
source
+
source_stride
*
4
);
zmm5
=
_mm512_load_epi32
(
source
+
source_stride
*
5
);
zmm6
=
_mm512_load_epi32
(
source
+
source_stride
*
6
);
zmm7
=
_mm512_load_epi32
(
source
+
source_stride
*
7
);
zmm0
=
_mm512_shuffle_epi8
(
zmm0
,
abcdefgh_to_abefcdgh
);
zmm1
=
_mm512_shuffle_epi8
(
zmm1
,
abcdefgh_to_abefcdgh
);
zmm2
=
_mm512_shuffle_epi8
(
zmm2
,
abcdefgh_to_abefcdgh
);
zmm3
=
_mm512_shuffle_epi8
(
zmm3
,
abcdefgh_to_abefcdgh
);
zmm4
=
_mm512_shuffle_epi8
(
zmm4
,
abcdefgh_to_abefcdgh
);
zmm5
=
_mm512_shuffle_epi8
(
zmm5
,
abcdefgh_to_abefcdgh
);
zmm6
=
_mm512_shuffle_epi8
(
zmm6
,
abcdefgh_to_abefcdgh
);
zmm7
=
_mm512_shuffle_epi8
(
zmm7
,
abcdefgh_to_abefcdgh
);
tmp0
=
_mm512_unpacklo_epi64
(
zmm0
,
zmm1
);
tmp1
=
_mm512_unpackhi_epi64
(
zmm0
,
zmm1
);
tmp2
=
_mm512_unpacklo_epi64
(
zmm2
,
zmm3
);
tmp3
=
_mm512_unpackhi_epi64
(
zmm2
,
zmm3
);
zmm0
=
_mm512_unpacklo_epi64
(
zmm4
,
zmm5
);
zmm1
=
_mm512_unpackhi_epi64
(
zmm4
,
zmm5
);
zmm2
=
_mm512_unpacklo_epi64
(
zmm6
,
zmm7
);
zmm3
=
_mm512_unpackhi_epi64
(
zmm6
,
zmm7
);
zmm4
=
_mm512_shuffle_i32x4
(
tmp0
,
tmp2
,
0x88
);
zmm6
=
_mm512_shuffle_i32x4
(
tmp0
,
tmp2
,
0xdd
);
zmm5
=
_mm512_shuffle_i32x4
(
tmp1
,
tmp3
,
0x88
);
zmm7
=
_mm512_shuffle_i32x4
(
tmp1
,
tmp3
,
0xdd
);
tmp0
=
_mm512_shuffle_i32x4
(
zmm0
,
zmm2
,
0x88
);
tmp1
=
_mm512_shuffle_i32x4
(
zmm0
,
zmm2
,
0xdd
);
tmp2
=
_mm512_shuffle_i32x4
(
zmm1
,
zmm3
,
0x88
);
tmp3
=
_mm512_shuffle_i32x4
(
zmm1
,
zmm3
,
0xdd
);
zmm0
=
_mm512_shuffle_i32x4
(
zmm4
,
tmp0
,
0x88
);
zmm1
=
_mm512_shuffle_i32x4
(
zmm5
,
tmp2
,
0x88
);
zmm2
=
_mm512_shuffle_i32x4
(
zmm6
,
tmp1
,
0x88
);
zmm3
=
_mm512_shuffle_i32x4
(
zmm7
,
tmp3
,
0x88
);
zmm4
=
_mm512_shuffle_i32x4
(
zmm4
,
tmp0
,
0xdd
);
zmm5
=
_mm512_shuffle_i32x4
(
zmm5
,
tmp2
,
0xdd
);
zmm6
=
_mm512_shuffle_i32x4
(
zmm6
,
tmp1
,
0xdd
);
zmm7
=
_mm512_shuffle_i32x4
(
zmm7
,
tmp3
,
0xdd
);
_mm512_store_epi32
(
dest
,
zmm0
);
_mm512_store_epi32
(
dest
+
dest_stride
,
zmm1
);
_mm512_store_epi32
(
dest
+
dest_stride
*
2
,
zmm2
);
_mm512_store_epi32
(
dest
+
dest_stride
*
3
,
zmm3
);
_mm512_store_epi32
(
dest
+
dest_stride
*
4
,
zmm4
);
_mm512_store_epi32
(
dest
+
dest_stride
*
5
,
zmm5
);
_mm512_store_epi32
(
dest
+
dest_stride
*
6
,
zmm6
);
_mm512_store_epi32
(
dest
+
dest_stride
*
7
,
zmm7
);
#else
LIBXSMM_UNUSED
(
source_void
);
LIBXSMM_UNUSED
(
dest_void
);
LIBXSMM_UNUSED
(
source_stride
);
LIBXSMM_UNUSED
(
dest_stride
);
#endif
}
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
bf16_vnni_transpose_kernel
(
libxsmm_bfloat16
*
src
,
libxsmm_bfloat16
*
dst
,
int
M
,
int
N
,
int
ld_in
,
int
ld_out
)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
const
int
_M
=
M
/
16
,
_N
=
N
/
16
;
int
i
=
0
,
j
=
0
;
for
(
i
=
0
;
i
<
_N
;
i
++
)
{
for
(
j
=
0
;
j
<
_M
;
j
++
)
{
bf16_vnni_transpose_16x16_kernel
((
libxsmm_bfloat16
*
)
src
+
i
*
16
*
ld_in
+
j
*
32
,
(
libxsmm_bfloat16
*
)
dst
+
j
*
16
*
ld_out
+
i
*
32
,
ld_in
*
2
,
ld_out
*
2
);
}
}
#else
LIBXSMM_UNUSED
(
src
);
LIBXSMM_UNUSED
(
dst
);
LIBXSMM_UNUSED
(
M
);
LIBXSMM_UNUSED
(
N
);
LIBXSMM_UNUSED
(
ld_in
);
LIBXSMM_UNUSED
(
ld_out
);
#endif
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)
handle
->
ofmblock
;
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c"
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c"
}
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
typedef
libxsmm_bmmfunction_reducebatch_addr
gemm_br_function_bf16bf16
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)
handle
->
ofmblock
;
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel2_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
}
else
{
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
brgemm_function
;
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
int
stride_a
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
libxsmm_bfloat16
);
int
stride_b
=
handle
->
ofmblock
*
handle
->
ofwp
*
handle
->
ofhp
*
sizeof
(
libxsmm_bfloat16
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function
bf16fp32_brgemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
stride_a
,
stride_b
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
&
l_flags
,
NULL
);
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
typedef
libxsmm_bsmmfunction_reducebatch_offs
gemm_br_function_offs
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
gemm_br_function_strd
;
gemm_br_function_offs
br_gemm_kernel_offs
=
handle
->
bwd_compute_kernel_offs
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
bwd_compute_kernel_strd
;
gemm_function
tile_config_kernel
=
handle
->
bwd_config_kernel
;
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
}
else
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
brgemm_function
;
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
int
stride_a
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
libxsmm_bfloat16
);
int
stride_b
=
handle
->
ofmblock
*
handle
->
ofwp
*
handle
->
ofhp
*
sizeof
(
libxsmm_bfloat16
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function
bf16fp32_brgemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
stride_a
,
stride_b
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
&
l_flags
,
NULL
);
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
typedef
libxsmm_bmmfunction_reducebatch_addr
gemm_br_function_bf16bf16
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)
handle
->
ofmblock
;
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel2_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
else
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
brgemm_function
;
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
int
stride_a
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
libxsmm_bfloat16
);
int
stride_b
=
handle
->
ofmblock
*
handle
->
ofwp
*
handle
->
ofhp
*
sizeof
(
libxsmm_bfloat16
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function
bf16fp32_brgemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
stride_a
,
stride_b
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
&
l_flags
,
NULL
);
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bsmmfunction
gemm_function
;
typedef
libxsmm_bsmmfunction_reducebatch_offs
gemm_br_function_offs
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
gemm_br_function_strd
;
gemm_br_function_offs
br_gemm_kernel_offs
=
handle
->
bwd_compute_kernel_offs
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
bwd_compute_kernel_strd
;
gemm_function
tile_config_kernel
=
handle
->
bwd_config_kernel
;
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
else
{
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction_reducebatch_strd
brgemm_function
;
const
libxsmm_blasint
ldC
=
(
libxsmm_blasint
)(
handle
->
desc
.
v
*
handle
->
ifmblock
);
int
l_flags
=
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
);
int
stride_a
=
handle
->
ifmblock
*
handle
->
desc
.
R
*
handle
->
desc
.
S
*
handle
->
ofmblock
*
sizeof
(
libxsmm_bfloat16
);
int
stride_b
=
handle
->
ofmblock
*
handle
->
ofwp
*
handle
->
ofhp
*
sizeof
(
libxsmm_bfloat16
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
brgemm_function
bf16fp32_brgemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_strd
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
stride_a
,
stride_b
,
NULL
,
NULL
,
&
ldC
,
NULL
,
NULL
,
&
l_flags
,
NULL
);
# define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
# undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
);
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
||
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
);
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
||
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
grad_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
reg_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_bwd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldx
=
((
libxsmm_blasint
)
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
handle
->
ifmblock
*
handle
->
desc
.
v
:
handle
->
ifmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_generic.tpl.c"
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldx
=
((
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
NULL
,
NULL
,
&
ldx
,
NULL
,
NULL
,
NULL
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_bwd_custom_custom_fallback_generic.tpl.c"
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
grad_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
reg_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_rsck_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
);
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
||
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_RSCK
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
grad_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
reg_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_bwd_nhwc_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
if
(
handle
->
use_fallback_bwd_loops
==
0
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
handle
->
spread_input_bwd
==
1
)
?
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
);
const
float
beta
=
(
handle
->
avoid_acc_load_bwd
?
0
.
f
:
1
.
f
);
int
l_flags
=
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
);
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
handle
->
bwd_ofw_rb
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ifmblock
,
handle
->
bwd_ofh_rb
*
(
handle
->
bwd_ofw_rb
-
1
),
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
}
else
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
const
libxsmm_blasint
ldB
=
(
libxsmm_blasint
)(
handle
->
blocksofm
*
handle
->
ofmblock
);
const
libxsmm_blasint
ldA
=
(
libxsmm_blasint
)
handle
->
ifmblock
;
const
libxsmm_blasint
ldC
=
(
(
handle
->
desc
.
pad_h
!=
handle
->
desc
.
pad_h_in
)
||
(
handle
->
desc
.
pad_w
!=
handle
->
desc
.
pad_w_in
)
)
?
(
libxsmm_blasint
)(
handle
->
ifmblock
*
handle
->
desc
.
v
)
:
(
libxsmm_blasint
)(
handle
->
blocksifm
*
handle
->
ifmblock
*
handle
->
desc
.
v
);
/* let's do a ifmblock x ofw_rb x ofmblock GEMM :-) or in other words M=nbIfm, N=ofw, K=nbOfm (col-major) */
gemm_function
gemm_kernel
=
libxsmm_smmdispatch
(
handle
->
ifmblock
,
handle
->
ofw
,
handle
->
ofmblock
,
&
ldA
,
&
ldB
,
&
ldC
,
NULL
,
NULL
,
NULL
,
NULL
);
# define LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_bwd_nhwc_custom-rsck_fallback_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_BWD_DIRECT_GENERIC_NHWC_CUSTOM
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_convolution_backward.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Rajkishore Barik, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_CONVOLUTION_BACKWARD_H
#define LIBXSMM_DNN_CONVOLUTION_BACKWARD_H
#include <libxsmm_dnn_convolution.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_bwd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_CONVOLUTION_BACKWARD_H */
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke, Evangelos Georganas, Hans Pabst (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_convolution_forward.h"
#include "libxsmm_main.h"
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
#if 1
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function_addr
;
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function_addr
br_gemm_kernel_a_addr
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function_addr
br_gemm_kernel_b_addr
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
#else
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function_addr
;
typedef
libxsmm_smmfunction_reducebatch_offs
gemm_br_function_offs
;
typedef
libxsmm_smmfunction_reducebatch_strd
gemm_br_function_strd
;
{
gemm_br_function_addr
br_gemm_kernel_a_addr
=
handle
->
fwd_compute_kernel_addr_a_f32
;
gemm_br_function_addr
br_gemm_kernel_b_addr
=
handle
->
fwd_compute_kernel_addr_b_f32
;
gemm_br_function_offs
br_gemm_kernel_offs
=
handle
->
fwd_compute_kernel_offs_f32
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
fwd_compute_kernel_strd_f32
;
#endif
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c"
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
{
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
typedef
libxsmm_bmmfunction_reducebatch_addr
gemm_br_function_bf16bf16
;
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
handle
->
fwd_flags
;
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel2_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction
gemm_function
;
typedef
libxsmm_bmmfunction_reducebatch_offs
gemm_br_function_offs_a
;
typedef
libxsmm_bsmmfunction_reducebatch_offs
gemm_br_function_offs_b
;
typedef
libxsmm_bmmfunction_reducebatch_strd
gemm_br_function_strd
;
gemm_br_function_offs_a
br_gemm_kernel_offs_a
=
handle
->
fwd_compute_kernel_offs_a
;
gemm_br_function_offs_b
br_gemm_kernel_offs_b
=
handle
->
fwd_compute_kernel_offs_b
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
fwd_compute_kernel_strd
;
gemm_function
tile_config_kernel
=
handle
->
fwd_config_kernel
;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
typedef
libxsmm_bmmfunction_reducebatch_addr
gemm_br_function_bf16bf16
;
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
(
LIBXSMM_GEMM_VNNI_FLAGS
(
'N'
,
'N'
,
'V'
,
'N'
)
)
|
handle
->
fwd_flags
;
gemm_br_function
br_gemm_kernel
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_bsmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
gemm_br_function_bf16bf16
br_gemm_kernel2_bf16bf16
=
libxsmm_bmmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
NULL
);
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction
gemm_function
;
typedef
libxsmm_bmmfunction_reducebatch_offs
gemm_br_function_offs_a
;
typedef
libxsmm_bsmmfunction_reducebatch_offs
gemm_br_function_offs_b
;
typedef
libxsmm_bmmfunction_reducebatch_strd
gemm_br_function_strd
;
gemm_br_function_offs_a
br_gemm_kernel_offs_a
=
handle
->
fwd_compute_kernel_offs_a
;
gemm_br_function_offs_b
br_gemm_kernel_offs_b
=
handle
->
fwd_compute_kernel_offs_b
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
fwd_compute_kernel_strd
;
gemm_function
tile_config_kernel
=
handle
->
fwd_config_kernel
;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
unsigned
char
element_input_type
;
typedef
int
element_output_type
;
typedef
char
element_filter_type
;
/* Basically we need only offset based and strided BRGEMMs */
libxsmm_subimmfunction_reducebatch_strd
br_gemm_kernel_strided
=
handle
->
gemm_fwd
.
xgemm
.
subimrs
;
libxsmm_subimmfunction_reducebatch_strd
br_gemm_kernel_strided2
=
handle
->
gemm_fwd2
.
xgemm
.
subimrs
;
libxsmm_subimmfunction_reducebatch_offs
br_gemm_kernel_offset
=
handle
->
gemm_fwd
.
xgemm
.
subimro
;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i32.tpl.c"
#else
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
unsigned
char
element_input_type
;
typedef
unsigned
char
element_output_type
;
typedef
char
element_filter_type
;
/* Basically we need only offset based and strided BRGEMMs */
libxsmm_sububmmfunction_reducebatch_strd
br_gemm_kernel_strided
=
handle
->
gemm_fwd
.
xgemm
.
sububmrs
;
libxsmm_sububmmfunction_reducebatch_offs
br_gemm_kernel_offset
=
handle
->
gemm_fwd
.
xgemm
.
sububmro
;
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic_i8i8.tpl.c"
#else
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
}
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
reg_output
==
0
||
handle
->
reg_filter
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_I32
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i32
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_I8
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_I8
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_i8_i8
(
handle
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_fwd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
#if 1
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function_addr
;
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function_addr
br_gemm_kernel_a_addr
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function_addr
br_gemm_kernel_b_addr
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
#else
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function_addr
;
typedef
libxsmm_smmfunction_reducebatch_offs
gemm_br_function_offs
;
typedef
libxsmm_smmfunction_reducebatch_strd
gemm_br_function_strd
;
{
gemm_br_function_addr
br_gemm_kernel_a_addr
=
handle
->
fwd_compute_kernel_addr_a_f32
;
gemm_br_function_addr
br_gemm_kernel_b_addr
=
handle
->
fwd_compute_kernel_addr_b_f32
;
gemm_br_function_offs
br_gemm_kernel_offs
=
handle
->
fwd_compute_kernel_offs_f32
;
gemm_br_function_strd
br_gemm_kernel_strd
=
handle
->
fwd_compute_kernel_strd_f32
;
#endif
# include "template/libxsmm_dnn_convolve_st_fwd_custom_custom_generic.tpl.c"
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
reg_output
==
0
||
handle
->
reg_filter
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_fwd_nhwc_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
reg_output
==
0
||
handle
->
reg_filter
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_fwd_nhwc_rsck_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
const
libxsmm_blasint
ldx
=
(
handle
->
pack_input
==
1
)
?
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
ifmblock
:
(
libxsmm_blasint
)
handle
->
blocksifm
*
handle
->
desc
.
v
*
handle
->
ifmblock
;
const
libxsmm_blasint
ldA
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
libxsmm_blasint
ldC
=
handle
->
blocksofm
*
handle
->
ofmblock
;
const
float
beta
=
(
handle
->
avoid_acc_load
)
?
0
.
f
:
1
.
f
;
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
int
l_flags
=
(
LIBXSMM_GEMM_FLAGS
(
'N'
,
'N'
)
)
|
handle
->
fwd_flags
;
int
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_NONE
);
int
brgemm_pf_oob
=
0
;
const
char
*
const
env_brgemm_pf_oob
=
getenv
(
"BRGEMM_PF_OOB"
);
if
(
0
==
env_brgemm_pf_oob
)
{
}
else
{
brgemm_pf_oob
=
atoi
(
env_brgemm_pf_oob
);
}
if
(
brgemm_pf_oob
>
0
)
{
prefetch_mode
=
libxsmm_get_gemm_prefetch
(
LIBXSMM_GEMM_PREFETCH_BRGEMM_OOB
);
}
{
/* let's do a ofmblock x ofw_rb x ifmblock GEMM :-) or in other words M=nbOfm, N=ofw, K=nbIfm (col-major) */
gemm_br_function
br_gemm_kernel
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
handle
->
fwd_ofw_rb
,
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
gemm_br_function
br_gemm_kernel2
=
libxsmm_smmdispatch_reducebatch_addr
(
handle
->
ofmblock
,
handle
->
fwd_ofh_rb
*
(
handle
->
fwd_ofw_rb
-
1
),
handle
->
ifmblock
,
&
ldA
,
&
ldx
,
&
ldC
,
NULL
,
&
beta
,
&
l_flags
,
&
prefetch_mode
);
# define LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_fwd_nhwc_custom-rsck_generic.tpl.c"
# undef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK
}
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_convolution_forward.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_CONVOLUTION_FORWARD_H
#define LIBXSMM_DNN_CONVOLUTION_FORWARD_H
#include <libxsmm_dnn_convolution.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_fwd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_CONVOLUTION_FORWARD_H */
third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.c
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Rajkishore Barik, Alexander Heinecke, Ankush Mandal, Jason Sewall (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_convolution_weight_update.h"
#include "libxsmm_main.h"
/* function prototypes for below implementations */
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
transpose_32x16
(
const
libxsmm_bfloat16
*
in
,
libxsmm_bfloat16
*
out
,
int
ld_in
,
int
ld_out
)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
,
r9
,
ra
,
rb
,
rc
,
rd
,
re
,
rf
;
__m512i
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
,
t8
,
t9
,
ta
,
tb
,
tc
,
td
,
te
,
tf
;
const
int
in_width
=
ld_in
,
out_width
=
ld_out
;
const
__m512i
idx_lo
=
_mm512_set_epi64
(
13
,
12
,
5
,
4
,
9
,
8
,
1
,
0
);
const
__m512i
idx_hi
=
_mm512_set_epi64
(
7
,
6
,
15
,
14
,
3
,
2
,
11
,
10
);
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
rc
=
_mm512_loadu_si512
(
in
+
12
*
in_width
);
rd
=
_mm512_loadu_si512
(
in
+
13
*
in_width
);
re
=
_mm512_loadu_si512
(
in
+
14
*
in_width
);
rf
=
_mm512_loadu_si512
(
in
+
15
*
in_width
);
t0
=
_mm512_unpacklo_epi16
(
r0
,
r1
);
t1
=
_mm512_unpackhi_epi16
(
r0
,
r1
);
t2
=
_mm512_unpacklo_epi16
(
r2
,
r3
);
t3
=
_mm512_unpackhi_epi16
(
r2
,
r3
);
t4
=
_mm512_unpacklo_epi16
(
r4
,
r5
);
t5
=
_mm512_unpackhi_epi16
(
r4
,
r5
);
t6
=
_mm512_unpacklo_epi16
(
r6
,
r7
);
t7
=
_mm512_unpackhi_epi16
(
r6
,
r7
);
t8
=
_mm512_unpacklo_epi16
(
r8
,
r9
);
t9
=
_mm512_unpackhi_epi16
(
r8
,
r9
);
ta
=
_mm512_unpacklo_epi16
(
ra
,
rb
);
tb
=
_mm512_unpackhi_epi16
(
ra
,
rb
);
tc
=
_mm512_unpacklo_epi16
(
rc
,
rd
);
td
=
_mm512_unpackhi_epi16
(
rc
,
rd
);
te
=
_mm512_unpacklo_epi16
(
re
,
rf
);
tf
=
_mm512_unpackhi_epi16
(
re
,
rf
);
r0
=
_mm512_unpacklo_epi32
(
t0
,
t2
);
r1
=
_mm512_unpackhi_epi32
(
t0
,
t2
);
r2
=
_mm512_unpacklo_epi32
(
t1
,
t3
);
r3
=
_mm512_unpackhi_epi32
(
t1
,
t3
);
r4
=
_mm512_unpacklo_epi32
(
t4
,
t6
);
r5
=
_mm512_unpackhi_epi32
(
t4
,
t6
);
r6
=
_mm512_unpacklo_epi32
(
t5
,
t7
);
r7
=
_mm512_unpackhi_epi32
(
t5
,
t7
);
r8
=
_mm512_unpacklo_epi32
(
t8
,
ta
);
r9
=
_mm512_unpackhi_epi32
(
t8
,
ta
);
ra
=
_mm512_unpacklo_epi32
(
t9
,
tb
);
rb
=
_mm512_unpackhi_epi32
(
t9
,
tb
);
rc
=
_mm512_unpacklo_epi32
(
tc
,
te
);
rd
=
_mm512_unpackhi_epi32
(
tc
,
te
);
re
=
_mm512_unpacklo_epi32
(
td
,
tf
);
rf
=
_mm512_unpackhi_epi32
(
td
,
tf
);
t0
=
_mm512_unpacklo_epi64
(
r0
,
r4
);
t1
=
_mm512_unpackhi_epi64
(
r0
,
r4
);
t2
=
_mm512_unpacklo_epi64
(
r1
,
r5
);
t3
=
_mm512_unpackhi_epi64
(
r1
,
r5
);
t4
=
_mm512_unpacklo_epi64
(
r2
,
r6
);
t5
=
_mm512_unpackhi_epi64
(
r2
,
r6
);
t6
=
_mm512_unpacklo_epi64
(
r3
,
r7
);
t7
=
_mm512_unpackhi_epi64
(
r3
,
r7
);
t8
=
_mm512_unpacklo_epi64
(
r8
,
rc
);
t9
=
_mm512_unpackhi_epi64
(
r8
,
rc
);
ta
=
_mm512_unpacklo_epi64
(
r9
,
rd
);
tb
=
_mm512_unpackhi_epi64
(
r9
,
rd
);
tc
=
_mm512_unpacklo_epi64
(
ra
,
re
);
td
=
_mm512_unpackhi_epi64
(
ra
,
re
);
te
=
_mm512_unpacklo_epi64
(
rb
,
rf
);
tf
=
_mm512_unpackhi_epi64
(
rb
,
rf
);
r0
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0x88
);
r1
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0x88
);
r2
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0x88
);
r3
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0x88
);
r4
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0xdd
);
r5
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0xdd
);
r6
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0xdd
);
r7
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0xdd
);
r8
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0x88
);
r9
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0x88
);
ra
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0x88
);
rb
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0x88
);
rc
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0xdd
);
rd
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0xdd
);
re
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0xdd
);
rf
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0xdd
);
t0
=
_mm512_permutex2var_epi64
(
r0
,
idx_lo
,
r8
);
t1
=
_mm512_permutex2var_epi64
(
r1
,
idx_lo
,
r9
);
t2
=
_mm512_permutex2var_epi64
(
r2
,
idx_lo
,
ra
);
t3
=
_mm512_permutex2var_epi64
(
r3
,
idx_lo
,
rb
);
t4
=
_mm512_permutex2var_epi64
(
r4
,
idx_lo
,
rc
);
t5
=
_mm512_permutex2var_epi64
(
r5
,
idx_lo
,
rd
);
t6
=
_mm512_permutex2var_epi64
(
r6
,
idx_lo
,
re
);
t7
=
_mm512_permutex2var_epi64
(
r7
,
idx_lo
,
rf
);
t8
=
_mm512_permutex2var_epi64
(
r8
,
idx_hi
,
r0
);
t9
=
_mm512_permutex2var_epi64
(
r9
,
idx_hi
,
r1
);
ta
=
_mm512_permutex2var_epi64
(
ra
,
idx_hi
,
r2
);
tb
=
_mm512_permutex2var_epi64
(
rb
,
idx_hi
,
r3
);
tc
=
_mm512_permutex2var_epi64
(
rc
,
idx_hi
,
r4
);
td
=
_mm512_permutex2var_epi64
(
rd
,
idx_hi
,
r5
);
te
=
_mm512_permutex2var_epi64
(
re
,
idx_hi
,
r6
);
tf
=
_mm512_permutex2var_epi64
(
rf
,
idx_hi
,
r7
);
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
0
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t0
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
1
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t0
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
2
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t1
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
3
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t1
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
4
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t2
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
5
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t2
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
6
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t3
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
7
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t3
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
8
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t4
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
9
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t4
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
10
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t5
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
11
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t5
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
12
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t6
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
13
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t6
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
14
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t7
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
15
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t7
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
16
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t8
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
17
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t8
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
18
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t9
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
19
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t9
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
20
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
ta
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
21
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
ta
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
22
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tb
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
23
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tb
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
24
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tc
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
25
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tc
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
26
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
td
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
27
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
td
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
28
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
te
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
29
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
te
,
1
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
30
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tf
,
0
));
LIBXSMM_INTRINSICS_MM256_STORE_EPI32
(
out
+
31
*
out_width
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tf
,
1
));
#else
LIBXSMM_UNUSED
(
in
);
LIBXSMM_UNUSED
(
out
);
LIBXSMM_UNUSED
(
ld_in
);
LIBXSMM_UNUSED
(
ld_out
);
#endif
}
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
transpose_32xcols
(
const
libxsmm_bfloat16
*
in
,
libxsmm_bfloat16
*
out
,
int
col
,
int
ld_in
,
int
ld_out
)
{
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
__m512i
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
,
r8
,
r9
,
ra
,
rb
,
rc
,
rd
,
re
,
rf
;
__m512i
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
,
t8
,
t9
,
ta
,
tb
,
tc
,
td
,
te
,
tf
;
const
int
in_width
=
ld_in
,
out_width
=
ld_out
;
const
__m512i
idx_lo
=
_mm512_set_epi64
(
13
,
12
,
5
,
4
,
9
,
8
,
1
,
0
);
const
__m512i
idx_hi
=
_mm512_set_epi64
(
7
,
6
,
15
,
14
,
3
,
2
,
11
,
10
);
__mmask16
store_mask
=
LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16
(((
unsigned
int
)
1
<<
col
)
-
1
);
rf
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
if
(
col
==
15
)
{
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
rc
=
_mm512_loadu_si512
(
in
+
12
*
in_width
);
rd
=
_mm512_loadu_si512
(
in
+
13
*
in_width
);
re
=
_mm512_loadu_si512
(
in
+
14
*
in_width
);
}
else
if
(
col
==
14
)
{
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
rc
=
_mm512_loadu_si512
(
in
+
12
*
in_width
);
rd
=
_mm512_loadu_si512
(
in
+
13
*
in_width
);
}
else
if
(
col
==
13
)
{
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
rc
=
_mm512_loadu_si512
(
in
+
12
*
in_width
);
}
else
if
(
col
==
12
)
{
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
rb
=
_mm512_loadu_si512
(
in
+
11
*
in_width
);
}
else
if
(
col
==
11
)
{
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
ra
=
_mm512_loadu_si512
(
in
+
10
*
in_width
);
}
else
if
(
col
==
10
)
{
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
r9
=
_mm512_loadu_si512
(
in
+
9
*
in_width
);
}
else
if
(
col
==
9
)
{
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
r8
=
_mm512_loadu_si512
(
in
+
8
*
in_width
);
}
else
if
(
col
==
8
)
{
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
r7
=
_mm512_loadu_si512
(
in
+
7
*
in_width
);
}
else
if
(
col
==
7
)
{
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
r6
=
_mm512_loadu_si512
(
in
+
6
*
in_width
);
}
else
if
(
col
==
6
)
{
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
r5
=
_mm512_loadu_si512
(
in
+
5
*
in_width
);
}
else
if
(
col
==
5
)
{
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
r4
=
_mm512_loadu_si512
(
in
+
4
*
in_width
);
}
else
if
(
col
==
4
)
{
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
r3
=
_mm512_loadu_si512
(
in
+
3
*
in_width
);
}
else
if
(
col
==
3
)
{
r3
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
r2
=
_mm512_loadu_si512
(
in
+
2
*
in_width
);
}
else
if
(
col
==
2
)
{
r2
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r3
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
r1
=
_mm512_loadu_si512
(
in
+
1
*
in_width
);
}
else
if
(
col
==
1
)
{
r1
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r2
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r3
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r0
=
_mm512_loadu_si512
(
in
+
0
*
in_width
);
}
else
{
r0
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r1
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r2
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r3
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r4
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r5
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r6
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r7
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r8
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
r9
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
ra
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rb
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rc
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
rd
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
re
=
LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32
();
}
t0
=
_mm512_unpacklo_epi16
(
r0
,
r1
);
t1
=
_mm512_unpackhi_epi16
(
r0
,
r1
);
t2
=
_mm512_unpacklo_epi16
(
r2
,
r3
);
t3
=
_mm512_unpackhi_epi16
(
r2
,
r3
);
t4
=
_mm512_unpacklo_epi16
(
r4
,
r5
);
t5
=
_mm512_unpackhi_epi16
(
r4
,
r5
);
t6
=
_mm512_unpacklo_epi16
(
r6
,
r7
);
t7
=
_mm512_unpackhi_epi16
(
r6
,
r7
);
t8
=
_mm512_unpacklo_epi16
(
r8
,
r9
);
t9
=
_mm512_unpackhi_epi16
(
r8
,
r9
);
ta
=
_mm512_unpacklo_epi16
(
ra
,
rb
);
tb
=
_mm512_unpackhi_epi16
(
ra
,
rb
);
tc
=
_mm512_unpacklo_epi16
(
rc
,
rd
);
td
=
_mm512_unpackhi_epi16
(
rc
,
rd
);
te
=
_mm512_unpacklo_epi16
(
re
,
rf
);
tf
=
_mm512_unpackhi_epi16
(
re
,
rf
);
r0
=
_mm512_unpacklo_epi32
(
t0
,
t2
);
r1
=
_mm512_unpackhi_epi32
(
t0
,
t2
);
r2
=
_mm512_unpacklo_epi32
(
t1
,
t3
);
r3
=
_mm512_unpackhi_epi32
(
t1
,
t3
);
r4
=
_mm512_unpacklo_epi32
(
t4
,
t6
);
r5
=
_mm512_unpackhi_epi32
(
t4
,
t6
);
r6
=
_mm512_unpacklo_epi32
(
t5
,
t7
);
r7
=
_mm512_unpackhi_epi32
(
t5
,
t7
);
r8
=
_mm512_unpacklo_epi32
(
t8
,
ta
);
r9
=
_mm512_unpackhi_epi32
(
t8
,
ta
);
ra
=
_mm512_unpacklo_epi32
(
t9
,
tb
);
rb
=
_mm512_unpackhi_epi32
(
t9
,
tb
);
rc
=
_mm512_unpacklo_epi32
(
tc
,
te
);
rd
=
_mm512_unpackhi_epi32
(
tc
,
te
);
re
=
_mm512_unpacklo_epi32
(
td
,
tf
);
rf
=
_mm512_unpackhi_epi32
(
td
,
tf
);
t0
=
_mm512_unpacklo_epi64
(
r0
,
r4
);
t1
=
_mm512_unpackhi_epi64
(
r0
,
r4
);
t2
=
_mm512_unpacklo_epi64
(
r1
,
r5
);
t3
=
_mm512_unpackhi_epi64
(
r1
,
r5
);
t4
=
_mm512_unpacklo_epi64
(
r2
,
r6
);
t5
=
_mm512_unpackhi_epi64
(
r2
,
r6
);
t6
=
_mm512_unpacklo_epi64
(
r3
,
r7
);
t7
=
_mm512_unpackhi_epi64
(
r3
,
r7
);
t8
=
_mm512_unpacklo_epi64
(
r8
,
rc
);
t9
=
_mm512_unpackhi_epi64
(
r8
,
rc
);
ta
=
_mm512_unpacklo_epi64
(
r9
,
rd
);
tb
=
_mm512_unpackhi_epi64
(
r9
,
rd
);
tc
=
_mm512_unpacklo_epi64
(
ra
,
re
);
td
=
_mm512_unpackhi_epi64
(
ra
,
re
);
te
=
_mm512_unpacklo_epi64
(
rb
,
rf
);
tf
=
_mm512_unpackhi_epi64
(
rb
,
rf
);
r0
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0x88
);
r1
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0x88
);
r2
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0x88
);
r3
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0x88
);
r4
=
_mm512_shuffle_i32x4
(
t0
,
t1
,
0xdd
);
r5
=
_mm512_shuffle_i32x4
(
t2
,
t3
,
0xdd
);
r6
=
_mm512_shuffle_i32x4
(
t4
,
t5
,
0xdd
);
r7
=
_mm512_shuffle_i32x4
(
t6
,
t7
,
0xdd
);
r8
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0x88
);
r9
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0x88
);
ra
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0x88
);
rb
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0x88
);
rc
=
_mm512_shuffle_i32x4
(
t8
,
t9
,
0xdd
);
rd
=
_mm512_shuffle_i32x4
(
ta
,
tb
,
0xdd
);
re
=
_mm512_shuffle_i32x4
(
tc
,
td
,
0xdd
);
rf
=
_mm512_shuffle_i32x4
(
te
,
tf
,
0xdd
);
t0
=
_mm512_permutex2var_epi64
(
r0
,
idx_lo
,
r8
);
t1
=
_mm512_permutex2var_epi64
(
r1
,
idx_lo
,
r9
);
t2
=
_mm512_permutex2var_epi64
(
r2
,
idx_lo
,
ra
);
t3
=
_mm512_permutex2var_epi64
(
r3
,
idx_lo
,
rb
);
t4
=
_mm512_permutex2var_epi64
(
r4
,
idx_lo
,
rc
);
t5
=
_mm512_permutex2var_epi64
(
r5
,
idx_lo
,
rd
);
t6
=
_mm512_permutex2var_epi64
(
r6
,
idx_lo
,
re
);
t7
=
_mm512_permutex2var_epi64
(
r7
,
idx_lo
,
rf
);
t8
=
_mm512_permutex2var_epi64
(
r8
,
idx_hi
,
r0
);
t9
=
_mm512_permutex2var_epi64
(
r9
,
idx_hi
,
r1
);
ta
=
_mm512_permutex2var_epi64
(
ra
,
idx_hi
,
r2
);
tb
=
_mm512_permutex2var_epi64
(
rb
,
idx_hi
,
r3
);
tc
=
_mm512_permutex2var_epi64
(
rc
,
idx_hi
,
r4
);
td
=
_mm512_permutex2var_epi64
(
rd
,
idx_hi
,
r5
);
te
=
_mm512_permutex2var_epi64
(
re
,
idx_hi
,
r6
);
tf
=
_mm512_permutex2var_epi64
(
rf
,
idx_hi
,
r7
);
_mm256_mask_storeu_epi16
(
out
+
0
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t0
,
0
));
_mm256_mask_storeu_epi16
(
out
+
1
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t0
,
1
));
_mm256_mask_storeu_epi16
(
out
+
2
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t1
,
0
));
_mm256_mask_storeu_epi16
(
out
+
3
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t1
,
1
));
_mm256_mask_storeu_epi16
(
out
+
4
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t2
,
0
));
_mm256_mask_storeu_epi16
(
out
+
5
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t2
,
1
));
_mm256_mask_storeu_epi16
(
out
+
6
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t3
,
0
));
_mm256_mask_storeu_epi16
(
out
+
7
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t3
,
1
));
_mm256_mask_storeu_epi16
(
out
+
8
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t4
,
0
));
_mm256_mask_storeu_epi16
(
out
+
9
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t4
,
1
));
_mm256_mask_storeu_epi16
(
out
+
10
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t5
,
0
));
_mm256_mask_storeu_epi16
(
out
+
11
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t5
,
1
));
_mm256_mask_storeu_epi16
(
out
+
12
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t6
,
0
));
_mm256_mask_storeu_epi16
(
out
+
13
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t6
,
1
));
_mm256_mask_storeu_epi16
(
out
+
14
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t7
,
0
));
_mm256_mask_storeu_epi16
(
out
+
15
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t7
,
1
));
_mm256_mask_storeu_epi16
(
out
+
16
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t8
,
0
));
_mm256_mask_storeu_epi16
(
out
+
17
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t8
,
1
));
_mm256_mask_storeu_epi16
(
out
+
18
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t9
,
0
));
_mm256_mask_storeu_epi16
(
out
+
19
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
t9
,
1
));
_mm256_mask_storeu_epi16
(
out
+
20
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
ta
,
0
));
_mm256_mask_storeu_epi16
(
out
+
21
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
ta
,
1
));
_mm256_mask_storeu_epi16
(
out
+
22
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tb
,
0
));
_mm256_mask_storeu_epi16
(
out
+
23
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tb
,
1
));
_mm256_mask_storeu_epi16
(
out
+
24
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tc
,
0
));
_mm256_mask_storeu_epi16
(
out
+
25
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tc
,
1
));
_mm256_mask_storeu_epi16
(
out
+
26
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
td
,
0
));
_mm256_mask_storeu_epi16
(
out
+
27
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
td
,
1
));
_mm256_mask_storeu_epi16
(
out
+
28
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
te
,
0
));
_mm256_mask_storeu_epi16
(
out
+
29
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
te
,
1
));
_mm256_mask_storeu_epi16
(
out
+
30
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tf
,
0
));
_mm256_mask_storeu_epi16
(
out
+
31
*
out_width
,
store_mask
,
LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64
(
tf
,
1
));
#else
LIBXSMM_UNUSED
(
in
);
LIBXSMM_UNUSED
(
out
);
LIBXSMM_UNUSED
(
col
);
LIBXSMM_UNUSED
(
ld_in
);
LIBXSMM_UNUSED
(
ld_out
);
#endif
}
LIBXSMM_API_INLINE
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
void
transpose_input_pixels_bf16
(
const
libxsmm_bfloat16
*
in
,
libxsmm_bfloat16
*
out
,
int
M
,
int
N
,
int
ld_in
,
int
ld_out
){
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
int
i
,
j
;
int
full16_chunks
=
N
/
16
;
int
remainder_cols
=
N
%
16
;
int
_N
=
N
-
remainder_cols
;
if
(
full16_chunks
)
{
for
(
i
=
0
;
i
<
M
;
i
+=
32
)
{
for
(
j
=
0
;
j
<
_N
;
j
+=
16
)
{
transpose_32x16
((
const
libxsmm_bfloat16
*
)
in
+
i
+
ld_in
*
j
,
(
libxsmm_bfloat16
*
)
out
+
j
+
i
*
ld_out
,
ld_in
,
ld_out
);
}
}
}
if
(
remainder_cols
)
{
for
(
i
=
0
;
i
<
M
;
i
+=
32
)
{
transpose_32xcols
((
const
libxsmm_bfloat16
*
)
in
+
i
+
ld_in
*
full16_chunks
*
16
,
(
libxsmm_bfloat16
*
)
out
+
full16_chunks
*
16
+
i
*
ld_out
,
remainder_cols
,
ld_in
,
ld_out
);
}
}
#else
LIBXSMM_UNUSED
(
in
);
LIBXSMM_UNUSED
(
out
);
LIBXSMM_UNUSED
(
M
);
LIBXSMM_UNUSED
(
N
);
LIBXSMM_UNUSED
(
ld_in
);
LIBXSMM_UNUSED
(
ld_out
);
#endif
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_strd
gemm_br_function
;
gemm_function
tile_config_kernel
=
handle
->
upd_config_kernel
;
gemm_function
gemm_kernel
=
NULL
;
gemm_br_function
br_gemm_kernel
=
NULL
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_addr
gemm_br_function
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
#endif
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CPX
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
typedef
libxsmm_bfloat16
element_input_type
;
typedef
libxsmm_bfloat16
element_output_type
;
typedef
libxsmm_bfloat16
element_filter_type
;
typedef
libxsmm_bsmmfunction
gemm_function
;
#define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
/* some portable macrros fof BF16 <-> FP32 */
# include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
typedef
libxsmm_bsmmfunction_reducebatch_strd
gemm_br_function
;
gemm_function
tile_config_kernel
=
handle
->
upd_config_kernel
;
gemm_function
gemm_kernel
=
NULL
;
gemm_br_function
br_gemm_kernel
=
NULL
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic_bf16_amx.tpl.c"
# include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
#undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
#else
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512_CORE
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
return
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
LIBXSMM_INTRINSICS
(
LIBXSMM_X86_AVX512
)
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
#else
/* should not happen */
LIBXSMM_UNUSED
(
handle
);
LIBXSMM_UNUSED
(
start_thread
);
LIBXSMM_UNUSED
(
tid
);
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH
;
#endif
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
grad_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
((
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
))
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_CPX
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CPX
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_amx
(
handle
,
start_thread
,
tid
);
}
#elif defined(LIBXSMM_INTRINSICS_AVX512_CORE)
/*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_CORE
&&
handle
->
target_archid
<
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu
(
handle
,
start_thread
,
tid
);
}
else
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_BF16
&&
handle
->
target_archid
>=
LIBXSMM_X86_AVX512_SPR
)
{
status
=
libxsmm_dnn_convolve_st_upd_custom_custom_bf16_bf16_emu_amx
(
handle
,
start_thread
,
tid
);
}
#endif
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
# include "template/libxsmm_dnn_convolve_st_upd_custom_custom_generic.tpl.c"
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
grad_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_custom_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_CUSTOM
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
)
{
libxsmm_dnn_err_t
status
=
LIBXSMM_DNN_SUCCESS
;
/* check if we have input, output and filter */
if
(
handle
->
reg_input
==
0
||
handle
->
grad_output
==
0
||
handle
->
grad_filter
==
0
||
handle
->
scratch
==
0
)
{
status
=
LIBXSMM_DNN_ERR_DATA_NOT_BOUND
;
return
status
;
}
/* check if we are on AVX512 */
#if defined(LIBXSMM_INTRINSICS_AVX512)
/*__AVX512F__*/
if
(
(
handle
->
target_archid
>=
LIBXSMM_X86_AVX512
)
&&
(
handle
->
target_archid
<=
LIBXSMM_X86_ALLFEAT
)
)
{
if
(
handle
->
desc
.
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
desc
.
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
status
=
libxsmm_dnn_convolve_st_upd_nhwc_rsck_f32_f32
(
handle
,
start_thread
,
tid
);
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
else
#endif
{
if
(
handle
->
datatype_in
==
LIBXSMM_DNN_DATATYPE_F32
&&
handle
->
datatype_out
==
LIBXSMM_DNN_DATATYPE_F32
)
{
typedef
float
element_input_type
;
typedef
float
element_output_type
;
typedef
float
element_filter_type
;
typedef
libxsmm_smmfunction
gemm_function
;
typedef
libxsmm_smmfunction_reducebatch_addr
gemm_br_function
;
#define LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
# include "template/libxsmm_dnn_convolve_st_upd_nhwc_custom-rsck_generic.tpl.c"
#undef LIBXSMM_DNN_TPL_UPD_DIRECT_GENERIC_NHWC_RSCK
}
else
{
status
=
LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE
;
return
status
;
}
}
return
status
;
}
third_party/libxsmm/src/libxsmm_dnn_convolution_weight_update.h
0 → 100644
View file @
c454d419
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Rajkishore Barik, Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H
#define LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H
#include <libxsmm_dnn_convolution.h>
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_custom_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_rsck
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
LIBXSMM_API_INTERN
libxsmm_dnn_err_t
libxsmm_dnn_convolve_st_upd_nhwc_custom
(
libxsmm_dnn_layer
*
handle
,
int
start_thread
,
int
tid
);
#endif
/* LIBXSMM_DNN_CONVOLUTION_WEIGHT_UPDATE_H */
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment