metadata.cpp 18.3 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h>

#include <string>
9
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
10
11
12

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
13
Metadata::Metadata() {
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
  num_weights_ = 0;
  num_init_score_ = 0;
  num_data_ = 0;
  num_queries_ = 0;
18
19
20
  weight_load_from_file_ = false;
  query_load_from_file_ = false;
  init_score_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
21
22
}

23
void Metadata::Init(const char* data_filename, const char* initscore_file) {
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
28
  data_filename_ = data_filename;
  // for lambdarank, it needs query data for partition data in parallel learning
  LoadQueryBoundaries();
  LoadWeights();
  LoadQueryWeights();
29
  LoadInitialScore(initscore_file);
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
}

Metadata::~Metadata() {
}

35
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
Guolin Ke's avatar
Guolin Ke committed
36
  num_data_ = num_data;
37
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
38
  if (weight_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
39
    if (!weights_.empty()) {
40
      Log::Info("Using weights in data file, ignoring the additional weights file");
Guolin Ke's avatar
Guolin Ke committed
41
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
42
    }
43
    weights_ = std::vector<label_t>(num_data_, 0.0f);
Guolin Ke's avatar
Guolin Ke committed
44
    num_weights_ = num_data_;
45
    weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
46
47
  }
  if (query_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
48
    if (!query_boundaries_.empty()) {
49
      Log::Info("Using query id in data file, ignoring the additional query file");
Guolin Ke's avatar
Guolin Ke committed
50
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
51
    }
Guolin Ke's avatar
Guolin Ke committed
52
    if (!query_weights_.empty()) { query_weights_.clear(); }
53
    queries_ = std::vector<data_size_t>(num_data_, 0);
54
    query_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
55
  }
Guolin Ke's avatar
Guolin Ke committed
56
57
}

Guolin Ke's avatar
Guolin Ke committed
58
59
60
void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) {
  num_data_ = num_used_indices;

61
  label_ = std::vector<label_t>(num_used_indices);
62
63
  #pragma omp parallel for schedule(static)
  for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
64
65
66
    label_[i] = fullset.label_[used_indices[i]];
  }

Guolin Ke's avatar
Guolin Ke committed
67
  if (!fullset.weights_.empty()) {
68
    weights_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
69
    num_weights_ = num_used_indices;
70
71
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
      weights_[i] = fullset.weights_[used_indices[i]];
    }
  } else {
    num_weights_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
78
  if (!fullset.init_score_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
79
    int num_class = static_cast<int>(fullset.num_init_score_ / fullset.num_data_);
80
    init_score_ = std::vector<double>(static_cast<size_t>(num_used_indices) * num_class);
Guolin Ke's avatar
Guolin Ke committed
81
    num_init_score_ = static_cast<int64_t>(num_used_indices) * num_class;
82
    #pragma omp parallel for schedule(static)
83
    for (int k = 0; k < num_class; ++k) {
84
85
86
87
      const size_t offset_dest = static_cast<size_t>(k) * num_data_;
      const size_t offset_src = static_cast<size_t>(k) * fullset.num_data_;
      for (data_size_t i = 0; i < num_used_indices; ++i) {
        init_score_[offset_dest + i] = fullset.init_score_[offset_src + used_indices[i]];
88
      }
Guolin Ke's avatar
Guolin Ke committed
89
90
91
92
93
    }
  } else {
    num_init_score_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
94
  if (!fullset.query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    std::vector<data_size_t> used_query;
    data_size_t data_idx = 0;
    for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_indices; ++qid) {
      data_size_t start = fullset.query_boundaries_[qid];
      data_size_t end = fullset.query_boundaries_[qid + 1];
      data_size_t len = end - start;
      if (used_indices[data_idx] > start) {
        continue;
      } else if (used_indices[data_idx] == start) {
        if (num_used_indices >= data_idx + len && used_indices[data_idx + len - 1] == end - 1) {
          used_query.push_back(qid);
          data_idx += len;
        } else {
          Log::Fatal("Data partition error, data didn't match queries");
        }
      } else {
        Log::Fatal("Data partition error, data didn't match queries");
      }
    }
    query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
    num_queries_ = static_cast<data_size_t>(used_query.size());
    query_boundaries_[0] = 0;
    for (data_size_t i = 0; i < num_queries_; ++i) {
      data_size_t qid = used_query[i];
      data_size_t len = fullset.query_boundaries_[qid + 1] - fullset.query_boundaries_[qid];
      query_boundaries_[i + 1] = query_boundaries_[i] + len;
    }
  } else {
    num_queries_ = 0;
  }
}

