metadata.cpp 17.2 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
#include <LightGBM/dataset.h>

#include <LightGBM/utils/common.h>

#include <vector>
#include <string>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
10
Metadata::Metadata() {
Guolin Ke's avatar
Guolin Ke committed
11
12
}

13
void Metadata::Init(const char * data_filename) {
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
18
19
20
21
22
23
24
  data_filename_ = data_filename;
  // for lambdarank, it needs query data for partition data in parallel learning
  LoadQueryBoundaries();
  LoadWeights();
  LoadQueryWeights();
  LoadInitialScore();
}

Metadata::~Metadata() {
}

25
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
Guolin Ke's avatar
Guolin Ke committed
26
  num_data_ = num_data;
Guolin Ke's avatar
Guolin Ke committed
27
  label_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
28
  if (weight_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
29
    if (!weights_.empty()) {
30
      Log::Info("Using weights in data file, ignoring the additional weights file");
Guolin Ke's avatar
Guolin Ke committed
31
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
32
    }
Guolin Ke's avatar
Guolin Ke committed
33
    weights_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
34
    num_weights_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
35
    std::fill(weights_.begin(), weights_.end(), 0.0f);
Guolin Ke's avatar
Guolin Ke committed
36
37
  }
  if (query_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
38
    if (!query_boundaries_.empty()) {
39
      Log::Info("Using query id in data file, ignoring the additional query file");
Guolin Ke's avatar
Guolin Ke committed
40
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
41
    }
Guolin Ke's avatar
Guolin Ke committed
42
    if (!query_weights_.empty()) { query_weights_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
43
44
    queries_ = std::vector<data_size_t>(num_data_);
    std::fill(queries_.begin(), queries_.end(), 0);
Guolin Ke's avatar
Guolin Ke committed
45
  }
Guolin Ke's avatar
Guolin Ke committed
46
47
}

Guolin Ke's avatar
Guolin Ke committed
48
49
50
51
52
53
54
55
void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) {
  num_data_ = num_used_indices;

  label_ = std::vector<float>(num_used_indices);
  for (data_size_t i = 0; i < num_used_indices; i++) {
    label_[i] = fullset.label_[used_indices[i]];
  }

Guolin Ke's avatar
Guolin Ke committed
56
  if (!fullset.weights_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
62
63
64
65
    weights_ = std::vector<float>(num_used_indices);
    num_weights_ = num_used_indices;
    for (data_size_t i = 0; i < num_used_indices; i++) {
      weights_[i] = fullset.weights_[used_indices[i]];
    }
  } else {
    num_weights_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
66
  if (!fullset.init_score_.empty()) {
67
68
69
70
71
72
73
    int num_class = static_cast<int>(fullset.num_init_score_) / fullset.num_data_;
    init_score_ = std::vector<float>(num_used_indices*num_class);
    num_init_score_ = num_used_indices*num_class;
    for (int k = 0; k < num_class; ++k) {
      for (data_size_t i = 0; i < num_used_indices; i++) {
        init_score_[k*num_data_ + i] = fullset.init_score_[k* fullset.num_data_ + used_indices[i]];
      }
Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
78
    }
  } else {
    num_init_score_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
79
  if (!fullset.query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    std::vector<data_size_t> used_query;
    data_size_t data_idx = 0;
    for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_indices; ++qid) {
      data_size_t start = fullset.query_boundaries_[qid];
      data_size_t end = fullset.query_boundaries_[qid + 1];
      data_size_t len = end - start;
      if (used_indices[data_idx] > start) {
        continue;
      } else if (used_indices[data_idx] == start) {
        if (num_used_indices >= data_idx + len && used_indices[data_idx + len - 1] == end - 1) {
          used_query.push_back(qid);
          data_idx += len;
        } else {
          Log::Fatal("Data partition error, data didn't match queries");
        }
      } else {
        Log::Fatal("Data partition error, data didn't match queries");
      }
    }
    query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
    num_queries_ = static_cast<data_size_t>(used_query.size());
    query_boundaries_[0] = 0;
    for (data_size_t i = 0; i < num_queries_; ++i) {
      data_size_t qid = used_query[i];
      data_size_t len = fullset.query_boundaries_[qid + 1] - fullset.query_boundaries_[qid];
      query_boundaries_[i + 1] = query_boundaries_[i] + len;
    }
  } else {
    num_queries_ = 0;
  }

}

Guolin Ke's avatar
Guolin Ke committed
113
void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
Guolin Ke's avatar
Guolin Ke committed
114
  if (used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
115
116
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
117
  auto old_label = label_;
Guolin Ke's avatar
Guolin Ke committed
118
  num_data_ = static_cast<data_size_t>(used_indices.size());
Guolin Ke's avatar
Guolin Ke committed
119
  label_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
120
121
122
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = old_label[used_indices[i]];
  }
Guolin Ke's avatar
Guolin Ke committed
123
  old_label.clear();
Guolin Ke's avatar
Guolin Ke committed
124
125
126
}

