metadata.cpp 20.3 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6
#include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
7
8

#include <string>
9
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
10
11
12

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
13
Metadata::Metadata() {
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
  num_weights_ = 0;
  num_init_score_ = 0;
  num_data_ = 0;
  num_queries_ = 0;
18
19
20
  weight_load_from_file_ = false;
  query_load_from_file_ = false;
  init_score_load_from_file_ = false;
21
22
23
  #ifdef USE_CUDA_EXP
  cuda_metadata_ = nullptr;
  #endif  // USE_CUDA_EXP
Guolin Ke's avatar
Guolin Ke committed
24
25
}

26
void Metadata::Init(const char* data_filename) {
Guolin Ke's avatar
Guolin Ke committed
27
  data_filename_ = data_filename;
28
  // for lambdarank, it needs query data for partition data in distributed learning
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  LoadQueryBoundaries();
  LoadWeights();
  LoadQueryWeights();
32
  LoadInitialScore(data_filename_);
Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
}

Metadata::~Metadata() {
}

38
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
Guolin Ke's avatar
Guolin Ke committed
39
  num_data_ = num_data;
40
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
41
  if (weight_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
42
    if (!weights_.empty()) {
43
      Log::Info("Using weights in data file, ignoring the additional weights file");
Guolin Ke's avatar
Guolin Ke committed
44
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
45
    }
46
    weights_ = std::vector<label_t>(num_data_, 0.0f);
Guolin Ke's avatar
Guolin Ke committed
47
    num_weights_ = num_data_;
48
    weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
49
50
  }
  if (query_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
51
    if (!query_boundaries_.empty()) {
52
      Log::Info("Using query id in data file, ignoring the additional query file");
Guolin Ke's avatar
Guolin Ke committed
53
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
54
    }
Guolin Ke's avatar
Guolin Ke committed
55
    if (!query_weights_.empty()) { query_weights_.clear(); }
56
    queries_ = std::vector<data_size_t>(num_data_, 0);
57
    query_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
58
  }
Guolin Ke's avatar
Guolin Ke committed
59
60
}

Guolin Ke's avatar
Guolin Ke committed
61
62
63
void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) {
  num_data_ = num_used_indices;

64
  label_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
65
#pragma omp parallel for schedule(static, 512) if (num_used_indices >= 1024)
66
  for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
67
68
69
    label_[i] = fullset.label_[used_indices[i]];
  }

Guolin Ke's avatar
Guolin Ke committed
70
  if (!fullset.weights_.empty()) {
71
    weights_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
72
    num_weights_ = num_used_indices;
Guolin Ke's avatar
Guolin Ke committed
73
#pragma omp parallel for schedule(static, 512) if (num_used_indices >= 1024)
74
    for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
      weights_[i] = fullset.weights_[used_indices[i]];
    }
  } else {
    num_weights_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
81
  if (!fullset.init_score_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
82
    int num_class = static_cast<int>(fullset.num_init_score_ / fullset.num_data_);
83
    init_score_ = std::vector<double>(static_cast<size_t>(num_used_indices) * num_class);
Guolin Ke's avatar
Guolin Ke committed
84
    num_init_score_ = static_cast<int64_t>(num_used_indices) * num_class;
85
    #pragma omp parallel for schedule(static)
86
    for (int k = 0; k < num_class; ++k) {
87
88
89
90
      const size_t offset_dest = static_cast<size_t>(k) * num_data_;
      const size_t offset_src = static_cast<size_t>(k) * fullset.num_data_;
      for (data_size_t i = 0; i < num_used_indices; ++i) {
        init_score_[offset_dest + i] = fullset.init_score_[offset_src + used_indices[i]];
91
      }
Guolin Ke's avatar
Guolin Ke committed
92
93
94
95
96
    }
  } else {
    num_init_score_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
97
  if (!fullset.query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    std::vector<data_size_t> used_query;
    data_size_t data_idx = 0;
    for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_indices; ++qid) {
      data_size_t start = fullset.query_boundaries_[qid];
      data_size_t end = fullset.query_boundaries_[qid + 1];
      data_size_t len = end - start;
      if (used_indices[data_idx] > start) {
        continue;
      } else if (used_indices[data_idx] == start) {
        if (num_used_indices >= data_idx + len && used_indices[data_idx + len - 1] == end - 1) {
          used_query.push_back(qid);
          data_idx += len;
        } else {
          Log::Fatal("Data partition error, data didn't match queries");
        }
      } else {
        Log::Fatal("Data partition error, data didn't match queries");
      }
    }
    query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
    num_queries_ = static_cast<data_size_t>(used_query.size());
    query_boundaries_[0] = 0;
    for (data_size_t i = 0; i < num_queries_; ++i) {
      data_size_t qid = used_query[i];
      data_size_t len = fullset.query_boundaries_[qid + 1] - fullset.query_boundaries_[qid];
      query_boundaries_[i + 1] = query_boundaries_[i] + len;
    }
  } else {
    num_queries_ = 0;
  }
}

