metadata.cpp 18.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6

#include <string>
7
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
8

9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h>

Guolin Ke's avatar
Guolin Ke committed
12
13
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
14
Metadata::Metadata() {
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
  num_weights_ = 0;
  num_init_score_ = 0;
  num_data_ = 0;
  num_queries_ = 0;
19
20
21
  weight_load_from_file_ = false;
  query_load_from_file_ = false;
  init_score_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
22
23
}

24
void Metadata::Init(const char* data_filename) {
Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
29
  data_filename_ = data_filename;
  // for lambdarank, it needs query data for partition data in parallel learning
  LoadQueryBoundaries();
  LoadWeights();
  LoadQueryWeights();
30
  LoadInitialScore();
Guolin Ke's avatar
Guolin Ke committed
31
32
33
34
35
}

Metadata::~Metadata() {
}

36
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
Guolin Ke's avatar
Guolin Ke committed
37
  num_data_ = num_data;
38
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
39
  if (weight_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
40
    if (!weights_.empty()) {
41
      Log::Info("Using weights in data file, ignoring the additional weights file");
Guolin Ke's avatar
Guolin Ke committed
42
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
43
    }
44
    weights_ = std::vector<label_t>(num_data_, 0.0f);
Guolin Ke's avatar
Guolin Ke committed
45
    num_weights_ = num_data_;
46
    weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
47
48
  }
  if (query_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
49
    if (!query_boundaries_.empty()) {
50
      Log::Info("Using query id in data file, ignoring the additional query file");
Guolin Ke's avatar
Guolin Ke committed
51
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
52
    }
Guolin Ke's avatar
Guolin Ke committed
53
    if (!query_weights_.empty()) { query_weights_.clear(); }
54
    queries_ = std::vector<data_size_t>(num_data_, 0);
55
    query_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
56
  }
Guolin Ke's avatar
Guolin Ke committed
57
58
}

Guolin Ke's avatar
Guolin Ke committed
59
60
61
void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) {
  num_data_ = num_used_indices;

62
  label_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
63
#pragma omp parallel for schedule(static, 512) if (num_used_indices >= 1024)
64
  for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
65
66
67
    label_[i] = fullset.label_[used_indices[i]];
  }

Guolin Ke's avatar
Guolin Ke committed
68
  if (!fullset.weights_.empty()) {
69
    weights_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
70
    num_weights_ = num_used_indices;
Guolin Ke's avatar
Guolin Ke committed
71
#pragma omp parallel for schedule(static, 512) if (num_used_indices >= 1024)
72
    for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
73
74
75
76
77
78
      weights_[i] = fullset.weights_[used_indices[i]];
    }
  } else {
    num_weights_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
79
  if (!fullset.init_score_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
80
    int num_class = static_cast<int>(fullset.num_init_score_ / fullset.num_data_);
81
    init_score_ = std::vector<double>(static_cast<size_t>(num_used_indices) * num_class);
Guolin Ke's avatar
Guolin Ke committed
82
    num_init_score_ = static_cast<int64_t>(num_used_indices) * num_class;
83
    #pragma omp parallel for schedule(static)
84
    for (int k = 0; k < num_class; ++k) {
85
86
87
88
      const size_t offset_dest = static_cast<size_t>(k) * num_data_;
      const size_t offset_src = static_cast<size_t>(k) * fullset.num_data_;
      for (data_size_t i = 0; i < num_used_indices; ++i) {
        init_score_[offset_dest + i] = fullset.init_score_[offset_src + used_indices[i]];
89
      }
Guolin Ke's avatar
Guolin Ke committed
90
91
92
93
94
    }
  } else {
    num_init_score_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
95
  if (!fullset.query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    std::vector<data_size_t> used_query;
    data_size_t data_idx = 0;
    for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_indices; ++qid) {
      data_size_t start = fullset.query_boundaries_[qid];
      data_size_t end = fullset.query_boundaries_[qid + 1];
      data_size_t len = end - start;
      if (used_indices[data_idx] > start) {
        continue;
      } else if (used_indices[data_idx] == start) {
        if (num_used_indices >= data_idx + len && used_indices[data_idx + len - 1] == end - 1) {
          used_query.push_back(qid);
          data_idx += len;
        } else {
          Log::Fatal("Data partition error, data didn't match queries");
        }
      } else {
        Log::Fatal("Data partition error, data didn't match queries");
      }
    }
    query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
    num_queries_ = static_cast<data_size_t>(used_query.size());
    query_boundaries_[0] = 0;
    for (data_size_t i = 0; i < num_queries_; ++i) {
      data_size_t qid = used_query[i];
      data_size_t len = fullset.query_boundaries_[qid + 1] - fullset.query_boundaries_[qid];
      query_boundaries_[i + 1] = query_boundaries_[i] + len;
    }
  } else {
    num_queries_ = 0;
  }
}

