metadata.cpp 18.1 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h>

#include <string>
9
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
10
11
12

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
13
Metadata::Metadata() {
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
  num_weights_ = 0;
  num_init_score_ = 0;
  num_data_ = 0;
  num_queries_ = 0;
18
19
20
  weight_load_from_file_ = false;
  query_load_from_file_ = false;
  init_score_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
21
22
}

23
void Metadata::Init(const char * data_filename, const char* initscore_file) {
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
28
  data_filename_ = data_filename;
  // for lambdarank, it needs query data for partition data in parallel learning
  LoadQueryBoundaries();
  LoadWeights();
  LoadQueryWeights();
29
  LoadInitialScore(initscore_file);
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
}

Metadata::~Metadata() {
}

35
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
Guolin Ke's avatar
Guolin Ke committed
36
  num_data_ = num_data;
37
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
38
  if (weight_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
39
    if (!weights_.empty()) {
40
      Log::Info("Using weights in data file, ignoring the additional weights file");
Guolin Ke's avatar
Guolin Ke committed
41
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
42
    }
43
    weights_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
44
    num_weights_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
45
46
47
48
#pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_weights_; ++i) {
      weights_[i] = 0.0f;
    }
49
    weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
50
51
  }
  if (query_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
52
    if (!query_boundaries_.empty()) {
53
      Log::Info("Using query id in data file, ignoring the additional query file");
Guolin Ke's avatar
Guolin Ke committed
54
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
55
    }
Guolin Ke's avatar
Guolin Ke committed
56
    if (!query_weights_.empty()) { query_weights_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
57
    queries_ = std::vector<data_size_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
#pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_data_; ++i) {
      queries_[i] = 0;
    }
62
    query_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
63
  }
Guolin Ke's avatar
Guolin Ke committed
64
65
}

Guolin Ke's avatar
Guolin Ke committed
66
67
68
void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) {
  num_data_ = num_used_indices;

69
  label_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
70
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
  for (data_size_t i = 0; i < num_used_indices; i++) {
    label_[i] = fullset.label_[used_indices[i]];
  }

Guolin Ke's avatar
Guolin Ke committed
75
  if (!fullset.weights_.empty()) {
76
    weights_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
77
    num_weights_ = num_used_indices;
Guolin Ke's avatar
Guolin Ke committed
78
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
79
80
81
82
83
84
85
    for (data_size_t i = 0; i < num_used_indices; i++) {
      weights_[i] = fullset.weights_[used_indices[i]];
    }
  } else {
    num_weights_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
86
  if (!fullset.init_score_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
    int num_class = static_cast<int>(fullset.num_init_score_ / fullset.num_data_);
    init_score_ = std::vector<double>(num_used_indices*num_class);
    num_init_score_ = static_cast<int64_t>(num_used_indices) * num_class;
#pragma omp parallel for schedule(static)
91
92
93
94
    for (int k = 0; k < num_class; ++k) {
      for (data_size_t i = 0; i < num_used_indices; i++) {
        init_score_[k*num_data_ + i] = fullset.init_score_[k* fullset.num_data_ + used_indices[i]];
      }
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
99
    }
  } else {
    num_init_score_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
100
  if (!fullset.query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    std::vector<data_size_t> used_query;
    data_size_t data_idx = 0;
    for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_indices; ++qid) {
      data_size_t start = fullset.query_boundaries_[qid];
      data_size_t end = fullset.query_boundaries_[qid + 1];
      data_size_t len = end - start;
      if (used_indices[data_idx] > start) {
        continue;
      } else if (used_indices[data_idx] == start) {
        if (num_used_indices >= data_idx + len && used_indices[data_idx + len - 1] == end - 1) {
          used_query.push_back(qid);
          data_idx += len;
        } else {
          Log::Fatal("Data partition error, data didn't match queries");
        }
      } else {
        Log::Fatal("Data partition error, data didn't match queries");
      }
    }
    query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
    num_queries_ = static_cast<data_size_t>(used_query.size());
    query_boundaries_[0] = 0;
    for (data_size_t i = 0; i < num_queries_; ++i) {
      data_size_t qid = used_query[i];
      data_size_t len = fullset.query_boundaries_[qid + 1] - fullset.query_boundaries_[qid];
      query_boundaries_[i + 1] = query_boundaries_[i] + len;
    }
  } else {
    num_queries_ = 0;
  }
}

Guolin Ke's avatar
Guolin Ke committed
133
void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
Guolin Ke's avatar
Guolin Ke committed
134
  if (used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
135
136
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
137
  auto old_label = label_;
Guolin Ke's avatar
Guolin Ke committed
138
  num_data_ = static_cast<data_size_t>(used_indices.size());
139
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
140
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
141
142
143
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = old_label[used_indices[i]];
  }
Guolin Ke's avatar
Guolin Ke committed
144
  old_label.clear();
Guolin Ke's avatar
Guolin Ke committed
145
146
147
}