Guolin Ke's avatar
Guolin Ke committed
127
void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
Guolin Ke's avatar
Guolin Ke committed
128
  if (used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
129
130
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
131
  auto old_label = label_;
Guolin Ke's avatar
Guolin Ke committed
132
  num_data_ = static_cast<data_size_t>(used_indices.size());
133
  label_ = std::vector<label_t>(num_data_);
134
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
135
136
137
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = old_label[used_indices[i]];
  }
Guolin Ke's avatar
Guolin Ke committed
138
  old_label.clear();
Guolin Ke's avatar
Guolin Ke committed
139
140
141
}

void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
142
143
  if (used_data_indices.empty()) {
    if (!queries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
Guolin Ke's avatar
Guolin Ke committed
159
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
164
165
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
166
      queries_.clear();
Guolin Ke's avatar
Guolin Ke committed
167
    }
Guolin Ke's avatar
Guolin Ke committed
168
    // check weights
Guolin Ke's avatar
Guolin Ke committed
169
    if (!weights_.empty() && num_weights_ != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
170
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
171
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
172
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
173
174
175
    }

    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
176
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
177
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
178
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
179
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
180
181
182
    }

    // contain initial score file
183
    if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
Guolin Ke's avatar
Guolin Ke committed
184
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
185
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
186
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
187
188
    }
  } else {
189
190
191
    if (!queries_.empty()) {
      Log::Fatal("Cannot used query_id for parallel training");
    }
Guolin Ke's avatar
Guolin Ke committed
192
193
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
194
195
196
197
198
199
200
201
202
203
    if (weight_load_from_file_) {
      if (weights_.size() > 0 && num_weights_ != num_all_data) {
        weights_.clear();
        num_weights_ = 0;
        Log::Fatal("Weights size doesn't match data size");
      }
      // get local weights
      if (!weights_.empty()) {
        auto old_weights = weights_;
        num_weights_ = num_data_;
204
        weights_ = std::vector<label_t>(num_data_);
205
        #pragma omp parallel for schedule(static)
206
207
208
209
        for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
          weights_[i] = old_weights[used_data_indices[i]];
        }
        old_weights.clear();
Guolin Ke's avatar
Guolin Ke committed
210
211
      }
    }
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    if (query_load_from_file_) {
      // check query boundries
      if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
        query_boundaries_.clear();
        num_queries_ = 0;
        Log::Fatal("Query size doesn't match data size");
      }
      // get local query boundaries
      if (!query_boundaries_.empty()) {
        std::vector<data_size_t> used_query;
        data_size_t data_idx = 0;
        for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
          data_size_t start = query_boundaries_[qid];
          data_size_t end = query_boundaries_[qid + 1];
          data_size_t len = end - start;
          if (used_data_indices[data_idx] > start) {
            continue;
          } else if (used_data_indices[data_idx] == start) {
            if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
              used_query.push_back(qid);
              data_idx += len;
            } else {
              Log::Fatal("Data partition error, data didn't match queries");
            }
Guolin Ke's avatar
Guolin Ke committed
236
          } else {
Guolin Ke's avatar
Guolin Ke committed
237
            Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
238
239
          }
        }
240
241
242
243
244
245
246
247
248
249
        auto old_query_boundaries = query_boundaries_;
        query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
        num_queries_ = static_cast<data_size_t>(used_query.size());
        query_boundaries_[0] = 0;
        for (data_size_t i = 0; i < num_queries_; ++i) {
          data_size_t qid = used_query[i];
          data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
          query_boundaries_[i + 1] = query_boundaries_[i] + len;
        }
        old_query_boundaries.clear();
