testutils.h 3.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/*!
 * Copyright (c) 2022 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
#ifndef LIGHTGBM_TESTUTILS_H_
#define LIGHTGBM_TESTUTILS_H_

#include <LightGBM/c_api.h>
#include <LightGBM/dataset.h>

#include <vector>

using LightGBM::Metadata;

namespace LightGBM {

class TestUtils {
 public:
    /*!
    * Creates a Dataset from the internal repository examples.
    */
    static int LoadDatasetFromExamples(const char* filename, const char* config, DatasetHandle* out);


    /*!
    * Creates a dense Dataset of random values.
    */
    static void CreateRandomDenseData(int32_t nrows,
      int32_t ncols,
      int32_t nclasses,
      std::vector<double>* features,
      std::vector<float>* labels,
      std::vector<float>* weights,
      std::vector<double>* init_scores,
      std::vector<int32_t>* groups);

    /*!
    * Creates a CSR sparse Dataset of random values.
    */
    static void CreateRandomSparseData(int32_t nrows,
      int32_t ncols,
      int32_t nclasses,
      float sparse_percent,
      std::vector<int32_t>* indptr,
      std::vector<int32_t>* indices,
      std::vector<double>* values,
      std::vector<float>* labels,
      std::vector<float>* weights,
      std::vector<double>* init_scores,
      std::vector<int32_t>* groups);

    /*!
    * Creates a batch of Metadata of random values.
    */
    static void CreateRandomMetadata(int32_t nrows,
      int32_t nclasses,
      std::vector<float>* labels,
      std::vector<float>* weights,
      std::vector<double>* init_scores,
      std::vector<int32_t>* groups);

    /*!
    * Pushes nrows of data to a Dataset in batches of batch_count.
    */
    static void StreamDenseDataset(DatasetHandle dataset_handle,
      int32_t nrows,
      int32_t ncols,
      int32_t nclasses,
      int32_t batch_count,
      const std::vector<double>* features,
      const std::vector<float>* labels,
      const std::vector<float>* weights,
      const std::vector<double>* init_scores,
      const std::vector<int32_t>* groups);

    /*!
    * Pushes nrows of data to a Dataset in batches of batch_count.
    */
    static void StreamSparseDataset(DatasetHandle dataset_handle,
      int32_t nrows,
      int32_t nclasses,
      int32_t batch_count,
      const std::vector<int32_t>* indptr,
      const std::vector<int32_t>* indices,
      const std::vector<double>* values,
      const std::vector<float>* labels,
      const std::vector<float>* weights,
      const std::vector<double>* init_scores,
      const std::vector<int32_t>* groups);

    /*!
    * Validates metadata against reference vectors.
    */
    static void AssertMetadata(const Metadata* metadata,
      const std::vector<float>* labels,
      const std::vector<float>* weights,
      const std::vector<double>* init_scores,
      const std::vector<int32_t>* groups);

    static const double* CreateInitScoreBatch(std::vector<double>* init_score_batch,
      int32_t index,
      int32_t nrows,
      int32_t nclasses,
      int32_t batch_count,
      const std::vector<double>* original_init_scores);
};
}  // namespace LightGBM
#endif  // LIGHTGBM_TESTUTILS_H_