metadata.cpp 17.9 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h>

#include <string>
5
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
6
7
8

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
9
Metadata::Metadata() {
Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
  num_weights_ = 0;
  num_init_score_ = 0;
  num_data_ = 0;
  num_queries_ = 0;
14
15
16
  weight_load_from_file_ = false;
  query_load_from_file_ = false;
  init_score_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
17
18
}

19
void Metadata::Init(const char * data_filename, const char* initscore_file) {
Guolin Ke's avatar
Guolin Ke committed
20
21
22
23
24
  data_filename_ = data_filename;
  // for lambdarank, it needs query data for partition data in parallel learning
  LoadQueryBoundaries();
  LoadWeights();
  LoadQueryWeights();
25
  LoadInitialScore(initscore_file);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
}

Metadata::~Metadata() {
}

31
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
Guolin Ke's avatar
Guolin Ke committed
32
  num_data_ = num_data;
33
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
34
  if (weight_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
35
    if (!weights_.empty()) {
36
      Log::Info("Using weights in data file, ignoring the additional weights file");
Guolin Ke's avatar
Guolin Ke committed
37
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
38
    }
39
    weights_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
40
    num_weights_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
41
42
43
44
#pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_weights_; ++i) {
      weights_[i] = 0.0f;
    }
45
    weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
46
47
  }
  if (query_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
48
    if (!query_boundaries_.empty()) {
49
      Log::Info("Using query id in data file, ignoring the additional query file");
Guolin Ke's avatar
Guolin Ke committed
50
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
51
    }
Guolin Ke's avatar
Guolin Ke committed
52
    if (!query_weights_.empty()) { query_weights_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
53
    queries_ = std::vector<data_size_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
54
55
56
57
#pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data_; ++i) {
      queries_[i] = 0;
    }
58
    query_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
59
  }
Guolin Ke's avatar
Guolin Ke committed
60
61
}

Guolin Ke's avatar
Guolin Ke committed
62
63
64
void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) {
  num_data_ = num_used_indices;

65
  label_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
66
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
67
68
69
70
  for (data_size_t i = 0; i < num_used_indices; i++) {
    label_[i] = fullset.label_[used_indices[i]];
  }

Guolin Ke's avatar
Guolin Ke committed
71
  if (!fullset.weights_.empty()) {
72
    weights_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
73
    num_weights_ = num_used_indices;
Guolin Ke's avatar
Guolin Ke committed
74
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
    for (data_size_t i = 0; i < num_used_indices; i++) {
      weights_[i] = fullset.weights_[used_indices[i]];
    }
  } else {
    num_weights_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
82
  if (!fullset.init_score_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
83
84
85
86
    int num_class = static_cast<int>(fullset.num_init_score_ / fullset.num_data_);
    init_score_ = std::vector<double>(num_used_indices*num_class);
    num_init_score_ = static_cast<int64_t>(num_used_indices) * num_class;
#pragma omp parallel for schedule(static)
87
88
89
90
    for (int k = 0; k < num_class; ++k) {
      for (data_size_t i = 0; i < num_used_indices; i++) {
        init_score_[k*num_data_ + i] = fullset.init_score_[k* fullset.num_data_ + used_indices[i]];
      }
Guolin Ke's avatar
Guolin Ke committed
91
92
93
94
95
    }
  } else {
    num_init_score_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
96
  if (!fullset.query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    std::vector<data_size_t> used_query;
    data_size_t data_idx = 0;
    for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_indices; ++qid) {
      data_size_t start = fullset.query_boundaries_[qid];
      data_size_t end = fullset.query_boundaries_[qid + 1];
      data_size_t len = end - start;
      if (used_indices[data_idx] > start) {
        continue;
      } else if (used_indices[data_idx] == start) {
        if (num_used_indices >= data_idx + len && used_indices[data_idx + len - 1] == end - 1) {
          used_query.push_back(qid);
          data_idx += len;
        } else {
          Log::Fatal("Data partition error, data didn't match queries");
        }
      } else {
        Log::Fatal("Data partition error, data didn't match queries");
      }
    }
    query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
    num_queries_ = static_cast<data_size_t>(used_query.size());
    query_boundaries_[0] = 0;
    for (data_size_t i = 0; i < num_queries_; ++i) {
      data_size_t qid = used_query[i];
      data_size_t len = fullset.query_boundaries_[qid + 1] - fullset.query_boundaries_[qid];
      query_boundaries_[i + 1] = query_boundaries_[i] + len;
    }
  } else {
    num_queries_ = 0;
  }
}

Guolin Ke's avatar
Guolin Ke committed
129
void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
Guolin Ke's avatar
Guolin Ke committed
130
  if (used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
131
132
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
133
  auto old_label = label_;
Guolin Ke's avatar
Guolin Ke committed
134
  num_data_ = static_cast<data_size_t>(used_indices.size());
135
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
136
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
137
138
139
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = old_label[used_indices[i]];
  }
Guolin Ke's avatar
Guolin Ke committed
140
  old_label.clear();
Guolin Ke's avatar
Guolin Ke committed
141
142
143
}

