libxsmm_matrixeqn.h 6.43 KB
Newer Older
lisj's avatar
lisj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved.                      *
* This file is part of the LIBXSMM library.                                   *
*                                                                             *
* For information on the license, see the LICENSE file.                       *
* Further information: https://github.com/hfp/libxsmm/                        *
* SPDX-License-Identifier: BSD-3-Clause                                       *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_MATRIXEQN_H
#define LIBXSMM_MATRIXEQN_H

#define LEFT 0
#define RIGHT 1
#define RIGHT2 2

#include <libxsmm.h>
/**
 * TF includes src/libxsmm_main.h and uses LIBXSMM's sync primitives
 * without including libxsmm_sync. However, libxsmm_sync.h shall be
 * an explicit include separate from including libxsmm.h.
 */
#include "libxsmm_sync.h"

LIBXSMM_EXTERN_C typedef enum libxsmm_matrix_eqn_node_type {
  LIBXSMM_MATRIX_EQN_NODE_NONE    = 0,
  LIBXSMM_MATRIX_EQN_NODE_UNARY   = 1,
  LIBXSMM_MATRIX_EQN_NODE_BINARY  = 2,
  LIBXSMM_MATRIX_EQN_NODE_TERNARY = 4,
  LIBXSMM_MATRIX_EQN_NODE_ARG     = 8
} libxsmm_matrix_eqn_node_type;

LIBXSMM_EXTERN_C typedef enum libxsmm_matrix_eqn_bcast_type {
  LIBXSMM_MATRIX_EQN_BCAST_TYPE_NONE   = 0,
  LIBXSMM_MATRIX_EQN_BCAST_TYPE_ROW    = 1,
  LIBXSMM_MATRIX_EQN_BCAST_TYPE_COL    = 2,
  LIBXSMM_MATRIX_EQN_BCAST_TYPE_SCALAR = 4
} libxsmm_matrix_eqn_bcast_type;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_unary_op {
  libxsmm_meltw_unary_type  type;
  libxsmm_meltw_unary_flags flags;
  libxsmm_datatype          dtype;
} libxsmm_matrix_eqn_unary_op;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_binary_op {
  libxsmm_meltw_binary_type  type;
  libxsmm_meltw_binary_flags flags;
  libxsmm_datatype           dtype;
} libxsmm_matrix_eqn_binary_op;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_ternary_op {
  libxsmm_meltw_ternary_type  type;
  libxsmm_meltw_ternary_flags flags;
  libxsmm_datatype            dtype;
} libxsmm_matrix_eqn_ternary_op;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_arg {
  libxsmm_blasint  m;
  libxsmm_blasint  n;
  libxsmm_blasint  ld;
  libxsmm_blasint  in_pos;
  libxsmm_blasint  offs_in_pos;
  libxsmm_datatype dtype;
  libxsmm_matrix_eqn_bcast_type  bcast_type;
} libxsmm_matrix_eqn_arg;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_tmp_info {
  libxsmm_blasint  id;
  libxsmm_blasint  m;
  libxsmm_blasint  n;
  libxsmm_blasint  ld;
  libxsmm_datatype dtype;
  libxsmm_matrix_eqn_bcast_type  bcast_type;
  libxsmm_blasint  m_s;
  libxsmm_blasint  n_s;
  libxsmm_blasint  ld_s;
  libxsmm_datatype dtype_s;
  libxsmm_matrix_eqn_bcast_type  bcast_type_s;
  libxsmm_blasint  m_t;
  libxsmm_blasint  n_t;
  libxsmm_blasint  ld_t;
  libxsmm_datatype dtype_t;
  libxsmm_matrix_eqn_bcast_type  bcast_type_t;
} libxsmm_matrix_eqn_tmp_info;

LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_matrix_eqn_info {
  libxsmm_matrix_eqn_unary_op   u_op;
  libxsmm_matrix_eqn_binary_op  b_op;
  libxsmm_matrix_eqn_ternary_op t_op;
  libxsmm_matrix_eqn_arg        arg;
} libxsmm_matrix_eqn_info;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn_elem {
  struct libxsmm_matrix_eqn_elem* le;
  struct libxsmm_matrix_eqn_elem* ri;
  struct libxsmm_matrix_eqn_elem* r2;
  struct libxsmm_matrix_eqn_elem* up;
  libxsmm_matrix_eqn_node_type    type;
  libxsmm_matrix_eqn_info         info;
  libxsmm_blasint                 reg_score;
  libxsmm_blasint                 visit_timestamp;
  libxsmm_matrix_eqn_tmp_info     tmp;
  libxsmm_blasint                 max_tmp_size;
  libxsmm_blasint                 n_args;
  libxsmm_blasint                 tree_max_comp_tsize;
} libxsmm_matrix_eqn_elem;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_matrix_eqn {
  libxsmm_matrix_eqn_elem*        eqn_root;
  libxsmm_matrix_eqn_elem*        eqn_cur;
  libxsmm_blasint                 is_constructed;
  libxsmm_blasint                 is_optimized;
  libxsmm_blasint                 unary_only;
  libxsmm_blasint                 binary_only;
} libxsmm_matrix_eqn;

/* Helper functions for matrix equation handling */
LIBXSMM_API_INTERN libxsmm_matrix_eqn* libxsmm_matrix_eqn_get_equation( libxsmm_blasint eqn_idx );
LIBXSMM_API_INTERN int libxsmm_matrix_eqn_is_ready_for_jit( libxsmm_blasint eqn_idx );
LIBXSMM_API_INTERN void libxsmm_matrix_eqn_assign_reg_scores( libxsmm_matrix_eqn_elem* cur_node );
LIBXSMM_API_INTERN void libxsmm_matrix_eqn_create_exec_plan( libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *global_timestamp, libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool );
LIBXSMM_API_INTERN libxsmm_blasint reserve_tmp_storage(libxsmm_blasint n_max_tmp, libxsmm_blasint *tmp_storage_pool);
LIBXSMM_API_INTERN void libxsmm_generator_assign_new_timestamp(libxsmm_matrix_eqn_elem* cur_node, libxsmm_blasint *current_timestamp );
LIBXSMM_API_INTERN void libxsmm_generator_matequation_assign_timestamps(libxsmm_matrix_eqn *eqn);
LIBXSMM_API_INTERN void libxsmm_generator_reoptimize_eqn(libxsmm_matrix_eqn *eqn);
LIBXSMM_API_INTERN void libxsmm_matrix_eqn_adjust_tmp_sizes( libxsmm_matrix_eqn_elem* cur_node );
LIBXSMM_API_INTERN int is_unary_opcode_reduce_kernel (unsigned int opcode);
LIBXSMM_API_INTERN int is_unary_opcode_transform_kernel (unsigned int opcode);
LIBXSMM_API_INTERN int is_unary_opcode_reduce_to_scalar (unsigned int opcode);
LIBXSMM_API_INTERN int is_binary_opcode_reduce_to_scalar (unsigned int opcode);

LIBXSMM_API_INTERN
libxsmm_matrix_eqn_bcast_type get_bcast_type_unary(libxsmm_meltw_unary_flags flags);

LIBXSMM_API_INTERN
libxsmm_matrix_eqn_bcast_type get_bcast_type_binary(libxsmm_meltw_binary_flags flags, unsigned int side);

LIBXSMM_API_INTERN
libxsmm_matrix_eqn_bcast_type get_bcast_type_ternary(libxsmm_meltw_ternary_flags flags, unsigned int side);

LIBXSMM_API_INTERN void libxsmm_matrix_eqn_reassign_bcast_tmp(libxsmm_matrix_eqn *eqn);
LIBXSMM_API_INTERN void libxsmm_matrix_eqn_reassign_children_bcast_tmp(libxsmm_matrix_eqn *eqn, libxsmm_matrix_eqn_elem* cur_node);


#endif /*LIBXSMM_MATRIXEQN_H*/