void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
148
149
  if (used_data_indices.empty()) {
    if (!queries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
Guolin Ke's avatar
Guolin Ke committed
165
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
166
167
168
169
170
171
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
172
      queries_.clear();
Guolin Ke's avatar
Guolin Ke committed
173
    }
Guolin Ke's avatar
Guolin Ke committed
174
    // check weights
Guolin Ke's avatar
Guolin Ke committed
175
    if (!weights_.empty() && num_weights_ != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
176
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
177
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
178
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
179
180
181
    }

    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
182
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
183
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
184
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
185
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
186
187
188
    }

    // contain initial score file
189
    if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
Guolin Ke's avatar
Guolin Ke committed
190
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
191
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
192
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
193
194
    }
  } else {
195
196
197
    if (!queries_.empty()) {
      Log::Fatal("Cannot used query_id for parallel training");
    }
Guolin Ke's avatar
Guolin Ke committed
198
199
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
200
201
202
203
204
205
206
207
208
209
    if (weight_load_from_file_) {
      if (weights_.size() > 0 && num_weights_ != num_all_data) {
        weights_.clear();
        num_weights_ = 0;
        Log::Fatal("Weights size doesn't match data size");
      }
      // get local weights
      if (!weights_.empty()) {
        auto old_weights = weights_;
        num_weights_ = num_data_;
210
        weights_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
211
#pragma omp parallel for schedule(static)
212
213
214
215
        for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
          weights_[i] = old_weights[used_data_indices[i]];
        }
        old_weights.clear();
Guolin Ke's avatar
Guolin Ke committed
216
217
      }
    }
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    if (query_load_from_file_) {
      // check query boundries
      if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
        query_boundaries_.clear();
        num_queries_ = 0;
        Log::Fatal("Query size doesn't match data size");
      }
      // get local query boundaries
      if (!query_boundaries_.empty()) {
        std::vector<data_size_t> used_query;
        data_size_t data_idx = 0;
        for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
          data_size_t start = query_boundaries_[qid];
          data_size_t end = query_boundaries_[qid + 1];
          data_size_t len = end - start;
          if (used_data_indices[data_idx] > start) {
            continue;
          } else if (used_data_indices[data_idx] == start) {
            if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
              used_query.push_back(qid);
              data_idx += len;
            } else {
              Log::Fatal("Data partition error, data didn't match queries");
            }
Guolin Ke's avatar
Guolin Ke committed
242
          } else {
Guolin Ke's avatar
Guolin Ke committed
243
            Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
244
245
          }
        }
246
247
248
249
250
251
252
253
254
255
        auto old_query_boundaries = query_boundaries_;
        query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
        num_queries_ = static_cast<data_size_t>(used_query.size());
        query_boundaries_[0] = 0;
        for (data_size_t i = 0; i < num_queries_; ++i) {
          data_size_t qid = used_query[i];
          data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
          query_boundaries_[i + 1] = query_boundaries_[i] + len;
        }
        old_query_boundaries.clear();