Guolin Ke's avatar
Guolin Ke committed
130
void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
Guolin Ke's avatar
Guolin Ke committed
131
  if (used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
132
133
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
134
  auto old_label = label_;
Guolin Ke's avatar
Guolin Ke committed
135
  num_data_ = static_cast<data_size_t>(used_indices.size());
136
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
137
#pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
138
139
140
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = old_label[used_indices[i]];
  }
Guolin Ke's avatar
Guolin Ke committed
141
  old_label.clear();
Guolin Ke's avatar
Guolin Ke committed
142
143
144
}

void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
145
146
  if (used_data_indices.empty()) {
    if (!queries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
Guolin Ke's avatar
Guolin Ke committed
162
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
163
164
165
166
167
168
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
169
      queries_.clear();
Guolin Ke's avatar
Guolin Ke committed
170
    }
Guolin Ke's avatar
Guolin Ke committed
171
    // check weights
Guolin Ke's avatar
Guolin Ke committed
172
    if (!weights_.empty() && num_weights_ != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
173
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
174
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
175
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
176
177
178
    }

    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
179
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
180
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
181
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
182
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
183
184
185
    }

    // contain initial score file
186
    if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
Guolin Ke's avatar
Guolin Ke committed
187
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
188
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
189
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
190
191
    }
  } else {
192
    if (!queries_.empty()) {
193
      Log::Fatal("Cannot used query_id for distributed training");
194
    }
Guolin Ke's avatar
Guolin Ke committed
195
196
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
197
198
199
200
201
202
203
204
205
206
    if (weight_load_from_file_) {
      if (weights_.size() > 0 && num_weights_ != num_all_data) {
        weights_.clear();
        num_weights_ = 0;
        Log::Fatal("Weights size doesn't match data size");
      }
      // get local weights
      if (!weights_.empty()) {
        auto old_weights = weights_;
        num_weights_ = num_data_;
207
        weights_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
208
#pragma omp parallel for schedule(static, 512)
209
210
211
212
        for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
          weights_[i] = old_weights[used_data_indices[i]];
        }
        old_weights.clear();
Guolin Ke's avatar
Guolin Ke committed
213
214
      }
    }
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    if (query_load_from_file_) {
      // check query boundries
      if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
        query_boundaries_.clear();
        num_queries_ = 0;
        Log::Fatal("Query size doesn't match data size");
      }
      // get local query boundaries
      if (!query_boundaries_.empty()) {
        std::vector<data_size_t> used_query;
        data_size_t data_idx = 0;
        for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
          data_size_t start = query_boundaries_[qid];
          data_size_t end = query_boundaries_[qid + 1];
          data_size_t len = end - start;
          if (used_data_indices[data_idx] > start) {
            continue;
          } else if (used_data_indices[data_idx] == start) {
            if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
              used_query.push_back(qid);
              data_idx += len;
            } else {
              Log::Fatal("Data partition error, data didn't match queries");
            }
Guolin Ke's avatar
Guolin Ke committed
239
          } else {
Guolin Ke's avatar
Guolin Ke committed
240
            Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
241
242
          }
        }