Guolin Ke's avatar
Guolin Ke committed
128
void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
Guolin Ke's avatar
Guolin Ke committed
129
  if (used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
130
131
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
132
  auto old_label = label_;
Guolin Ke's avatar
Guolin Ke committed
133
  num_data_ = static_cast<data_size_t>(used_indices.size());
134
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
135
#pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
136
137
138
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = old_label[used_indices[i]];
  }
Guolin Ke's avatar
Guolin Ke committed
139
  old_label.clear();
Guolin Ke's avatar
Guolin Ke committed
140
141
142
}

void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
143
144
  if (used_data_indices.empty()) {
    if (!queries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
Guolin Ke's avatar
Guolin Ke committed
160
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
165
166
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
167
      queries_.clear();
Guolin Ke's avatar
Guolin Ke committed
168
    }
Guolin Ke's avatar
Guolin Ke committed
169
    // check weights
Guolin Ke's avatar
Guolin Ke committed
170
    if (!weights_.empty() && num_weights_ != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
171
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
172
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
173
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
174
175
176
    }

    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
177
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
178
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
179
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
180
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
181
182
183
    }

    // contain initial score file
184
    if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
Guolin Ke's avatar
Guolin Ke committed
185
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
186
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
187
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
188
189
    }
  } else {
190
191
192
    if (!queries_.empty()) {
      Log::Fatal("Cannot used query_id for parallel training");
    }
Guolin Ke's avatar
Guolin Ke committed
193
194
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
195
196
197
198
199
200
201
202
203
204
    if (weight_load_from_file_) {
      if (weights_.size() > 0 && num_weights_ != num_all_data) {
        weights_.clear();
        num_weights_ = 0;
        Log::Fatal("Weights size doesn't match data size");
      }
      // get local weights
      if (!weights_.empty()) {
        auto old_weights = weights_;
        num_weights_ = num_data_;
205
        weights_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
206
#pragma omp parallel for schedule(static, 512)
207
208
209
210
        for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
          weights_[i] = old_weights[used_data_indices[i]];
        }
        old_weights.clear();
Guolin Ke's avatar
Guolin Ke committed
211
212
      }
    }
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    if (query_load_from_file_) {
      // check query boundries
      if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
        query_boundaries_.clear();
        num_queries_ = 0;
        Log::Fatal("Query size doesn't match data size");
      }
      // get local query boundaries
      if (!query_boundaries_.empty()) {
        std::vector<data_size_t> used_query;
        data_size_t data_idx = 0;
        for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
          data_size_t start = query_boundaries_[qid];
          data_size_t end = query_boundaries_[qid + 1];
          data_size_t len = end - start;
          if (used_data_indices[data_idx] > start) {
            continue;
          } else if (used_data_indices[data_idx] == start) {
            if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
              used_query.push_back(qid);
              data_idx += len;
            } else {
              Log::Fatal("Data partition error, data didn't match queries");
            }
Guolin Ke's avatar
Guolin Ke committed
237
          } else {
Guolin Ke's avatar
Guolin Ke committed
238
            Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
239
240
          }
        }
241
242
243
244
245
246
247
248
249
250
        auto old_query_boundaries = query_boundaries_;
        query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
        num_queries_ = static_cast<data_size_t>(used_query.size());
        query_boundaries_[0] = 0;
        for (data_size_t i = 0; i < num_queries_; ++i) {
          data_size_t qid = used_query[i];
          data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
          query_boundaries_[i + 1] = query_boundaries_[i] + len;
        }
        old_query_boundaries.clear();