Guolin Ke's avatar
Guolin Ke committed
250
251
      }
    }
252
253
254
255
256
257
258
    if (init_score_load_from_file_) {
      // contain initial score file
      if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
        init_score_.clear();
        num_init_score_ = 0;
        Log::Fatal("Initial score size doesn't match data size");
      }
Guolin Ke's avatar
Guolin Ke committed
259

260
261
262
263
264
265
      // get local initial scores
      if (!init_score_.empty()) {
        auto old_scores = init_score_;
        int num_class = static_cast<int>(num_init_score_ / num_all_data);
        num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
        init_score_ = std::vector<double>(num_init_score_);
266
        #pragma omp parallel for schedule(static)
267
        for (int k = 0; k < num_class; ++k) {
268
269
          const size_t offset_dest = static_cast<size_t>(k) * num_data_;
          const size_t offset_src = static_cast<size_t>(k) * num_all_data;
270
          for (size_t i = 0; i < used_data_indices.size(); ++i) {
271
            init_score_[offset_dest + i] = old_scores[offset_src + used_data_indices[i]];
272
          }
273
        }
274
        old_scores.clear();
Guolin Ke's avatar
Guolin Ke committed
275
276
277
278
279
280
281
      }
    }
    // re-load query weight
    LoadQueryWeights();
  }
}

282
283
284
285
286
287
288
289
290
291
292
void Metadata::SetInitScore(const double* init_score, data_size_t len) {
  std::lock_guard<std::mutex> lock(mutex_);
  // save to nullptr
  if (init_score == nullptr || len == 0) {
    init_score_.clear();
    num_init_score_ = 0;
    return;
  }
  if ((len % num_data_) != 0) {
    Log::Fatal("Initial score size doesn't match data size");
  }
293
  if (init_score_.empty()) { init_score_.resize(len); }
294
  num_init_score_ = len;
295

296
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
297
  for (int64_t i = 0; i < num_init_score_; ++i) {
298
    init_score_[i] = Common::AvoidInf(init_score[i]);
299
  }
300
  init_score_load_from_file_ = false;
301
302
}

