libxsmm_dnn_fullyconnected.h 4.07 KB
Newer Older
lisj's avatar
lisj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved.                      *
* This file is part of the LIBXSMM library.                                   *
*                                                                             *
* For information on the license, see the LICENSE file.                       *
* Further information: https://github.com/hfp/libxsmm/                        *
* SPDX-License-Identifier: BSD-3-Clause                                       *
******************************************************************************/
/* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_DNN_FULLYCONNECTED_H
#define LIBXSMM_DNN_FULLYCONNECTED_H

#include "libxsmm_dnn.h"
#include "libxsmm_dnn_tensor.h"

/** Opaque handles which represents LIBXSMM fullyconnected */
LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fullyconnected libxsmm_dnn_fullyconnected;

typedef enum libxsmm_dnn_fullyconnected_fuse_op {
  /* the fuse order is: 1. BIAS, 2. Actitvation */
  LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE = 0,
  LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS = 1,
  LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU = 2,
  LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID = 4,
  LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU = 3,
  LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID = 5
} libxsmm_dnn_fullyconnected_fuse_op;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fullyconnected_desc {
  int N;                                        /* number of images in mini-batch */
  int C;                                        /* number of input feature maps */
  int K;                                        /* number of output feature maps */
  int bn;
  int bk;
  int bc;
  int threads;                                  /* number of threads used */
  int compressed_A;
  int sparsity_factor_A;
  libxsmm_dnn_datatype datatype_in;             /* datatype used for all input related buffers */
  libxsmm_dnn_datatype datatype_out;            /* datatype used for all output related buffers */
  libxsmm_dnn_tensor_format buffer_format;      /* format which is for activation buffers */
  libxsmm_dnn_tensor_format filter_format;      /* format which is for filter buffers */
  libxsmm_dnn_fullyconnected_fuse_op fuse_ops;  /* fused operations */
} libxsmm_dnn_fullyconnected_desc;

LIBXSMM_API libxsmm_dnn_fullyconnected* libxsmm_dnn_create_fullyconnected(libxsmm_dnn_fullyconnected_desc fullyconnected_desc, libxsmm_dnn_err_t* status);
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fullyconnected(const libxsmm_dnn_fullyconnected* handle);

LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fullyconnected_create_tensor_datalayout(const libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status);

LIBXSMM_API void*  libxsmm_dnn_fullyconnected_get_scratch_ptr (const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status);
LIBXSMM_API size_t libxsmm_dnn_fullyconnected_get_scratch_size(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status);
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_scratch(libxsmm_dnn_fullyconnected* handle, const void* scratch);
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_scratch(libxsmm_dnn_fullyconnected* handle);

LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type);
LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fullyconnected_get_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status);
LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type);

LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_execute_st(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind,
  /*unsigned*/int start_thread, /*unsigned*/int tid);

#endif /*LIBXSMM_DNN_FULLYCONNECTED_H*/