void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
144
145
  if (used_data_indices.empty()) {
    if (!queries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
Guolin Ke's avatar
Guolin Ke committed
161
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
162
163
164
165
166
167
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
168
      queries_.clear();
Guolin Ke's avatar
Guolin Ke committed
169
    }
Guolin Ke's avatar
Guolin Ke committed
170
    // check weights
Guolin Ke's avatar
Guolin Ke committed
171
    if (!weights_.empty() && num_weights_ != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
172
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
173
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
174
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    }

    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
178
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
179
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
180
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
181
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
182
183
184
    }

    // contain initial score file
185
    if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
Guolin Ke's avatar
Guolin Ke committed
186
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
187
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
188
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
189
190
    }
  } else {
191
192
193
    if (!queries_.empty()) {
      Log::Fatal("Cannot used query_id for parallel training");
    }
Guolin Ke's avatar
Guolin Ke committed
194
195
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
196
197
198
199
200
201
202
203
204
205
    if (weight_load_from_file_) {
      if (weights_.size() > 0 && num_weights_ != num_all_data) {
        weights_.clear();
        num_weights_ = 0;
        Log::Fatal("Weights size doesn't match data size");
      }
      // get local weights
      if (!weights_.empty()) {
        auto old_weights = weights_;
        num_weights_ = num_data_;
206
        weights_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
207
#pragma omp parallel for schedule(static)
208
209
210
211
        for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
          weights_[i] = old_weights[used_data_indices[i]];
        }
        old_weights.clear();
Guolin Ke's avatar
Guolin Ke committed
212
213
      }
    }
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    if (query_load_from_file_) {
      // check query boundries
      if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
        query_boundaries_.clear();
        num_queries_ = 0;
        Log::Fatal("Query size doesn't match data size");
      }
      // get local query boundaries
      if (!query_boundaries_.empty()) {
        std::vector<data_size_t> used_query;
        data_size_t data_idx = 0;
        for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
          data_size_t start = query_boundaries_[qid];
          data_size_t end = query_boundaries_[qid + 1];
          data_size_t len = end - start;
          if (used_data_indices[data_idx] > start) {
            continue;
          } else if (used_data_indices[data_idx] == start) {
            if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
              used_query.push_back(qid);
              data_idx += len;
            } else {
              Log::Fatal("Data partition error, data didn't match queries");
            }
Guolin Ke's avatar
Guolin Ke committed
238
          } else {
Guolin Ke's avatar
Guolin Ke committed
239
            Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
240
241
          }
        }
242
243
244
245
246
247
248
249
250
251
        auto old_query_boundaries = query_boundaries_;
        query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
        num_queries_ = static_cast<data_size_t>(used_query.size());
        query_boundaries_[0] = 0;
        for (data_size_t i = 0; i < num_queries_; ++i) {
          data_size_t qid = used_query[i];
          data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
          query_boundaries_[i + 1] = query_boundaries_[i] + len;
        }
        old_query_boundaries.clear();
