testutils.h 3.49 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2022 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6
#ifndef LIGHTGBM_TESTS_CPP_TESTS_TESTUTILS_H_
#define LIGHTGBM_TESTS_CPP_TESTS_TESTUTILS_H_
7
8
9
10
11
12
13
14
15
16
17
18

#include <LightGBM/c_api.h>
#include <LightGBM/dataset.h>

#include <vector>

using LightGBM::Metadata;

namespace LightGBM {

class TestUtils {
 public:
19
20
21
22
  /*!
   * Creates a Dataset from the internal repository examples.
   */
  static int LoadDatasetFromExamples(const char* filename, const char* config, DatasetHandle* out);
23
24


25
26
27
28
29
30
31
32
33
34
35
  /*!
   * Creates a dense Dataset of random values.
   */
  static void CreateRandomDenseData(int32_t nrows,
    int32_t ncols,
    int32_t nclasses,
    std::vector<double>* features,
    std::vector<float>* labels,
    std::vector<float>* weights,
    std::vector<double>* init_scores,
    std::vector<int32_t>* groups);
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
  /*!
   * Creates a CSR sparse Dataset of random values.
   */
  static void CreateRandomSparseData(int32_t nrows,
    int32_t ncols,
    int32_t nclasses,
    float sparse_percent,
    std::vector<int32_t>* indptr,
    std::vector<int32_t>* indices,
    std::vector<double>* values,
    std::vector<float>* labels,
    std::vector<float>* weights,
    std::vector<double>* init_scores,
    std::vector<int32_t>* groups);
51

52
53
54
55
56
57
58
59
60
  /*!
   * Creates a batch of Metadata of random values.
   */
  static void CreateRandomMetadata(int32_t nrows,
    int32_t nclasses,
    std::vector<float>* labels,
    std::vector<float>* weights,
    std::vector<double>* init_scores,
    std::vector<int32_t>* groups);
61

62
63
64
65
66
67
68
69
70
71
72
73
74
  /*!
   * Pushes nrows of data to a Dataset in batches of batch_count.
   */
  static void StreamDenseDataset(DatasetHandle dataset_handle,
    int32_t nrows,
    int32_t ncols,
    int32_t nclasses,
    int32_t batch_count,
    const std::vector<double>* features,
    const std::vector<float>* labels,
    const std::vector<float>* weights,
    const std::vector<double>* init_scores,
    const std::vector<int32_t>* groups);
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
  /*!
   * Pushes nrows of data to a Dataset in batches of batch_count.
   */
  static void StreamSparseDataset(DatasetHandle dataset_handle,
    int32_t nrows,
    int32_t nclasses,
    int32_t batch_count,
    const std::vector<int32_t>* indptr,
    const std::vector<int32_t>* indices,
    const std::vector<double>* values,
    const std::vector<float>* labels,
    const std::vector<float>* weights,
    const std::vector<double>* init_scores,
    const std::vector<int32_t>* groups);
90

91
92
93
94
95
96
97
98
  /*!
   * Validates metadata against reference vectors.
   */
  static void AssertMetadata(const Metadata* metadata,
    const std::vector<float>* labels,
    const std::vector<float>* weights,
    const std::vector<double>* init_scores,
    const std::vector<int32_t>* groups);
99

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
  static const double* CreateInitScoreBatch(std::vector<double>* init_score_batch,
    int32_t index,
    int32_t nrows,
    int32_t nclasses,
    int32_t batch_count,
    const std::vector<double>* original_init_scores);

 private:
  static void PushSparseBatch(DatasetHandle dataset_handle,
    int32_t nrows,
    int32_t nclasses,
    int32_t batch_count,
    const std::vector<int32_t>* indptr,
    const int32_t* indptr_ptr,
    const int32_t* indices_ptr,
    const double* values_ptr,
    const float* labels_ptr,
    const float* weights_ptr,
    const std::vector<double>* init_scores,
    const int32_t* groups_ptr,
    int32_t thread_count,
    int32_t thread_id);
122
123
};
}  // namespace LightGBM
124
#endif  // LIGHTGBM_TESTS_CPP_TESTS_TESTUTILS_H_