243
244
245
246
247
248
249
250
251
252
        auto old_query_boundaries = query_boundaries_;
        query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
        num_queries_ = static_cast<data_size_t>(used_query.size());
        query_boundaries_[0] = 0;
        for (data_size_t i = 0; i < num_queries_; ++i) {
          data_size_t qid = used_query[i];
          data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
          query_boundaries_[i + 1] = query_boundaries_[i] + len;
        }
        old_query_boundaries.clear();
Guolin Ke's avatar
Guolin Ke committed
253
254
      }
    }
255
256
257
258
259
260
261
    if (init_score_load_from_file_) {
      // contain initial score file
      if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
        init_score_.clear();
        num_init_score_ = 0;
        Log::Fatal("Initial score size doesn't match data size");
      }
Guolin Ke's avatar
Guolin Ke committed
262

263
264
265
266
267
268
      // get local initial scores
      if (!init_score_.empty()) {
        auto old_scores = init_score_;
        int num_class = static_cast<int>(num_init_score_ / num_all_data);
        num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
        init_score_ = std::vector<double>(num_init_score_);
Guolin Ke's avatar
Guolin Ke committed
269
#pragma omp parallel for schedule(static)
270
        for (int k = 0; k < num_class; ++k) {
271
272
          const size_t offset_dest = static_cast<size_t>(k) * num_data_;
          const size_t offset_src = static_cast<size_t>(k) * num_all_data;
273
          for (size_t i = 0; i < used_data_indices.size(); ++i) {
274
            init_score_[offset_dest + i] = old_scores[offset_src + used_data_indices[i]];
275
          }
276
        }
277
        old_scores.clear();
Guolin Ke's avatar
Guolin Ke committed
278
279
280
281
282
      }
    }
    // re-load query weight
    LoadQueryWeights();
  }
283
284
285
286
  if (num_queries_ > 0) {
    Log::Debug("Number of queries in %s: %i. Average number of rows per query: %f.",
      data_filename_.c_str(), static_cast<int>(num_queries_), static_cast<double>(num_data_) / num_queries_);
  }
Guolin Ke's avatar
Guolin Ke committed
287
288
}

289
290
291
292
293
294
295
296
297
298
299
void Metadata::SetInitScore(const double* init_score, data_size_t len) {
  std::lock_guard<std::mutex> lock(mutex_);
  // save to nullptr
  if (init_score == nullptr || len == 0) {
    init_score_.clear();
    num_init_score_ = 0;
    return;
  }
  if ((len % num_data_) != 0) {
    Log::Fatal("Initial score size doesn't match data size");
  }
300
  if (init_score_.empty()) { init_score_.resize(len); }
301
  num_init_score_ = len;
302

Guolin Ke's avatar
Guolin Ke committed
303
  #pragma omp parallel for schedule(static, 512) if (num_init_score_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
304
  for (int64_t i = 0; i < num_init_score_; ++i) {
305
    init_score_[i] = Common::AvoidInf(init_score[i]);
306
  }
307
  init_score_load_from_file_ = false;
308
309
310
311
312
  #ifdef USE_CUDA_EXP
  if (cuda_metadata_ != nullptr) {
    cuda_metadata_->SetInitScore(init_score_.data(), len);
  }
  #endif  // USE_CUDA_EXP
313
314
}