void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
127
128
  if (used_data_indices.empty()) {
    if (!queries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
Guolin Ke's avatar
Guolin Ke committed
144
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
151
      queries_.clear();
Guolin Ke's avatar
Guolin Ke committed
152
    }
Guolin Ke's avatar
Guolin Ke committed
153
    // check weights
Guolin Ke's avatar
Guolin Ke committed
154
    if (!weights_.empty() && num_weights_ != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
155
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
156
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
157
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
158
159
160
    }

    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
161
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
162
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
163
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
164
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
165
166
167
    }

    // contain initial score file
168
    if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
Guolin Ke's avatar
Guolin Ke committed
169
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
170
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
171
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
    }
  } else {
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
Guolin Ke's avatar
Guolin Ke committed
176
177
    if (weights_.size() > 0 && num_weights_ != num_all_data) {
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
178
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
179
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
180
181
    }
    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
182
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
Guolin Ke's avatar
Guolin Ke committed
183
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
184
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
185
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
186
187
188
    }

    // contain initial score file
189
    if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
Guolin Ke's avatar
Guolin Ke committed
190
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
191
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
192
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
193
194
195
    }

    // get local weights
Guolin Ke's avatar
Guolin Ke committed
196
    if (!weights_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
197
      auto old_weights = weights_;
Guolin Ke's avatar
Guolin Ke committed
198
      num_weights_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
199
      weights_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
200
201
202
      for (size_t i = 0; i < used_data_indices.size(); ++i) {
        weights_[i] = old_weights[used_data_indices[i]];
      }
Guolin Ke's avatar
Guolin Ke committed
203
      old_weights.clear();
Guolin Ke's avatar
Guolin Ke committed
204
205
206
    }

    // get local query boundaries
Guolin Ke's avatar
Guolin Ke committed
207
    if (!query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
208
209
210
211
212
213
214
215
216
217
218
219
220
      std::vector<data_size_t> used_query;
      data_size_t data_idx = 0;
      for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
        data_size_t start = query_boundaries_[qid];
        data_size_t end = query_boundaries_[qid + 1];
        data_size_t len = end - start;
        if (used_data_indices[data_idx] > start) {
          continue;
        } else if (used_data_indices[data_idx] == start) {
          if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
            used_query.push_back(qid);
            data_idx += len;
          } else {
Guolin Ke's avatar
Guolin Ke committed
221
            Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
222
223
          }
        } else {
Guolin Ke's avatar
Guolin Ke committed
224
          Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
225
226
        }
      }
Guolin Ke's avatar
Guolin Ke committed
227
228
      auto old_query_boundaries = query_boundaries_;
      query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
229
230
231
232
233
234
235
      num_queries_ = static_cast<data_size_t>(used_query.size());
      query_boundaries_[0] = 0;
      for (data_size_t i = 0; i < num_queries_; ++i) {
        data_size_t qid = used_query[i];
        data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
        query_boundaries_[i + 1] = query_boundaries_[i] + len;
      }
Guolin Ke's avatar
Guolin Ke committed
236
      old_query_boundaries.clear();
Guolin Ke's avatar
Guolin Ke committed
237
238
239
    }

    // get local initial scores