Guolin Ke's avatar
Guolin Ke committed
251
252
      }
    }
253
254
255
256
257
258
259
    if (init_score_load_from_file_) {
      // contain initial score file
      if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
        init_score_.clear();
        num_init_score_ = 0;
        Log::Fatal("Initial score size doesn't match data size");
      }
Guolin Ke's avatar
Guolin Ke committed
260

261
262
263
264
265
266
      // get local initial scores
      if (!init_score_.empty()) {
        auto old_scores = init_score_;
        int num_class = static_cast<int>(num_init_score_ / num_all_data);
        num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
        init_score_ = std::vector<double>(num_init_score_);
Guolin Ke's avatar
Guolin Ke committed
267
#pragma omp parallel for schedule(static)
268
        for (int k = 0; k < num_class; ++k) {
269
270
          const size_t offset_dest = static_cast<size_t>(k) * num_data_;
          const size_t offset_src = static_cast<size_t>(k) * num_all_data;
271
          for (size_t i = 0; i < used_data_indices.size(); ++i) {
272
            init_score_[offset_dest + i] = old_scores[offset_src + used_data_indices[i]];
273
          }
274
        }
275
        old_scores.clear();
Guolin Ke's avatar
Guolin Ke committed
276
277
278
279
280
281
282
      }
    }
    // re-load query weight
    LoadQueryWeights();
  }
}

283
284
285
286
287
288
289
290
291
292
293
void Metadata::SetInitScore(const double* init_score, data_size_t len) {
  std::lock_guard<std::mutex> lock(mutex_);
  // save to nullptr
  if (init_score == nullptr || len == 0) {
    init_score_.clear();
    num_init_score_ = 0;
    return;
  }
  if ((len % num_data_) != 0) {
    Log::Fatal("Initial score size doesn't match data size");
  }
294
  if (init_score_.empty()) { init_score_.resize(len); }
295
  num_init_score_ = len;
296

Guolin Ke's avatar
Guolin Ke committed
297
  #pragma omp parallel for schedule(static, 512) if (num_init_score_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
298
  for (int64_t i = 0; i < num_init_score_; ++i) {
299
    init_score_[i] = Common::AvoidInf(init_score[i]);
300
  }
301
  init_score_load_from_file_ = false;
302
303
}