315
void Metadata::SetLabel(const label_t* label, data_size_t len) {
316
  std::lock_guard<std::mutex> lock(mutex_);
317
318
319
  if (label == nullptr) {
    Log::Fatal("label cannot be nullptr");
  }
Guolin Ke's avatar
Guolin Ke committed
320
  if (num_data_ != len) {
321
    Log::Fatal("Length of label is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
322
  }
323
  if (label_.empty()) { label_.resize(num_data_); }
324

Guolin Ke's avatar
Guolin Ke committed
325
  #pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
326
  for (data_size_t i = 0; i < num_data_; ++i) {
327
    label_[i] = Common::AvoidInf(label[i]);
Guolin Ke's avatar
Guolin Ke committed
328
  }
329
330
331
332
333
  #ifdef USE_CUDA_EXP
  if (cuda_metadata_ != nullptr) {
    cuda_metadata_->SetLabel(label_.data(), len);
  }
  #endif  // USE_CUDA_EXP
Guolin Ke's avatar
Guolin Ke committed
334
335
}

336
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
337
  std::lock_guard<std::mutex> lock(mutex_);
338
339
340
341
342
343
  // save to nullptr
  if (weights == nullptr || len == 0) {
    weights_.clear();
    num_weights_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
344
  if (num_data_ != len) {
345
    Log::Fatal("Length of weights is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
346
  }
347
  if (weights_.empty()) { weights_.resize(num_data_); }
Guolin Ke's avatar
Guolin Ke committed
348
  num_weights_ = num_data_;
349

Guolin Ke's avatar
Guolin Ke committed
350
  #pragma omp parallel for schedule(static, 512) if (num_weights_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
351
  for (data_size_t i = 0; i < num_weights_; ++i) {
352
    weights_[i] = Common::AvoidInf(weights[i]);
Guolin Ke's avatar
Guolin Ke committed
353
354
  }
  LoadQueryWeights();
355
  weight_load_from_file_ = false;
356
357
358
359
360
  #ifdef USE_CUDA_EXP
  if (cuda_metadata_ != nullptr) {
    cuda_metadata_->SetWeights(weights_.data(), len);
  }
  #endif  // USE_CUDA_EXP
Guolin Ke's avatar
Guolin Ke committed
361
362
}

Guolin Ke's avatar
Guolin Ke committed
363
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
364
  std::lock_guard<std::mutex> lock(mutex_);
365
  // save to nullptr
Guolin Ke's avatar
Guolin Ke committed
366
  if (query == nullptr || len == 0) {
367
368
369
370
    query_boundaries_.clear();
    num_queries_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
371
  data_size_t sum = 0;
372
  #pragma omp parallel for schedule(static) reduction(+:sum)
Guolin Ke's avatar
Guolin Ke committed
373
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
374
    sum += query[i];
Guolin Ke's avatar
Guolin Ke committed
375
376
  }
  if (num_data_ != sum) {
377
    Log::Fatal("Sum of query counts is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
378
379
  }
  num_queries_ = len;
380
  query_boundaries_.resize(num_queries_ + 1);
Guolin Ke's avatar
Guolin Ke committed
381
  query_boundaries_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
382
  for (data_size_t i = 0; i < num_queries_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
383
    query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
Guolin Ke's avatar
Guolin Ke committed
384
385
  }
  LoadQueryWeights();
386
  query_load_from_file_ = false;
387
388
389
390
391
392
393
394
395
396
  #ifdef USE_CUDA_EXP
  if (cuda_metadata_ != nullptr) {
    if (query_weights_.size() > 0) {
      CHECK_EQ(query_weights_.size(), static_cast<size_t>(num_queries_));
      cuda_metadata_->SetQuery(query_boundaries_.data(), query_weights_.data(), num_queries_);
    } else {
      cuda_metadata_->SetQuery(query_boundaries_.data(), nullptr, num_queries_);
    }
  }
  #endif  // USE_CUDA_EXP
397
}
Guolin Ke's avatar
Guolin Ke committed
398

Guolin Ke's avatar
Guolin Ke committed
399
400
401
402
403
void Metadata::LoadWeights() {
  num_weights_ = 0;
  std::string weight_filename(data_filename_);
  // default weight file name
  weight_filename.append(".weight");
Guolin Ke's avatar
Guolin Ke committed
404
  TextReader<size_t> reader(weight_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
405
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
406
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
407
408
    return;
  }
409
  Log::Info("Loading weights...");
Guolin Ke's avatar
Guolin Ke committed
410
  num_weights_ = static_cast<data_size_t>(reader.Lines().size());
411
  weights_ = std::vector<label_t>(num_weights_);
412
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
413
  for (data_size_t i = 0; i < num_weights_; ++i) {
414
    double tmp_weight = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
415
    Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
416
    weights_[i] = Common::AvoidInf(static_cast<label_t>(tmp_weight));
Guolin Ke's avatar
Guolin Ke committed
417
  }
418
  weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
419
420
}

421
void Metadata::LoadInitialScore(const std::string& data_filename) {
Guolin Ke's avatar
Guolin Ke committed
422
  num_init_score_ = 0;
423
424
  std::string init_score_filename(data_filename);
  init_score_filename = std::string(data_filename);
425
426
  // default init_score file name
  init_score_filename.append(".init");
Guolin Ke's avatar
Guolin Ke committed
427
  TextReader<size_t> reader(init_score_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
428
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
429
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
430
431
    return;
  }
432
433
  Log::Info("Loading initial scores...");

434
435
436
  // use first line to count number class
  int num_class = static_cast<int>(Common::Split(reader.Lines()[0].c_str(), '\t').size());
  data_size_t num_line = static_cast<data_size_t>(reader.Lines().size());
Guolin Ke's avatar
Guolin Ke committed
437
  num_init_score_ = static_cast<int64_t>(num_line) * num_class;
438

Guolin Ke's avatar
Guolin Ke committed
439
  init_score_ = std::vector<double>(num_init_score_);
440
  if (num_class == 1) {
441
    #pragma omp parallel for schedule(static)
442
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
443
      double tmp = 0.0f;
444
      Common::Atof(reader.Lines()[i].c_str(), &tmp);
445
      init_score_[i] = Common::AvoidInf(static_cast<double>(tmp));
446
    }
447
  } else {
448
    std::vector<std::string> oneline_init_score;
449
    #pragma omp parallel for schedule(static)
450
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
451
      double tmp = 0.0f;
452
453
      oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
      if (static_cast<int>(oneline_init_score.size()) != num_class) {
454
        Log::Fatal("Invalid initial score file. Redundant or insufficient columns");
455
456
457
      }
      for (int k = 0; k < num_class; ++k) {
        Common::Atof(oneline_init_score[k].c_str(), &tmp);
458
        init_score_[static_cast<size_t>(k) * num_line + i] = Common::AvoidInf(static_cast<double>(tmp));
459
      }
460
    }
Guolin Ke's avatar
Guolin Ke committed
461
  }
462
  init_score_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
463
464
465
466
467
468
469
}

