libxsmm_dnn_optimizer_sgd.c 4.05 KB
Newer Older
lisj's avatar
lisj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved.                      *
* This file is part of the LIBXSMM library.                                   *
*                                                                             *
* For information on the license, see the LICENSE file.                       *
* Further information: https://github.com/hfp/libxsmm/                        *
* SPDX-License-Identifier: BSD-3-Clause                                       *
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "libxsmm_dnn_optimizer_sgd.h"
#include "libxsmm_main.h"


LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_f32_f32(libxsmm_dnn_optimizer* handle, int start_thread, int tid);
LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_bf16_bf16(libxsmm_dnn_optimizer* handle, int start_thread, int tid);


LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_f32_f32(libxsmm_dnn_optimizer* handle, int start_thread, int tid)
{
  libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
  typedef float element_filter_type;

# define LIBXSMM_DNN_OPTIMIZER_SGD_F32_AVX512
# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c"
# undef LIBXSMM_DNN_OPTIMIZER_SGD_F32_AVX512
#else /* should not happen */
  LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
  status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
  return status;
}


LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st_bf16_bf16(libxsmm_dnn_optimizer* handle, int start_thread, int tid)
{
  libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
  typedef libxsmm_bfloat16 element_filter_type;
  typedef float            element_master_type;

# define LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512
# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c"
# undef LIBXSMM_DNN_OPTIMIZER_SGD_BF16_AVX512
#else /* should not happen */
  LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
  status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
#endif
  return status;
}


LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_optimizer_sgd_st(libxsmm_dnn_optimizer* handle, int start_thread, int tid)
{
  libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;

  /* check if we have filter, grad_filter */
  if ( handle->reg_filter == 0 || handle->grad_filter == 0 ) {
    status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
    return status;
  }
  if ( (handle->master_filter == 0) && (handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16) ) {
    status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
    return status;
  }

  /* check if we are on an AVX512 platform */
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
  if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
    if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
      status = libxsmm_dnn_optimizer_sgd_st_f32_f32( handle, start_thread, tid);
    } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
      status = libxsmm_dnn_optimizer_sgd_st_bf16_bf16( handle, start_thread, tid);
    } else {
      status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
      return status;
    }
  } else
#endif
  {
    if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_F32 ) {
      typedef float element_filter_type;

# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c"
    } else if ( handle->desc.datatype == LIBXSMM_DNN_DATATYPE_BF16 ) {
      typedef libxsmm_bfloat16 element_filter_type;
      typedef float            element_master_type;

# define LIBXSMM_DNN_OPTIMIZER_SGD_BF16
# include "template/libxsmm_dnn_optimizer_sgd_st_generic.tpl.c"
# undef LIBXSMM_DNN_OPTIMIZER_SGD_BF16
    } else {
      status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
      return status;
    }
  }

  return status;
}