303
void Metadata::SetLabel(const label_t* label, data_size_t len) {
304
  std::lock_guard<std::mutex> lock(mutex_);
305
306
307
  if (label == nullptr) {
    Log::Fatal("label cannot be nullptr");
  }
Guolin Ke's avatar
Guolin Ke committed
308
  if (num_data_ != len) {
309
    Log::Fatal("Length of label is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
310
  }
311
  if (label_.empty()) { label_.resize(num_data_); }
312

313
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
314
  for (data_size_t i = 0; i < num_data_; ++i) {
315
    label_[i] = Common::AvoidInf(label[i]);
Guolin Ke's avatar
Guolin Ke committed
316
317
318
  }
}

319
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
320
  std::lock_guard<std::mutex> lock(mutex_);
321
322
323
324
325
326
  // save to nullptr
  if (weights == nullptr || len == 0) {
    weights_.clear();
    num_weights_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
327
  if (num_data_ != len) {
328
    Log::Fatal("Length of weights is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
329
  }
330
  if (weights_.empty()) { weights_.resize(num_data_); }
Guolin Ke's avatar
Guolin Ke committed
331
  num_weights_ = num_data_;
332

333
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
334
  for (data_size_t i = 0; i < num_weights_; ++i) {
335
    weights_[i] = Common::AvoidInf(weights[i]);
Guolin Ke's avatar
Guolin Ke committed
336
337
  }
  LoadQueryWeights();
338
  weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
339
340
}

Guolin Ke's avatar
Guolin Ke committed
341
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
342
  std::lock_guard<std::mutex> lock(mutex_);
343
  // save to nullptr
Guolin Ke's avatar
Guolin Ke committed
344
  if (query == nullptr || len == 0) {
345
346
347
348
    query_boundaries_.clear();
    num_queries_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
349
  data_size_t sum = 0;
350
  #pragma omp parallel for schedule(static) reduction(+:sum)
Guolin Ke's avatar
Guolin Ke committed
351
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
352
    sum += query[i];
Guolin Ke's avatar
Guolin Ke committed
353
354
  }
  if (num_data_ != sum) {
355
    Log::Fatal("Sum of query counts is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
356
357
  }
  num_queries_ = len;
358
  query_boundaries_.resize(num_queries_ + 1);
Guolin Ke's avatar
Guolin Ke committed
359
  query_boundaries_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
360
  for (data_size_t i = 0; i < num_queries_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
361
    query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
Guolin Ke's avatar
Guolin Ke committed
362
363
  }
  LoadQueryWeights();
364
  query_load_from_file_ = false;
365
}
Guolin Ke's avatar
Guolin Ke committed
366

Guolin Ke's avatar
Guolin Ke committed
367
368
369
370
371
void Metadata::LoadWeights() {
  num_weights_ = 0;
  std::string weight_filename(data_filename_);
  // default weight file name
  weight_filename.append(".weight");
Guolin Ke's avatar
Guolin Ke committed
372
  TextReader<size_t> reader(weight_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
373
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
374
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
375
376
    return;
  }
377
  Log::Info("Loading weights...");
Guolin Ke's avatar
Guolin Ke committed
378
  num_weights_ = static_cast<data_size_t>(reader.Lines().size());
379
  weights_ = std::vector<label_t>(num_weights_);
380
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
381
  for (data_size_t i = 0; i < num_weights_; ++i) {
382
    double tmp_weight = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
383
    Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
384
    weights_[i] = Common::AvoidInf(static_cast<label_t>(tmp_weight));
Guolin Ke's avatar
Guolin Ke committed
385
  }
386
  weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
387
388
}

389
void Metadata::LoadInitialScore(const char* initscore_file) {
Guolin Ke's avatar
Guolin Ke committed
390
  num_init_score_ = 0;
391
392
393
394
395
396
  std::string init_score_filename(initscore_file);
  if (init_score_filename.size() <= 0) {
    init_score_filename = std::string(data_filename_);
    // default weight file name
    init_score_filename.append(".init");
  }
Guolin Ke's avatar
Guolin Ke committed
397
  TextReader<size_t> reader(init_score_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
398
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
399
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
400
401
    return;
  }
402
403
  Log::Info("Loading initial scores...");

404
405
406
  // use first line to count number class
  int num_class = static_cast<int>(Common::Split(reader.Lines()[0].c_str(), '\t').size());
  data_size_t num_line = static_cast<data_size_t>(reader.Lines().size());
Guolin Ke's avatar
Guolin Ke committed
407
  num_init_score_ = static_cast<int64_t>(num_line) * num_class;
408

Guolin Ke's avatar
Guolin Ke committed
409
  init_score_ = std::vector<double>(num_init_score_);
410
  if (num_class == 1) {
411
    #pragma omp parallel for schedule(static)
412
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
413
      double tmp = 0.0f;
414
      Common::Atof(reader.Lines()[i].c_str(), &tmp);
415
      init_score_[i] = Common::AvoidInf(static_cast<double>(tmp));
416
    }
417
  } else {
418
    std::vector<std::string> oneline_init_score;
419
    #pragma omp parallel for schedule(static)
420
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
421
      double tmp = 0.0f;
422
423
      oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
      if (static_cast<int>(oneline_init_score.size()) != num_class) {
424
        Log::Fatal("Invalid initial score file. Redundant or insufficient columns");
425
426
427
      }
      for (int k = 0; k < num_class; ++k) {
        Common::Atof(oneline_init_score[k].c_str(), &tmp);
428
        init_score_[static_cast<size_t>(k) * num_line + i] = Common::AvoidInf(static_cast<double>(tmp));
429
      }
430
    }
Guolin Ke's avatar
Guolin Ke committed
431
  }
432
  init_score_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
433
434
435
436
437
438
439
}