void Metadata::LoadQueryBoundaries() {
  num_queries_ = 0;
  std::string query_filename(data_filename_);
  // default query file name
  query_filename.append(".query");
Guolin Ke's avatar
Guolin Ke committed
470
  TextReader<size_t> reader(query_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
471
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
472
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
473
474
    return;
  }
475
  Log::Info("Loading query boundaries...");
Guolin Ke's avatar
Guolin Ke committed
476
  query_boundaries_ = std::vector<data_size_t>(reader.Lines().size() + 1);
Guolin Ke's avatar
Guolin Ke committed
477
478
479
480
481
482
483
  num_queries_ = static_cast<data_size_t>(reader.Lines().size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < reader.Lines().size(); ++i) {
    int tmp_cnt;
    Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
    query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
  }
484
  query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
485
486
487
}

void Metadata::LoadQueryWeights() {
Guolin Ke's avatar
Guolin Ke committed
488
  if (weights_.size() == 0 || query_boundaries_.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
489
490
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
491
  query_weights_.clear();
492
  Log::Info("Loading query weights...");
493
  query_weights_ = std::vector<label_t>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
494
495
496
497
498
499
500
501
502
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_weights_[i] = 0.0f;
    for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
      query_weights_[i] += weights_[j];
    }
    query_weights_[i] /= (query_boundaries_[i + 1] - query_boundaries_[i]);
  }
}