Guolin Ke's avatar
Guolin Ke committed
252
253
      }
    }
254
255
256
257
258
259
260
    if (init_score_load_from_file_) {
      // contain initial score file
      if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
        init_score_.clear();
        num_init_score_ = 0;
        Log::Fatal("Initial score size doesn't match data size");
      }
Guolin Ke's avatar
Guolin Ke committed
261

262
263
264
265
266
267
      // get local initial scores
      if (!init_score_.empty()) {
        auto old_scores = init_score_;
        int num_class = static_cast<int>(num_init_score_ / num_all_data);
        num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
        init_score_ = std::vector<double>(num_init_score_);
Guolin Ke's avatar
Guolin Ke committed
268
#pragma omp parallel for schedule(static)
269
270
271
272
        for (int k = 0; k < num_class; ++k) {
          for (size_t i = 0; i < used_data_indices.size(); ++i) {
            init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]];
          }
273
        }
274
        old_scores.clear();
Guolin Ke's avatar
Guolin Ke committed
275
276
277
278
279
280
281
      }
    }
    // re-load query weight
    LoadQueryWeights();
  }
}

282
283
284
285
286
287
288
289
290
291
292
293
294
void Metadata::SetInitScore(const double* init_score, data_size_t len) {
  std::lock_guard<std::mutex> lock(mutex_);
  // save to nullptr
  if (init_score == nullptr || len == 0) {
    init_score_.clear();
    num_init_score_ = 0;
    return;
  }
  if ((len % num_data_) != 0) {
    Log::Fatal("Initial score size doesn't match data size");
  }
  if (!init_score_.empty()) { init_score_.clear(); }
  num_init_score_ = len;
Guolin Ke's avatar
Guolin Ke committed
295
296
297
298
  init_score_ = std::vector<double>(len);
#pragma omp parallel for schedule(static)
  for (int64_t i = 0; i < num_init_score_; ++i) {
    init_score_[i] = init_score[i];
299
  }
300
  init_score_load_from_file_ = false;
301
302
}

303
void Metadata::SetLabel(const label_t* label, data_size_t len) {
304
  std::lock_guard<std::mutex> lock(mutex_);
305
306
307
  if (label == nullptr) {
    Log::Fatal("label cannot be nullptr");
  }
Guolin Ke's avatar
Guolin Ke committed
308
  if (num_data_ != len) {
309
    Log::Fatal("Length of label is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
310
  }
Guolin Ke's avatar
Guolin Ke committed
311
  if (!label_.empty()) { label_.clear(); }
312
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
313
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = label[i];
  }
}

319
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
320
  std::lock_guard<std::mutex> lock(mutex_);