Guolin Ke's avatar
Guolin Ke committed
240
    if (!init_score_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
241
      auto old_scores = init_score_;
242
243
244
245
      int num_class = num_init_score_ / num_all_data;
      num_init_score_ = num_data_ * num_class;
      init_score_ = std::vector<float>(num_init_score_);
      for (int k = 0; k < num_class; ++k){
246
247
        for (size_t i = 0; i < used_data_indices.size(); ++i) {
          init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]];
248
        }
Guolin Ke's avatar
Guolin Ke committed
249
      }
Guolin Ke's avatar
Guolin Ke committed
250
      old_scores.clear();
Guolin Ke's avatar
Guolin Ke committed
251
252
253
254
255
256
257
258
    }

    // re-load query weight
    LoadQueryWeights();
  }
}


Guolin Ke's avatar
Guolin Ke committed
259
void Metadata::SetInitScore(const float* init_score, data_size_t len) {
260
  std::lock_guard<std::mutex> lock(mutex_);
261
262
263
264
265
266
  // save to nullptr
  if (init_score == nullptr || len == 0) {
    init_score_.clear();
    num_init_score_ = 0;
    return;
  }
267
  if ((len % num_data_) != 0) {
268
    Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
269
  }
Guolin Ke's avatar
Guolin Ke committed
270
  if (!init_score_.empty()) { init_score_.clear(); }
271
  num_init_score_ = len;
Guolin Ke's avatar
Guolin Ke committed
272
  init_score_ = std::vector<float>(len);
273
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
274
275
    init_score_[i] = init_score[i];
  }
Guolin Ke's avatar
Guolin Ke committed
276
277
}

Guolin Ke's avatar
Guolin Ke committed
278
void Metadata::SetLabel(const float* label, data_size_t len) {
279
  std::lock_guard<std::mutex> lock(mutex_);
280
281
282
  if (label == nullptr) {
    Log::Fatal("label cannot be nullptr");
  }
Guolin Ke's avatar
Guolin Ke committed
283
284
285
  if (num_data_ != len) {
    Log::Fatal("len of label is not same with #data");
  }
Guolin Ke's avatar
Guolin Ke committed
286
  if (!label_.empty()) { label_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
287
  label_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
288
289
290
291
292
293
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = label[i];
  }
}

void Metadata::SetWeights(const float* weights, data_size_t len) {
294
  std::lock_guard<std::mutex> lock(mutex_);
295
296
297
298
299
300
  // save to nullptr
  if (weights == nullptr || len == 0) {
    weights_.clear();
    num_weights_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
301
302
303
  if (num_data_ != len) {
    Log::Fatal("len of weights is not same with #data");
  }
Guolin Ke's avatar
Guolin Ke committed
304
  if (!weights_.empty()) { weights_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
305
  num_weights_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
306
  weights_ = std::vector<float>(num_weights_);
Guolin Ke's avatar
Guolin Ke committed
307
308
309
310
311
312
  for (data_size_t i = 0; i < num_weights_; ++i) {
    weights_[i] = weights[i];
  }
  LoadQueryWeights();
}

Guolin Ke's avatar
Guolin Ke committed
313
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
314
  std::lock_guard<std::mutex> lock(mutex_);
315
  // save to nullptr
Guolin Ke's avatar
Guolin Ke committed
316
  if (query == nullptr || len == 0) {
317
318
319
320
    query_boundaries_.clear();
    num_queries_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
321
322
  data_size_t sum = 0;
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
323
    sum += query[i];
Guolin Ke's avatar
Guolin Ke committed
324
325
326
327
  }
  if (num_data_ != sum) {
    Log::Fatal("sum of query counts is not same with #data");
  }
Guolin Ke's avatar
Guolin Ke committed
328
  if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
329
  num_queries_ = len;
Guolin Ke's avatar
Guolin Ke committed
330
331
  query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
  query_boundaries_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
332
  for (data_size_t i = 0; i < num_queries_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
333
    query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
Guolin Ke's avatar
Guolin Ke committed
334
335
336
337
  }
  LoadQueryWeights();
}

338
void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) {
339
  std::lock_guard<std::mutex> lock(mutex_);
340
341
342
343
344
345
346
  // save to nullptr
  if (query_id == nullptr || len == 0) {
    query_boundaries_.clear();
    queries_.clear();
    num_queries_ = 0;
    return;
  }
347
348
349
  if (num_data_ != len) {
    Log::Fatal("len of query id is not same with #data");
  }
Guolin Ke's avatar
Guolin Ke committed
350
  if (!queries_.empty()) { queries_.clear(); }
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
  queries_ = std::vector<data_size_t>(num_data_);
  for (data_size_t i = 0; i < num_weights_; ++i) {
    queries_[i] = query_id[i];
  }
  // need convert query_id to boundaries
  std::vector<data_size_t> tmp_buffer;
  data_size_t last_qid = -1;
  data_size_t cur_cnt = 0;
  for (data_size_t i = 0; i < num_data_; ++i) {
    if (last_qid != queries_[i]) {
      if (cur_cnt > 0) {
        tmp_buffer.push_back(cur_cnt);
      }
      cur_cnt = 0;
      last_qid = queries_[i];
    }
    ++cur_cnt;
  }
  tmp_buffer.push_back(cur_cnt);
  query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
  num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < tmp_buffer.size(); ++i) {
    query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
  }
  queries_.clear();
  LoadQueryWeights();
}
Guolin Ke's avatar
Guolin Ke committed
379