503
504
505
506
507
508
509
#ifdef USE_CUDA_EXP
void Metadata::CreateCUDAMetadata(const int gpu_device_id) {
  cuda_metadata_.reset(new CUDAMetadata(gpu_device_id));
  cuda_metadata_->Init(label_, weights_, query_boundaries_, query_weights_, init_score_);
}
#endif  // USE_CUDA_EXP

Guolin Ke's avatar
Guolin Ke committed
510
511
512
513
void Metadata::LoadFromMemory(const void* memory) {
  const char* mem_ptr = reinterpret_cast<const char*>(memory);

  num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
514
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(num_data_));
Guolin Ke's avatar
Guolin Ke committed
515
  num_weights_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
516
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(num_weights_));
Guolin Ke's avatar
Guolin Ke committed
517
  num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
518
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(num_queries_));
Guolin Ke's avatar
Guolin Ke committed
519

Guolin Ke's avatar
Guolin Ke committed
520
  if (!label_.empty()) { label_.clear(); }
521
  label_ = std::vector<label_t>(num_data_);
522
  std::memcpy(label_.data(), mem_ptr, sizeof(label_t) * num_data_);
523
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
524
525

  if (num_weights_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
526
    if (!weights_.empty()) { weights_.clear(); }
527
    weights_ = std::vector<label_t>(num_weights_);
528
    std::memcpy(weights_.data(), mem_ptr, sizeof(label_t) * num_weights_);
529
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_weights_);
530
    weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
531
532
  }
  if (num_queries_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
533
    if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
534
    query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
535
    std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t) * (num_queries_ + 1));
536
537
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(data_size_t) *
                                              (num_queries_ + 1));
538
    query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
539
  }
Guolin Ke's avatar
Guolin Ke committed
540
  LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
541
542
}

543
void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
544
545
546
547
  writer->AlignedWrite(&num_data_, sizeof(num_data_));
  writer->AlignedWrite(&num_weights_, sizeof(num_weights_));
  writer->AlignedWrite(&num_queries_, sizeof(num_queries_));
  writer->AlignedWrite(label_.data(), sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
548
  if (!weights_.empty()) {
549
    writer->AlignedWrite(weights_.data(), sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
550
  }
Guolin Ke's avatar
Guolin Ke committed
551
  if (!query_boundaries_.empty()) {
552
553
    writer->AlignedWrite(query_boundaries_.data(),
                         sizeof(data_size_t) * (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
554
  }
555
556
557
558
  if (num_init_score_ > 0) {
    Log::Warning("Please note that `init_score` is not saved in binary file.\n"
      "If you need it, please set it again after loading Dataset.");
  }
Guolin Ke's avatar
Guolin Ke committed
559
560
}

561
size_t Metadata::SizesInByte() const {
562
563
564
565
  size_t size = VirtualFileWriter::AlignedSize(sizeof(num_data_)) +
                VirtualFileWriter::AlignedSize(sizeof(num_weights_)) +
                VirtualFileWriter::AlignedSize(sizeof(num_queries_));
  size += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
566
  if (!weights_.empty()) {
567
    size += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
568
  }
Guolin Ke's avatar
Guolin Ke committed
569
  if (!query_boundaries_.empty()) {
570
571
    size += VirtualFileWriter::AlignedSize(sizeof(data_size_t) *
                                           (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
572
573
574
575
576
577
  }
  return size;
}


}  // namespace LightGBM