libxsmm_spmdm.h 3.8 KB
Newer Older
lisj's avatar
lisj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved.                      *
* This file is part of the LIBXSMM library.                                   *
*                                                                             *
* For information on the license, see the LICENSE file.                       *
* Further information: https://github.com/hfp/libxsmm/                        *
* SPDX-License-Identifier: BSD-3-Clause                                       *
******************************************************************************/
/* Nadathur Satish (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_SPMDM_H
#define LIBXSMM_SPMDM_H

#include "libxsmm_typedefs.h"


typedef enum libxsmm_spmdm_datatype {
  LIBXSMM_SPMDM_DATATYPE_F32,
  LIBXSMM_SPMDM_DATATYPE_BFLOAT16
} libxsmm_spmdm_datatype;

LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_spmdm_handle {
  /* The following are the matrix multiply dimensions: A (sparse): m X k, B (dense): k X n, Output C (dense): m X n */
  int m;
  int n;
  int k;
  /* The block sizes for A, B and C. */
  /* Here we fix A to be divided into 128 X 128 blocks, B/C to be 128 X 48 for HSW/BDW and 128 X 96 for SKX */
  int bm;
  int bn;
  int bk;
  /* The number of blocks for the m, n and k dimensions */
  int mb;
  int nb;
  int kb;
  libxsmm_spmdm_datatype datatype;
  char* base_ptr_scratch_A;
  char* base_ptr_scratch_B_scratch_C;
  int memory_for_scratch_per_thread;
} libxsmm_spmdm_handle;

/**
 * This stores a single sparse splice (or block) of sparse matrix A using a CSR representation (rowidx, colidx, and values
 * Each splice corresponds to a bm X bk region of A, and stores local indexes
 */
LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_CSR_sparseslice {
  /* Since bm and bk are assumed to be <=256, a 16-bit integer is enough to store the local rowidx, colidx */
  uint16_t* rowidx;
  uint16_t* colidx;
  float*    values;
} libxsmm_CSR_sparseslice;


LIBXSMM_API void libxsmm_spmdm_init(
  int M, int N, int K,
  int max_threads,
  libxsmm_spmdm_handle* handle,
  libxsmm_CSR_sparseslice** libxsmm_output_csr);

LIBXSMM_API void libxsmm_spmdm_destroy(
  libxsmm_spmdm_handle* handle);

LIBXSMM_API int libxsmm_spmdm_get_num_createSparseSlice_blocks(
  const libxsmm_spmdm_handle* handle);

LIBXSMM_API int libxsmm_spmdm_get_num_compute_blocks(
  const libxsmm_spmdm_handle* handle);

/** This converts a dense representation of the sparse matrix to 2D array of sparse slices. */
LIBXSMM_API void libxsmm_spmdm_createSparseSlice_fp32_thread(
  const libxsmm_spmdm_handle* handle,
  char transa,
  const float* a,
  libxsmm_CSR_sparseslice* libxsmm_output_csr_a,
  int block_id,
  int tid, int nthreads);

LIBXSMM_API void libxsmm_spmdm_createSparseSlice_bfloat16_thread(
  const libxsmm_spmdm_handle* handle,
  char transa,
  const libxsmm_bfloat16* a,
  libxsmm_CSR_sparseslice* libxsmm_output_csr_a,
  int block_id,
  int tid, int nthreads);

/** NOTE: This code currently ignores alpha input to the matrix multiply */
LIBXSMM_API void libxsmm_spmdm_compute_fp32_thread(
  const libxsmm_spmdm_handle* handle,
  char transa,
  char transb,
  const float* alpha,
  libxsmm_CSR_sparseslice* a_sparse,
  const float* b,
  char transc,
  const float* beta,
  float* c,
  int block_id,
  int tid, int nthreads);

/** NOTE: This code currently ignores alpha input to the matrix multiply */
LIBXSMM_API void libxsmm_spmdm_compute_bfloat16_thread(
  const libxsmm_spmdm_handle* handle,
  char transa,
  char transb,
  const libxsmm_bfloat16* alpha,
  libxsmm_CSR_sparseslice* a_sparse,
  const libxsmm_bfloat16* b,
  char transc,
  const libxsmm_bfloat16* beta,
  float* c,
  int block_id,
  int tid, int nthreads);

#endif /*LIBXSMM_SPMDM_H*/