Guolin Ke's avatar
Guolin Ke committed
256
257
      }
    }
258
259
260
261
262
263
264
    if (init_score_load_from_file_) {
      // contain initial score file
      if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
        init_score_.clear();
        num_init_score_ = 0;
        Log::Fatal("Initial score size doesn't match data size");
      }
Guolin Ke's avatar
Guolin Ke committed
265

266
267
268
269
270
271
      // get local initial scores
      if (!init_score_.empty()) {
        auto old_scores = init_score_;
        int num_class = static_cast<int>(num_init_score_ / num_all_data);
        num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
        init_score_ = std::vector<double>(num_init_score_);
Guolin Ke's avatar
Guolin Ke committed
272
#pragma omp parallel for schedule(static)
273
274
275
276
        for (int k = 0; k < num_class; ++k) {
          for (size_t i = 0; i < used_data_indices.size(); ++i) {
            init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]];
          }
277
        }
278
        old_scores.clear();
Guolin Ke's avatar
Guolin Ke committed
279
280
281
282
283
284
285
      }
    }
    // re-load query weight
    LoadQueryWeights();
  }
}

286
287
288
289
290
291
292
293
294
295
296
297
298
void Metadata::SetInitScore(const double* init_score, data_size_t len) {
  std::lock_guard<std::mutex> lock(mutex_);
  // save to nullptr
  if (init_score == nullptr || len == 0) {
    init_score_.clear();
    num_init_score_ = 0;
    return;
  }
  if ((len % num_data_) != 0) {
    Log::Fatal("Initial score size doesn't match data size");
  }
  if (!init_score_.empty()) { init_score_.clear(); }
  num_init_score_ = len;
Guolin Ke's avatar
Guolin Ke committed
299
300
301
302
  init_score_ = std::vector<double>(len);
#pragma omp parallel for schedule(static)
  for (int64_t i = 0; i < num_init_score_; ++i) {
    init_score_[i] = init_score[i];
303
  }
304
  init_score_load_from_file_ = false;
305
306
}

307
void Metadata::SetLabel(const label_t* label, data_size_t len) {
308
  std::lock_guard<std::mutex> lock(mutex_);
309
310
311
  if (label == nullptr) {
    Log::Fatal("label cannot be nullptr");
  }
Guolin Ke's avatar
Guolin Ke committed
312
  if (num_data_ != len) {
313
    Log::Fatal("Length of label is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
314
  }
Guolin Ke's avatar
Guolin Ke committed
315
  if (!label_.empty()) { label_.clear(); }
316
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
317
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
318
319
320
321
322
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = label[i];
  }
}

323
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
324
  std::lock_guard<std::mutex> lock(mutex_);