Guolin Ke's avatar
Guolin Ke committed
380
381
382
383
384
void Metadata::LoadWeights() {
  num_weights_ = 0;
  std::string weight_filename(data_filename_);
  // default weight file name
  weight_filename.append(".weight");
Guolin Ke's avatar
Guolin Ke committed
385
  TextReader<size_t> reader(weight_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
386
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
387
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
388
389
    return;
  }
390
  Log::Info("Loading weights...");
Guolin Ke's avatar
Guolin Ke committed
391
  num_weights_ = static_cast<data_size_t>(reader.Lines().size());
Guolin Ke's avatar
Guolin Ke committed
392
  weights_ = std::vector<float>(num_weights_);
Guolin Ke's avatar
Guolin Ke committed
393
  for (data_size_t i = 0; i < num_weights_; ++i) {
394
    double tmp_weight = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
395
    Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
396
    weights_[i] = static_cast<float>(tmp_weight);
Guolin Ke's avatar
Guolin Ke committed
397
398
399
400
401
  }
}

void Metadata::LoadInitialScore() {
  num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
402
403
404
405
  std::string init_score_filename(data_filename_);
  // default weight file name
  init_score_filename.append(".init");
  TextReader<size_t> reader(init_score_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
406
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
407
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
408
409
    return;
  }
410
411
  Log::Info("Loading initial scores...");

412
413
414
415
416
417
  // use first line to count number class
  int num_class = static_cast<int>(Common::Split(reader.Lines()[0].c_str(), '\t').size());
  data_size_t num_line = static_cast<data_size_t>(reader.Lines().size());
  num_init_score_ = static_cast<data_size_t>(num_line * num_class);

  init_score_ = std::vector<float>(num_init_score_);
418
  double tmp = 0.0f;
419

420
421
422
423
424
  if (num_class == 1) {
    for (data_size_t i = 0; i < num_line; ++i) {
      Common::Atof(reader.Lines()[i].c_str(), &tmp);
      init_score_[i] = static_cast<float>(tmp);
    }
425
  } else {
426
427
428
429
430
431
432
433
434
    std::vector<std::string> oneline_init_score;
    for (data_size_t i = 0; i < num_line; ++i) {
      oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
      if (static_cast<int>(oneline_init_score.size()) != num_class) {
        Log::Fatal("Invalid initial score file. Redundant or insufficient columns.");
      }
      for (int k = 0; k < num_class; ++k) {
        Common::Atof(oneline_init_score[k].c_str(), &tmp);
        init_score_[k * num_line + i] = static_cast<float>(tmp);
435
      }
436
    }
Guolin Ke's avatar
Guolin Ke committed
437
438
439
440
441
442
443
444
  }
}