304
void Metadata::SetLabel(const label_t* label, data_size_t len) {
305
  std::lock_guard<std::mutex> lock(mutex_);
306
307
308
  if (label == nullptr) {
    Log::Fatal("label cannot be nullptr");
  }
Guolin Ke's avatar
Guolin Ke committed
309
  if (num_data_ != len) {
310
    Log::Fatal("Length of label is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
311
  }
312
  if (label_.empty()) { label_.resize(num_data_); }
313

Guolin Ke's avatar
Guolin Ke committed
314
  #pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
315
  for (data_size_t i = 0; i < num_data_; ++i) {
316
    label_[i] = Common::AvoidInf(label[i]);
Guolin Ke's avatar
Guolin Ke committed
317
318
319
  }
}

320
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
321
  std::lock_guard<std::mutex> lock(mutex_);
322
323
324
325
326
327
  // save to nullptr
  if (weights == nullptr || len == 0) {
    weights_.clear();
    num_weights_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
328
  if (num_data_ != len) {
329
    Log::Fatal("Length of weights is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
330
  }
331
  if (weights_.empty()) { weights_.resize(num_data_); }
Guolin Ke's avatar
Guolin Ke committed
332
  num_weights_ = num_data_;
333

Guolin Ke's avatar
Guolin Ke committed
334
  #pragma omp parallel for schedule(static, 512) if (num_weights_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
335
  for (data_size_t i = 0; i < num_weights_; ++i) {
336
    weights_[i] = Common::AvoidInf(weights[i]);
Guolin Ke's avatar
Guolin Ke committed
337
338
  }
  LoadQueryWeights();
339
  weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
340
341
}

Guolin Ke's avatar
Guolin Ke committed
342
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
343
  std::lock_guard<std::mutex> lock(mutex_);
344
  // save to nullptr
Guolin Ke's avatar
Guolin Ke committed
345
  if (query == nullptr || len == 0) {
346
347
348
349
    query_boundaries_.clear();
    num_queries_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
350
  data_size_t sum = 0;
351
  #pragma omp parallel for schedule(static) reduction(+:sum)
Guolin Ke's avatar
Guolin Ke committed
352
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
353
    sum += query[i];
Guolin Ke's avatar
Guolin Ke committed
354
355
  }
  if (num_data_ != sum) {
356
    Log::Fatal("Sum of query counts is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
357
358
  }
  num_queries_ = len;
359
  query_boundaries_.resize(num_queries_ + 1);
Guolin Ke's avatar
Guolin Ke committed
360
  query_boundaries_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
361
  for (data_size_t i = 0; i < num_queries_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
362
    query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
Guolin Ke's avatar
Guolin Ke committed
363
364
  }
  LoadQueryWeights();
365
  query_load_from_file_ = false;
366
}
Guolin Ke's avatar
Guolin Ke committed
367

Guolin Ke's avatar
Guolin Ke committed
368
369
370
371
372
void Metadata::LoadWeights() {
  num_weights_ = 0;
  std::string weight_filename(data_filename_);
  // default weight file name
  weight_filename.append(".weight");
Guolin Ke's avatar
Guolin Ke committed
373
  TextReader<size_t> reader(weight_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
374
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
375
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
376
377
    return;
  }
378
  Log::Info("Loading weights...");
Guolin Ke's avatar
Guolin Ke committed
379
  num_weights_ = static_cast<data_size_t>(reader.Lines().size());
380
  weights_ = std::vector<label_t>(num_weights_);
381
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
382
  for (data_size_t i = 0; i < num_weights_; ++i) {
383
    double tmp_weight = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
384
    Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
385
    weights_[i] = Common::AvoidInf(static_cast<label_t>(tmp_weight));
Guolin Ke's avatar
Guolin Ke committed
386
  }
387
  weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
388
389
}

390
void Metadata::LoadInitialScore() {
Guolin Ke's avatar
Guolin Ke committed
391
  num_init_score_ = 0;
392
393
394
395
  std::string init_score_filename(data_filename_);
  init_score_filename = std::string(data_filename_);
  // default init_score file name
  init_score_filename.append(".init");
Guolin Ke's avatar
Guolin Ke committed
396
  TextReader<size_t> reader(init_score_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
397
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
398
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
399
400
    return;
  }
401
402
  Log::Info("Loading initial scores...");

403
404
405
  // use first line to count number class
  int num_class = static_cast<int>(Common::Split(reader.Lines()[0].c_str(), '\t').size());
  data_size_t num_line = static_cast<data_size_t>(reader.Lines().size());
Guolin Ke's avatar
Guolin Ke committed
406
  num_init_score_ = static_cast<int64_t>(num_line) * num_class;
407

Guolin Ke's avatar
Guolin Ke committed
408
  init_score_ = std::vector<double>(num_init_score_);
409
  if (num_class == 1) {
410
    #pragma omp parallel for schedule(static)
411
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
412
      double tmp = 0.0f;
413
      Common::Atof(reader.Lines()[i].c_str(), &tmp);
414
      init_score_[i] = Common::AvoidInf(static_cast<double>(tmp));
415
    }
416
  } else {
417
    std::vector<std::string> oneline_init_score;
418
    #pragma omp parallel for schedule(static)
419
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
420
      double tmp = 0.0f;
421
422
      oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
      if (static_cast<int>(oneline_init_score.size()) != num_class) {
423
        Log::Fatal("Invalid initial score file. Redundant or insufficient columns");
424
425
426
      }
      for (int k = 0; k < num_class; ++k) {
        Common::Atof(oneline_init_score[k].c_str(), &tmp);
427
        init_score_[static_cast<size_t>(k) * num_line + i] = Common::AvoidInf(static_cast<double>(tmp));
428
      }
429
    }
Guolin Ke's avatar
Guolin Ke committed
430
  }
431
  init_score_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
432
433
434
435
436
437
438
}