325
326
327
328
329
330
  // save to nullptr
  if (weights == nullptr || len == 0) {
    weights_.clear();
    num_weights_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
331
  if (num_data_ != len) {
332
    Log::Fatal("Length of weights is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
333
  }
Guolin Ke's avatar
Guolin Ke committed
334
  if (!weights_.empty()) { weights_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
335
  num_weights_ = num_data_;
336
  weights_ = std::vector<label_t>(num_weights_);
Guolin Ke's avatar
Guolin Ke committed
337
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
338
339
340
341
  for (data_size_t i = 0; i < num_weights_; ++i) {
    weights_[i] = weights[i];
  }
  LoadQueryWeights();
342
  weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
343
344
}

Guolin Ke's avatar
Guolin Ke committed
345
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
346
  std::lock_guard<std::mutex> lock(mutex_);
347
  // save to nullptr
Guolin Ke's avatar
Guolin Ke committed
348
  if (query == nullptr || len == 0) {
349
350
351
352
    query_boundaries_.clear();
    num_queries_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
353
  data_size_t sum = 0;
Guolin Ke's avatar
Guolin Ke committed
354
#pragma omp parallel for schedule(static) reduction(+:sum)
Guolin Ke's avatar
Guolin Ke committed
355
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
356
    sum += query[i];
Guolin Ke's avatar
Guolin Ke committed
357
358
  }
  if (num_data_ != sum) {
359
    Log::Fatal("Sum of query counts is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
360
  }
Guolin Ke's avatar
Guolin Ke committed
361
  if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
362
  num_queries_ = len;
Guolin Ke's avatar
Guolin Ke committed
363
364
  query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
  query_boundaries_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
365
  for (data_size_t i = 0; i < num_queries_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
366
    query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
Guolin Ke's avatar
Guolin Ke committed
367
368
  }
  LoadQueryWeights();
369
  query_load_from_file_ = false;
370
}
Guolin Ke's avatar
Guolin Ke committed
371

Guolin Ke's avatar
Guolin Ke committed
372
373
374
375
376
void Metadata::LoadWeights() {
  num_weights_ = 0;
  std::string weight_filename(data_filename_);
  // default weight file name
  weight_filename.append(".weight");
Guolin Ke's avatar
Guolin Ke committed
377
  TextReader<size_t> reader(weight_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
378
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
379
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
380
381
    return;
  }
382
  Log::Info("Loading weights...");
Guolin Ke's avatar
Guolin Ke committed
383
  num_weights_ = static_cast<data_size_t>(reader.Lines().size());
384
  weights_ = std::vector<label_t>(num_weights_);
Guolin Ke's avatar
Guolin Ke committed
385
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
386
  for (data_size_t i = 0; i < num_weights_; ++i) {
387
    double tmp_weight = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
388
    Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
389
    weights_[i] = static_cast<label_t>(tmp_weight);
Guolin Ke's avatar
Guolin Ke committed
390
  }
391
  weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
392
393
}

394
void Metadata::LoadInitialScore(const char* initscore_file) {
Guolin Ke's avatar
Guolin Ke committed
395
  num_init_score_ = 0;
396
397
398
399
400
401
  std::string init_score_filename(initscore_file);
  if (init_score_filename.size() <= 0) {
    init_score_filename = std::string(data_filename_);
    // default weight file name
    init_score_filename.append(".init");
  }
Guolin Ke's avatar
Guolin Ke committed
402
  TextReader<size_t> reader(init_score_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
403
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
404
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
405
406
    return;
  }
407
408
  Log::Info("Loading initial scores...");

409
410
411
  // use first line to count number class
  int num_class = static_cast<int>(Common::Split(reader.Lines()[0].c_str(), '\t').size());
  data_size_t num_line = static_cast<data_size_t>(reader.Lines().size());
Guolin Ke's avatar
Guolin Ke committed
412
  num_init_score_ = static_cast<int64_t>(num_line) * num_class;
413

Guolin Ke's avatar
Guolin Ke committed
414
  init_score_ = std::vector<double>(num_init_score_);
415
  if (num_class == 1) {
Guolin Ke's avatar
Guolin Ke committed
416
#pragma omp parallel for schedule(static)
417
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
418
      double tmp = 0.0f;
419
      Common::Atof(reader.Lines()[i].c_str(), &tmp);
Guolin Ke's avatar
Guolin Ke committed
420
      init_score_[i] = static_cast<double>(tmp);
421
    }
422
  } else {
423
    std::vector<std::string> oneline_init_score;
Guolin Ke's avatar
Guolin Ke committed
424
#pragma omp parallel for schedule(static)
425
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
426
      double tmp = 0.0f;
427
428
      oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
      if (static_cast<int>(oneline_init_score.size()) != num_class) {
429
        Log::Fatal("Invalid initial score file. Redundant or insufficient columns");
430
431
432
      }
      for (int k = 0; k < num_class; ++k) {
        Common::Atof(oneline_init_score[k].c_str(), &tmp);
Guolin Ke's avatar
Guolin Ke committed
433
        init_score_[k * num_line + i] = static_cast<double>(tmp);
434
      }
435
    }
Guolin Ke's avatar
Guolin Ke committed
436
  }
437
  init_score_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
438
439
440
441
442
443
444
}

