libxsmm_dnn.h 6.87 KB
Newer Older
lisj's avatar
lisj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved.                      *
* This file is part of the LIBXSMM library.                                   *
*                                                                             *
* For information on the license, see the LICENSE file.                       *
* Further information: https://github.com/hfp/libxsmm/                        *
* SPDX-License-Identifier: BSD-3-Clause                                       *
******************************************************************************/
/* Alexander Heinecke, Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_H
#define LIBXSMM_DNN_H

#include "libxsmm_typedefs.h"

typedef unsigned int libxsmm_dnn_err_t;

/** Define error and warning codes */
#define LIBXSMM_DNN_SUCCESS                             0

#define LIBXSMM_DNN_WARN_FALLBACK                   90000
#define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING  90001
#define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING  90002
#define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING  90003
#define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING   90004
#define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING   90005
#define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING   90006

#define LIBXSMM_DNN_ERR_GENERAL                    100000
#define LIBXSMM_DNN_ERR_CREATE_HANDLE              100001
#define LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE       100002
#define LIBXSMM_DNN_ERR_INVALID_BLOCKING           100003
#define LIBXSMM_DNN_ERR_INVALID_HANDLE             100004
#define LIBXSMM_DNN_ERR_DATA_NOT_BOUND             100005
#define LIBXSMM_DNN_ERR_CREATE_TENSOR              100006
#define LIBXSMM_DNN_ERR_INVALID_TENSOR             100007
#define LIBXSMM_DNN_ERR_MISMATCH_TENSOR            100008
#define LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR      100009
#define LIBXSMM_DNN_ERR_INVALID_KIND               100010
#define LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW        100011
#define LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT     100012
#define LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT     100013
#define LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE    100014
#define LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS        100015
#define LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL     100016
#define LIBXSMM_DNN_ERR_CREATE_LAYOUT              100017
#define LIBXSMM_DNN_ERR_INVALID_LAYOUT             100018
#define LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH           100019
#define LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED        100020
#define LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE        100021
#define LIBXSMM_DNN_ERR_INVALID_ALGO               100022
#define LIBXSMM_DNN_ERR_INVALID_PADDING            100023
#define LIBXSMM_DNN_ERR_UNKNOWN_BIAS_TYPE          100024
#define LIBXSMM_DNN_ERR_MISMATCH_BIAS              100025
#define LIBXSMM_DNN_ERR_INVALID_HANDLE_BIAS        100026
#define LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL       100027
#define LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS       100028
#define LIBXSMM_DNN_ERR_NOT_IMPLEMENTED            100029
#define LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER  100030
#define LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION 100031
#define LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN     100032
#define LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING        100033
#define LIBXSMM_DNN_ERR_INVALID_FORMAT_FC          100034
#define LIBXSMM_DNN_ERR_INVALID_RNN_TYPE           100035
#define LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN        100036
#define LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER  100037
#define LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION 100038
#define LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION      100039

/** Kinds of supported compute flavor operations. */
typedef enum libxsmm_dnn_compute_kind {
  /** Forward path */
  LIBXSMM_DNN_COMPUTE_KIND_FWD,
  /** Backward path */
  LIBXSMM_DNN_COMPUTE_KIND_BWD,
  /** Updated weights. */
  LIBXSMM_DNN_COMPUTE_KIND_UPD,
  /** Backward and weightupdate combined, useful for RNNs */
  LIBXSMM_DNN_COMPUTE_KIND_BWDUPD,
  /** All routines, need for some init routines. */
  LIBXSMM_DNN_COMPUTE_KIND_ALL
} libxsmm_dnn_compute_kind;

/** these are some quantization definitions, not sure if we want to
    move them into some main part of LIBXSMM */
/* @TODO check position of these declarations and defines */
typedef union LIBXSMM_RETARGETABLE libxsmm_intfloat {
  unsigned int ui;
  float f;
} libxsmm_intfloat;

/* F32 masking defines */
#define LIBXSNN_DNN_MASK_SIGN_F32      0x80000000
#define LIBXSMM_DNN_MASK_EXP_F32       0x7f800000
#define LIBXSMM_DNN_MASK_MANT_F32      0x007fffff
#define LIBXSMM_DNN_MASK_ABS_F32       0x7fffffff
#define LIBXSMM_DNN_MASK_FULL_F32      0xffffffff
#define LIBXSMM_DNN_MANT_SZ_F32        23
#define LIBXSMM_DNN_SZ_F32             32

/* DFP16 masking defines */
#define LIBXSMM_DNN_MANT_DFP16         15
#define LIXSMMM_DNN_RES_DFP16          libxsmm_sexp2_i8i(-(LIBXSMM_DNN_MANT_DFP16))

/* Quantization Rounding Defines */
#define LIBXSMM_DNN_QUANT_NO_ROUND       80000
#define LIBXSMM_DNN_QUANT_BIAS_ROUND     80001
#define LIBXSMM_DNN_QUANT_STOCH_ROUND    80002
#define LIBXSMM_DNN_QUANT_NEAREST_ROUND  80003
#define LIBXSMM_DNN_QUANT_FPHW_ROUND     80004

/** get string of error code */
LIBXSMM_API const char* libxsmm_dnn_get_error(libxsmm_dnn_err_t code);
LIBXSMM_API size_t libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype);
LIBXSMM_API size_t libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype);

/** some quantization helper functions,
    @TODO need to be integrated better for all different ways of quantizations */
LIBXSMM_API void libxsmm_dnn_quantize( float* in_buffer, short* out_buffer, int length, unsigned char add_shift, unsigned char* scf, int round_mode );
LIBXSMM_API void libxsmm_dnn_quantize_act( float* in_buffer, short* out_buffer, unsigned int N, unsigned int C, unsigned int H, unsigned int W, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode );
LIBXSMM_API void libxsmm_dnn_quantize_fil( float* in_buffer, short* out_buffer, unsigned int K, unsigned int C, unsigned int R, unsigned int S, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int kblk_f32, unsigned int kblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode );
LIBXSMM_API void libxsmm_dnn_dequantize( short* in_buffer, float* out_buffer, int length, unsigned char scf );

/** some BF16<->FP32 conversion functions
    @TODO we need to find a final place for those */
LIBXSMM_API void libxsmm_truncate_convert_f32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int length);
LIBXSMM_API void libxsmm_rnaz_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len);
LIBXSMM_API void libxsmm_rne_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len);
LIBXSMM_API void libxsmm_convert_bf16_f32(const libxsmm_bfloat16* in, float* out, unsigned int length);

#endif /*LIBXSMM_DNN_H*/