void Metadata::LoadQueryBoundaries() {
  num_queries_ = 0;
  std::string query_filename(data_filename_);
  // default query file name
  query_filename.append(".query");
Guolin Ke's avatar
Guolin Ke committed
440
  TextReader<size_t> reader(query_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
441
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
442
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
443
444
    return;
  }
445
  Log::Info("Loading query boundaries...");
Guolin Ke's avatar
Guolin Ke committed
446
  query_boundaries_ = std::vector<data_size_t>(reader.Lines().size() + 1);
Guolin Ke's avatar
Guolin Ke committed
447
448
449
450
451
452
453
  num_queries_ = static_cast<data_size_t>(reader.Lines().size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < reader.Lines().size(); ++i) {
    int tmp_cnt;
    Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
    query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
  }
454
  query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
455
456
457
}

void Metadata::LoadQueryWeights() {
Guolin Ke's avatar
Guolin Ke committed
458
  if (weights_.size() == 0 || query_boundaries_.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
459
460
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
461
  query_weights_.clear();
462
  Log::Info("Loading query weights...");
463
  query_weights_ = std::vector<label_t>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_weights_[i] = 0.0f;
    for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
      query_weights_[i] += weights_[j];
    }
    query_weights_[i] /= (query_boundaries_[i + 1] - query_boundaries_[i]);
  }
}

void Metadata::LoadFromMemory(const void* memory) {
  const char* mem_ptr = reinterpret_cast<const char*>(memory);

  num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_data_);
  num_weights_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_weights_);
  num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_queries_);

Guolin Ke's avatar
Guolin Ke committed
483
  if (!label_.empty()) { label_.clear(); }
484
  label_ = std::vector<label_t>(num_data_);
485
486
  std::memcpy(label_.data(), mem_ptr, sizeof(label_t) * num_data_);
  mem_ptr += sizeof(label_t) * num_data_;
Guolin Ke's avatar
Guolin Ke committed
487
488

  if (num_weights_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
489
    if (!weights_.empty()) { weights_.clear(); }
490
    weights_ = std::vector<label_t>(num_weights_);
491
492
    std::memcpy(weights_.data(), mem_ptr, sizeof(label_t) * num_weights_);
    mem_ptr += sizeof(label_t) * num_weights_;
493
    weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
494
495
  }
  if (num_queries_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
496
    if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
497
    query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
498
499
    std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t) * (num_queries_ + 1));
    mem_ptr += sizeof(data_size_t) * (num_queries_ + 1);
500
    query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
501
  }
Guolin Ke's avatar
Guolin Ke committed
502
  LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
503
504
}

505
506
507
508
509
void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
  writer->Write(&num_data_, sizeof(num_data_));
  writer->Write(&num_weights_, sizeof(num_weights_));
  writer->Write(&num_queries_, sizeof(num_queries_));
  writer->Write(label_.data(), sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
510
  if (!weights_.empty()) {
511
    writer->Write(weights_.data(), sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
512
  }
Guolin Ke's avatar
Guolin Ke committed
513
  if (!query_boundaries_.empty()) {
514
    writer->Write(query_boundaries_.data(), sizeof(data_size_t) * (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
515
  }
516
517
518
519
  if (num_init_score_ > 0) {
    Log::Warning("Please note that `init_score` is not saved in binary file.\n"
      "If you need it, please set it again after loading Dataset.");
  }
Guolin Ke's avatar
Guolin Ke committed
520
521
}

522
size_t Metadata::SizesInByte() const {
Guolin Ke's avatar
Guolin Ke committed
523
524
  size_t size = sizeof(num_data_) + sizeof(num_weights_)
    + sizeof(num_queries_);
525
  size += sizeof(label_t) * num_data_;
Guolin Ke's avatar
Guolin Ke committed
526
  if (!weights_.empty()) {
527
    size += sizeof(label_t) * num_weights_;
Guolin Ke's avatar
Guolin Ke committed
528
  }
Guolin Ke's avatar
Guolin Ke committed
529
  if (!query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
530
531
532
533
534
535
536
    size += sizeof(data_size_t) * (num_queries_ + 1);
  }
  return size;
}


}  // namespace LightGBM