321
322
323
324
325
326
  // save to nullptr
  if (weights == nullptr || len == 0) {
    weights_.clear();
    num_weights_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
327
  if (num_data_ != len) {
328
    Log::Fatal("Length of weights is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
329
  }
Guolin Ke's avatar
Guolin Ke committed
330
  if (!weights_.empty()) { weights_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
331
  num_weights_ = num_data_;
332
  weights_ = std::vector<label_t>(num_weights_);
Guolin Ke's avatar
Guolin Ke committed
333
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
334
335
336
337
  for (data_size_t i = 0; i < num_weights_; ++i) {
    weights_[i] = weights[i];
  }
  LoadQueryWeights();
338
  weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
339
340
}

Guolin Ke's avatar
Guolin Ke committed
341
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
342
  std::lock_guard<std::mutex> lock(mutex_);
343
  // save to nullptr
Guolin Ke's avatar
Guolin Ke committed
344
  if (query == nullptr || len == 0) {
345
346
347
348
    query_boundaries_.clear();
    num_queries_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
349
  data_size_t sum = 0;
Guolin Ke's avatar
Guolin Ke committed
350
#pragma omp parallel for schedule(static) reduction(+:sum)
Guolin Ke's avatar
Guolin Ke committed
351
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
352
    sum += query[i];
Guolin Ke's avatar
Guolin Ke committed
353
354
  }
  if (num_data_ != sum) {
355
    Log::Fatal("Sum of query counts is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
356
  }
Guolin Ke's avatar
Guolin Ke committed
357
  if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
358
  num_queries_ = len;
Guolin Ke's avatar
Guolin Ke committed
359
360
  query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
  query_boundaries_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
361
  for (data_size_t i = 0; i < num_queries_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
362
    query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
Guolin Ke's avatar
Guolin Ke committed
363
364
  }
  LoadQueryWeights();
365
  query_load_from_file_ = false;
366
}
Guolin Ke's avatar
Guolin Ke committed
367

Guolin Ke's avatar
Guolin Ke committed
368
369
370
371
372
void Metadata::LoadWeights() {
  num_weights_ = 0;
  std::string weight_filename(data_filename_);
  // default weight file name
  weight_filename.append(".weight");
Guolin Ke's avatar
Guolin Ke committed
373
  TextReader<size_t> reader(weight_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
374
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
375
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
376
377
    return;
  }
378
  Log::Info("Loading weights...");
Guolin Ke's avatar
Guolin Ke committed
379
  num_weights_ = static_cast<data_size_t>(reader.Lines().size());
380
  weights_ = std::vector<label_t>(num_weights_);
Guolin Ke's avatar
Guolin Ke committed
381
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
382
  for (data_size_t i = 0; i < num_weights_; ++i) {
383
    double tmp_weight = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
384
    Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
385
    weights_[i] = static_cast<label_t>(tmp_weight);
Guolin Ke's avatar
Guolin Ke committed
386
  }
387
  weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
388
389
}

390
void Metadata::LoadInitialScore(const char* initscore_file) {
Guolin Ke's avatar
Guolin Ke committed
391
  num_init_score_ = 0;
392
393
394
395
396
397
  std::string init_score_filename(initscore_file);
  if (init_score_filename.size() <= 0) {
    init_score_filename = std::string(data_filename_);
    // default weight file name
    init_score_filename.append(".init");
  }
Guolin Ke's avatar
Guolin Ke committed
398
  TextReader<size_t> reader(init_score_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
399
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
400
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
401
402
    return;
  }
403
404
  Log::Info("Loading initial scores...");

405
406
407
  // use first line to count number class
  int num_class = static_cast<int>(Common::Split(reader.Lines()[0].c_str(), '\t').size());
  data_size_t num_line = static_cast<data_size_t>(reader.Lines().size());
Guolin Ke's avatar
Guolin Ke committed
408
  num_init_score_ = static_cast<int64_t>(num_line) * num_class;
409

Guolin Ke's avatar
Guolin Ke committed
410
  init_score_ = std::vector<double>(num_init_score_);
411
  if (num_class == 1) {
Guolin Ke's avatar
Guolin Ke committed
412
#pragma omp parallel for schedule(static)
413
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
414
      double tmp = 0.0f;
415
      Common::Atof(reader.Lines()[i].c_str(), &tmp);
Guolin Ke's avatar
Guolin Ke committed
416
      init_score_[i] = static_cast<double>(tmp);
417
    }
418
  } else {
419
    std::vector<std::string> oneline_init_score;
Guolin Ke's avatar
Guolin Ke committed
420
#pragma omp parallel for schedule(static)
421
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
422
      double tmp = 0.0f;
423
424
      oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
      if (static_cast<int>(oneline_init_score.size()) != num_class) {
425
        Log::Fatal("Invalid initial score file. Redundant or insufficient columns");
426
427
428
      }
      for (int k = 0; k < num_class; ++k) {
        Common::Atof(oneline_init_score[k].c_str(), &tmp);
Guolin Ke's avatar
Guolin Ke committed
429
        init_score_[k * num_line + i] = static_cast<double>(tmp);
430
      }
431
    }
Guolin Ke's avatar
Guolin Ke committed
432
  }
433
  init_score_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
434
435
436
437
438
439
440
}