void Metadata::LoadQueryBoundaries() {
  num_queries_ = 0;
  std::string query_filename(data_filename_);
  // default query file name
  query_filename.append(".query");
Guolin Ke's avatar
Guolin Ke committed
439
  TextReader<size_t> reader(query_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
440
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
441
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
442
443
    return;
  }
444
  Log::Info("Loading query boundaries...");
Guolin Ke's avatar
Guolin Ke committed
445
  query_boundaries_ = std::vector<data_size_t>(reader.Lines().size() + 1);
Guolin Ke's avatar
Guolin Ke committed
446
447
448
449
450
451
452
  num_queries_ = static_cast<data_size_t>(reader.Lines().size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < reader.Lines().size(); ++i) {
    int tmp_cnt;
    Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
    query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
  }
453
  query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
454
455
456
}

void Metadata::LoadQueryWeights() {
Guolin Ke's avatar
Guolin Ke committed
457
  if (weights_.size() == 0 || query_boundaries_.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
458
459
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
460
  query_weights_.clear();
461
  Log::Info("Loading query weights...");
462
  query_weights_ = std::vector<label_t>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_weights_[i] = 0.0f;
    for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
      query_weights_[i] += weights_[j];
    }
    query_weights_[i] /= (query_boundaries_[i + 1] - query_boundaries_[i]);
  }
}

void Metadata::LoadFromMemory(const void* memory) {
  const char* mem_ptr = reinterpret_cast<const char*>(memory);

  num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_data_);
  num_weights_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_weights_);
  num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_queries_);

Guolin Ke's avatar
Guolin Ke committed
482
  if (!label_.empty()) { label_.clear(); }
483
  label_ = std::vector<label_t>(num_data_);
484
485
  std::memcpy(label_.data(), mem_ptr, sizeof(label_t) * num_data_);
  mem_ptr += sizeof(label_t) * num_data_;
Guolin Ke's avatar
Guolin Ke committed
486
487

  if (num_weights_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
488
    if (!weights_.empty()) { weights_.clear(); }
489
    weights_ = std::vector<label_t>(num_weights_);
490
491
    std::memcpy(weights_.data(), mem_ptr, sizeof(label_t) * num_weights_);
    mem_ptr += sizeof(label_t) * num_weights_;
492
    weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
493
494
  }
  if (num_queries_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
495
    if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
496
    query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
497
498
    std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t) * (num_queries_ + 1));
    mem_ptr += sizeof(data_size_t) * (num_queries_ + 1);
499
    query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
500
  }
Guolin Ke's avatar
Guolin Ke committed
501
  LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
502
503
}

504
505
506
507
508
void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
  writer->Write(&num_data_, sizeof(num_data_));
  writer->Write(&num_weights_, sizeof(num_weights_));
  writer->Write(&num_queries_, sizeof(num_queries_));
  writer->Write(label_.data(), sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
509
  if (!weights_.empty()) {
510
    writer->Write(weights_.data(), sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
511
  }
Guolin Ke's avatar
Guolin Ke committed
512
  if (!query_boundaries_.empty()) {
513
    writer->Write(query_boundaries_.data(), sizeof(data_size_t) * (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
514
  }
515
516
517
518
  if (num_init_score_ > 0) {
    Log::Warning("Please note that `init_score` is not saved in binary file.\n"
      "If you need it, please set it again after loading Dataset.");
  }
Guolin Ke's avatar
Guolin Ke committed
519
520
}

521
size_t Metadata::SizesInByte() const {
Guolin Ke's avatar
Guolin Ke committed
522
523
  size_t size = sizeof(num_data_) + sizeof(num_weights_)
    + sizeof(num_queries_);
524
  size += sizeof(label_t) * num_data_;
Guolin Ke's avatar
Guolin Ke committed
525
  if (!weights_.empty()) {
526
    size += sizeof(label_t) * num_weights_;
Guolin Ke's avatar
Guolin Ke committed
527
  }
Guolin Ke's avatar
Guolin Ke committed
528
  if (!query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
529
530
531
532
533
534
535
    size += sizeof(data_size_t) * (num_queries_ + 1);
  }
  return size;
}


}  // namespace LightGBM