void Metadata::LoadQueryBoundaries() {
  num_queries_ = 0;
  std::string query_filename(data_filename_);
  // default query file name
  query_filename.append(".query");
Guolin Ke's avatar
Guolin Ke committed
445
  TextReader<size_t> reader(query_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
446
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
447
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
448
449
    return;
  }
450
  Log::Info("Loading query boundaries...");
Guolin Ke's avatar
Guolin Ke committed
451
  query_boundaries_ = std::vector<data_size_t>(reader.Lines().size() + 1);
Guolin Ke's avatar
Guolin Ke committed
452
453
454
455
456
457
458
459
460
461
  num_queries_ = static_cast<data_size_t>(reader.Lines().size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < reader.Lines().size(); ++i) {
    int tmp_cnt;
    Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
    query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
  }
}

void Metadata::LoadQueryWeights() {
Guolin Ke's avatar
Guolin Ke committed
462
  if (weights_.size() == 0 || query_boundaries_.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
463
464
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
465
  query_weights_.clear();
466
  Log::Info("Loading query weights...");
Guolin Ke's avatar
Guolin Ke committed
467
  query_weights_ = std::vector<float>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_weights_[i] = 0.0f;
    for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
      query_weights_[i] += weights_[j];
    }
    query_weights_[i] /= (query_boundaries_[i + 1] - query_boundaries_[i]);
  }
}

void Metadata::LoadFromMemory(const void* memory) {
  const char* mem_ptr = reinterpret_cast<const char*>(memory);

  num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_data_);
  num_weights_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_weights_);
  num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_queries_);

Guolin Ke's avatar
Guolin Ke committed
487
  if (!label_.empty()) { label_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
488
489
  label_ = std::vector<float>(num_data_);
  std::memcpy(label_.data(), mem_ptr, sizeof(float)*num_data_);
Guolin Ke's avatar
Guolin Ke committed
490
  mem_ptr += sizeof(float)*num_data_;
Guolin Ke's avatar
Guolin Ke committed
491
492

  if (num_weights_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
493
    if (!weights_.empty()) { weights_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
494
495
    weights_ = std::vector<float>(num_weights_);
    std::memcpy(weights_.data(), mem_ptr, sizeof(float)*num_weights_);
Guolin Ke's avatar
Guolin Ke committed
496
497
498
    mem_ptr += sizeof(float)*num_weights_;
  }
  if (num_queries_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
499
    if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
500
501
    query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
    std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t)*(num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
502
503
    mem_ptr += sizeof(data_size_t)*(num_queries_ + 1);
  }
Guolin Ke's avatar
Guolin Ke committed
504
  LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
505
506
507
508
509
510
}

void Metadata::SaveBinaryToFile(FILE* file) const {
  fwrite(&num_data_, sizeof(num_data_), 1, file);
  fwrite(&num_weights_, sizeof(num_weights_), 1, file);
  fwrite(&num_queries_, sizeof(num_queries_), 1, file);
Guolin Ke's avatar
Guolin Ke committed
511
  fwrite(label_.data(), sizeof(float), num_data_, file);
Guolin Ke's avatar
Guolin Ke committed
512
  if (!weights_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
513
    fwrite(weights_.data(), sizeof(float), num_weights_, file);
Guolin Ke's avatar
Guolin Ke committed
514
  }
Guolin Ke's avatar
Guolin Ke committed
515
  if (!query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
516
    fwrite(query_boundaries_.data(), sizeof(data_size_t), num_queries_ + 1, file);
Guolin Ke's avatar
Guolin Ke committed
517
518
519
520
521
522
523
524
  }

}

size_t Metadata::SizesInByte() const  {
  size_t size = sizeof(num_data_) + sizeof(num_weights_)
    + sizeof(num_queries_);
  size += sizeof(float) * num_data_;
Guolin Ke's avatar
Guolin Ke committed
525
  if (!weights_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
526
527
    size += sizeof(float) * num_weights_;
  }
Guolin Ke's avatar
Guolin Ke committed
528
  if (!query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
529
530
531
532
533
534
535
    size += sizeof(data_size_t) * (num_queries_ + 1);
  }
  return size;
}


}  // namespace LightGBM