void Metadata::LoadQueryBoundaries() {
  num_queries_ = 0;
  std::string query_filename(data_filename_);
  // default query file name
  query_filename.append(".query");
Guolin Ke's avatar
Guolin Ke committed
441
  TextReader<size_t> reader(query_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
442
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
443
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
444
445
    return;
  }
446
  Log::Info("Loading query boundaries...");
Guolin Ke's avatar
Guolin Ke committed
447
  query_boundaries_ = std::vector<data_size_t>(reader.Lines().size() + 1);
Guolin Ke's avatar
Guolin Ke committed
448
449
450
451
452
453
454
  num_queries_ = static_cast<data_size_t>(reader.Lines().size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < reader.Lines().size(); ++i) {
    int tmp_cnt;
    Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
    query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
  }
455
  query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
456
457
458
}

void Metadata::LoadQueryWeights() {
Guolin Ke's avatar
Guolin Ke committed
459
  if (weights_.size() == 0 || query_boundaries_.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
460
461
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
462
  query_weights_.clear();
463
  Log::Info("Loading query weights...");
464
  query_weights_ = std::vector<label_t>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_weights_[i] = 0.0f;
    for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
      query_weights_[i] += weights_[j];
    }
    query_weights_[i] /= (query_boundaries_[i + 1] - query_boundaries_[i]);
  }
}

void Metadata::LoadFromMemory(const void* memory) {
  const char* mem_ptr = reinterpret_cast<const char*>(memory);

  num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_data_);
  num_weights_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_weights_);
  num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_queries_);

Guolin Ke's avatar
Guolin Ke committed
484
  if (!label_.empty()) { label_.clear(); }
485
486
487
  label_ = std::vector<label_t>(num_data_);
  std::memcpy(label_.data(), mem_ptr, sizeof(label_t)*num_data_);
  mem_ptr += sizeof(label_t)*num_data_;
Guolin Ke's avatar
Guolin Ke committed
488
489

  if (num_weights_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
490
    if (!weights_.empty()) { weights_.clear(); }
491
492
493
    weights_ = std::vector<label_t>(num_weights_);
    std::memcpy(weights_.data(), mem_ptr, sizeof(label_t)*num_weights_);
    mem_ptr += sizeof(label_t)*num_weights_;
494
    weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
495
496
  }
  if (num_queries_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
497
    if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
498
499
    query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
    std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t)*(num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
500
    mem_ptr += sizeof(data_size_t)*(num_queries_ + 1);
501
    query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
502
  }
Guolin Ke's avatar
Guolin Ke committed
503
  LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
504
505
}

506
507
508
509
510
void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
  writer->Write(&num_data_, sizeof(num_data_));
  writer->Write(&num_weights_, sizeof(num_weights_));
  writer->Write(&num_queries_, sizeof(num_queries_));
  writer->Write(label_.data(), sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
511
  if (!weights_.empty()) {
512
    writer->Write(weights_.data(), sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
513
  }
Guolin Ke's avatar
Guolin Ke committed
514
  if (!query_boundaries_.empty()) {
515
    writer->Write(query_boundaries_.data(), sizeof(data_size_t) * (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
516
517
518
519
520
521
  }
}

size_t Metadata::SizesInByte() const  {
  size_t size = sizeof(num_data_) + sizeof(num_weights_)
    + sizeof(num_queries_);
522
  size += sizeof(label_t) * num_data_;
Guolin Ke's avatar
Guolin Ke committed
523
  if (!weights_.empty()) {
524
    size += sizeof(label_t) * num_weights_;
Guolin Ke's avatar
Guolin Ke committed
525
  }
Guolin Ke's avatar
Guolin Ke committed
526
  if (!query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
527
528
529
530
531
532
533
    size += sizeof(data_size_t) * (num_queries_ + 1);
  }
  return size;
}


}  // namespace LightGBM