void Metadata::LoadQueryBoundaries() {
  num_queries_ = 0;
  std::string query_filename(data_filename_);
  // default query file name
  query_filename.append(".query");
Guolin Ke's avatar
Guolin Ke committed
445
  TextReader<size_t> reader(query_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
446
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
447
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
448
449
    return;
  }
450
  Log::Info("Loading query boundaries...");
Guolin Ke's avatar
Guolin Ke committed
451
  query_boundaries_ = std::vector<data_size_t>(reader.Lines().size() + 1);
Guolin Ke's avatar
Guolin Ke committed
452
453
454
455
456
457
458
  num_queries_ = static_cast<data_size_t>(reader.Lines().size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < reader.Lines().size(); ++i) {
    int tmp_cnt;
    Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
    query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
  }
459
  query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
460
461
462
}

void Metadata::LoadQueryWeights() {
Guolin Ke's avatar
Guolin Ke committed
463
  if (weights_.size() == 0 || query_boundaries_.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
464
465
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
466
  query_weights_.clear();
467
  Log::Info("Loading query weights...");
468
  query_weights_ = std::vector<label_t>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_weights_[i] = 0.0f;
    for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
      query_weights_[i] += weights_[j];
    }
    query_weights_[i] /= (query_boundaries_[i + 1] - query_boundaries_[i]);
  }
}

void Metadata::LoadFromMemory(const void* memory) {
  const char* mem_ptr = reinterpret_cast<const char*>(memory);

  num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_data_);
  num_weights_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_weights_);
  num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_queries_);

Guolin Ke's avatar
Guolin Ke committed
488
  if (!label_.empty()) { label_.clear(); }
489
490
491
  label_ = std::vector<label_t>(num_data_);
  std::memcpy(label_.data(), mem_ptr, sizeof(label_t)*num_data_);
  mem_ptr += sizeof(label_t)*num_data_;
Guolin Ke's avatar
Guolin Ke committed
492
493

  if (num_weights_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
494
    if (!weights_.empty()) { weights_.clear(); }
495
496
497
    weights_ = std::vector<label_t>(num_weights_);
    std::memcpy(weights_.data(), mem_ptr, sizeof(label_t)*num_weights_);
    mem_ptr += sizeof(label_t)*num_weights_;
498
    weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
499
500
  }
  if (num_queries_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
501
    if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
502
503
    query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
    std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t)*(num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
504
    mem_ptr += sizeof(data_size_t)*(num_queries_ + 1);
505
    query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
506
  }
Guolin Ke's avatar
Guolin Ke committed
507
  LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
508
509
}

510
511
512
513
514
void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
  writer->Write(&num_data_, sizeof(num_data_));
  writer->Write(&num_weights_, sizeof(num_weights_));
  writer->Write(&num_queries_, sizeof(num_queries_));
  writer->Write(label_.data(), sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
515
  if (!weights_.empty()) {
516
    writer->Write(weights_.data(), sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
517
  }
Guolin Ke's avatar
Guolin Ke committed
518
  if (!query_boundaries_.empty()) {
519
    writer->Write(query_boundaries_.data(), sizeof(data_size_t) * (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
520
521
522
523
524
525
  }
}

size_t Metadata::SizesInByte() const  {
  size_t size = sizeof(num_data_) + sizeof(num_weights_)
    + sizeof(num_queries_);
526
  size += sizeof(label_t) * num_data_;
Guolin Ke's avatar
Guolin Ke committed
527
  if (!weights_.empty()) {
528
    size += sizeof(label_t) * num_weights_;
Guolin Ke's avatar
Guolin Ke committed
529
  }
Guolin Ke's avatar
Guolin Ke committed
530
  if (!query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
531
532
533
534
535
536
537
    size += sizeof(data_size_t) * (num_queries_ + 1);
  }
  return size;
}


}  // namespace LightGBM