metadata.cpp 19.1 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6
#include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
7
8

#include <string>
9
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
10
11
12

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
13
Metadata::Metadata() {
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
  num_weights_ = 0;
  num_init_score_ = 0;
  num_data_ = 0;
  num_queries_ = 0;
18
19
20
  weight_load_from_file_ = false;
  query_load_from_file_ = false;
  init_score_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
21
22
}

23
void Metadata::Init(const char* data_filename) {
Guolin Ke's avatar
Guolin Ke committed
24
  data_filename_ = data_filename;
25
  // for lambdarank, it needs query data for partition data in distributed learning
Guolin Ke's avatar
Guolin Ke committed
26
27
28
  LoadQueryBoundaries();
  LoadWeights();
  LoadQueryWeights();
29
  LoadInitialScore();
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
}

Metadata::~Metadata() {
}

35
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
Guolin Ke's avatar
Guolin Ke committed
36
  num_data_ = num_data;
37
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
38
  if (weight_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
39
    if (!weights_.empty()) {
40
      Log::Info("Using weights in data file, ignoring the additional weights file");
Guolin Ke's avatar
Guolin Ke committed
41
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
42
    }
43
    weights_ = std::vector<label_t>(num_data_, 0.0f);
Guolin Ke's avatar
Guolin Ke committed
44
    num_weights_ = num_data_;
45
    weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
46
47
  }
  if (query_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
48
    if (!query_boundaries_.empty()) {
49
      Log::Info("Using query id in data file, ignoring the additional query file");
Guolin Ke's avatar
Guolin Ke committed
50
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
51
    }
Guolin Ke's avatar
Guolin Ke committed
52
    if (!query_weights_.empty()) { query_weights_.clear(); }
53
    queries_ = std::vector<data_size_t>(num_data_, 0);
54
    query_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
55
  }
Guolin Ke's avatar
Guolin Ke committed
56
57
}

Guolin Ke's avatar
Guolin Ke committed
58
59
60
void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) {
  num_data_ = num_used_indices;

61
  label_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
62
#pragma omp parallel for schedule(static, 512) if (num_used_indices >= 1024)
63
  for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
64
65
66
    label_[i] = fullset.label_[used_indices[i]];
  }

Guolin Ke's avatar
Guolin Ke committed
67
  if (!fullset.weights_.empty()) {
68
    weights_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
69
    num_weights_ = num_used_indices;
Guolin Ke's avatar
Guolin Ke committed
70
#pragma omp parallel for schedule(static, 512) if (num_used_indices >= 1024)
71
    for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
      weights_[i] = fullset.weights_[used_indices[i]];
    }
  } else {
    num_weights_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
78
  if (!fullset.init_score_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
79
    int num_class = static_cast<int>(fullset.num_init_score_ / fullset.num_data_);
80
    init_score_ = std::vector<double>(static_cast<size_t>(num_used_indices) * num_class);
Guolin Ke's avatar
Guolin Ke committed
81
    num_init_score_ = static_cast<int64_t>(num_used_indices) * num_class;
82
    #pragma omp parallel for schedule(static)
83
    for (int k = 0; k < num_class; ++k) {
84
85
86
87
      const size_t offset_dest = static_cast<size_t>(k) * num_data_;
      const size_t offset_src = static_cast<size_t>(k) * fullset.num_data_;
      for (data_size_t i = 0; i < num_used_indices; ++i) {
        init_score_[offset_dest + i] = fullset.init_score_[offset_src + used_indices[i]];
88
      }
Guolin Ke's avatar
Guolin Ke committed
89
90
91
92
93
    }
  } else {
    num_init_score_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
94
  if (!fullset.query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    std::vector<data_size_t> used_query;
    data_size_t data_idx = 0;
    for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_indices; ++qid) {
      data_size_t start = fullset.query_boundaries_[qid];
      data_size_t end = fullset.query_boundaries_[qid + 1];
      data_size_t len = end - start;
      if (used_indices[data_idx] > start) {
        continue;
      } else if (used_indices[data_idx] == start) {
        if (num_used_indices >= data_idx + len && used_indices[data_idx + len - 1] == end - 1) {
          used_query.push_back(qid);
          data_idx += len;
        } else {
          Log::Fatal("Data partition error, data didn't match queries");
        }
      } else {
        Log::Fatal("Data partition error, data didn't match queries");
      }
    }
    query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
    num_queries_ = static_cast<data_size_t>(used_query.size());
    query_boundaries_[0] = 0;
    for (data_size_t i = 0; i < num_queries_; ++i) {
      data_size_t qid = used_query[i];
      data_size_t len = fullset.query_boundaries_[qid + 1] - fullset.query_boundaries_[qid];
      query_boundaries_[i + 1] = query_boundaries_[i] + len;
    }
  } else {
    num_queries_ = 0;
  }
}

Guolin Ke's avatar
Guolin Ke committed
127
void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
Guolin Ke's avatar
Guolin Ke committed
128
  if (used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
129
130
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
131
  auto old_label = label_;
Guolin Ke's avatar
Guolin Ke committed
132
  num_data_ = static_cast<data_size_t>(used_indices.size());
133
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
134
#pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
135
136
137
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = old_label[used_indices[i]];
  }
Guolin Ke's avatar
Guolin Ke committed
138
  old_label.clear();
Guolin Ke's avatar
Guolin Ke committed
139
140
141
}

void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
142
143
  if (used_data_indices.empty()) {
    if (!queries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
Guolin Ke's avatar
Guolin Ke committed
159
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
164
165
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
166
      queries_.clear();
Guolin Ke's avatar
Guolin Ke committed
167
    }
Guolin Ke's avatar
Guolin Ke committed
168
    // check weights
Guolin Ke's avatar
Guolin Ke committed
169
    if (!weights_.empty() && num_weights_ != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
170
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
171
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
172
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
173
174
175
    }

    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
176
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
177
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
178
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
179
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
180
181
182
    }

    // contain initial score file
183
    if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
Guolin Ke's avatar
Guolin Ke committed
184
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
185
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
186
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
187
188
    }
  } else {
189
    if (!queries_.empty()) {
190
      Log::Fatal("Cannot used query_id for distributed training");
191
    }
Guolin Ke's avatar
Guolin Ke committed
192
193
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
194
195
196
197
198
199
200
201
202
203
    if (weight_load_from_file_) {
      if (weights_.size() > 0 && num_weights_ != num_all_data) {
        weights_.clear();
        num_weights_ = 0;
        Log::Fatal("Weights size doesn't match data size");
      }
      // get local weights
      if (!weights_.empty()) {
        auto old_weights = weights_;
        num_weights_ = num_data_;
204
        weights_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
205
#pragma omp parallel for schedule(static, 512)
206
207
208
209
        for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
          weights_[i] = old_weights[used_data_indices[i]];
        }
        old_weights.clear();
Guolin Ke's avatar
Guolin Ke committed
210
211
      }
    }
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    if (query_load_from_file_) {
      // check query boundries
      if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
        query_boundaries_.clear();
        num_queries_ = 0;
        Log::Fatal("Query size doesn't match data size");
      }
      // get local query boundaries
      if (!query_boundaries_.empty()) {
        std::vector<data_size_t> used_query;
        data_size_t data_idx = 0;
        for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
          data_size_t start = query_boundaries_[qid];
          data_size_t end = query_boundaries_[qid + 1];
          data_size_t len = end - start;
          if (used_data_indices[data_idx] > start) {
            continue;
          } else if (used_data_indices[data_idx] == start) {
            if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
              used_query.push_back(qid);
              data_idx += len;
            } else {
              Log::Fatal("Data partition error, data didn't match queries");
            }
Guolin Ke's avatar
Guolin Ke committed
236
          } else {
Guolin Ke's avatar
Guolin Ke committed
237
            Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
238
239
          }
        }
240
241
242
243
244
245
246
247
248
249
        auto old_query_boundaries = query_boundaries_;
        query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
        num_queries_ = static_cast<data_size_t>(used_query.size());
        query_boundaries_[0] = 0;
        for (data_size_t i = 0; i < num_queries_; ++i) {
          data_size_t qid = used_query[i];
          data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
          query_boundaries_[i + 1] = query_boundaries_[i] + len;
        }
        old_query_boundaries.clear();
Guolin Ke's avatar
Guolin Ke committed
250
251
      }
    }
252
253
254
255
256
257
258
    if (init_score_load_from_file_) {
      // contain initial score file
      if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
        init_score_.clear();
        num_init_score_ = 0;
        Log::Fatal("Initial score size doesn't match data size");
      }
Guolin Ke's avatar
Guolin Ke committed
259

260
261
262
263
264
265
      // get local initial scores
      if (!init_score_.empty()) {
        auto old_scores = init_score_;
        int num_class = static_cast<int>(num_init_score_ / num_all_data);
        num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
        init_score_ = std::vector<double>(num_init_score_);
Guolin Ke's avatar
Guolin Ke committed
266
#pragma omp parallel for schedule(static)
267
        for (int k = 0; k < num_class; ++k) {
268
269
          const size_t offset_dest = static_cast<size_t>(k) * num_data_;
          const size_t offset_src = static_cast<size_t>(k) * num_all_data;
270
          for (size_t i = 0; i < used_data_indices.size(); ++i) {
271
            init_score_[offset_dest + i] = old_scores[offset_src + used_data_indices[i]];
272
          }
273
        }
274
        old_scores.clear();
Guolin Ke's avatar
Guolin Ke committed
275
276
277
278
279
      }
    }
    // re-load query weight
    LoadQueryWeights();
  }
280
281
282
283
  if (num_queries_ > 0) {
    Log::Debug("Number of queries in %s: %i. Average number of rows per query: %f.",
      data_filename_.c_str(), static_cast<int>(num_queries_), static_cast<double>(num_data_) / num_queries_);
  }
Guolin Ke's avatar
Guolin Ke committed
284
285
}

286
287
288
289
290
291
292
293
294
295
296
void Metadata::SetInitScore(const double* init_score, data_size_t len) {
  std::lock_guard<std::mutex> lock(mutex_);
  // save to nullptr
  if (init_score == nullptr || len == 0) {
    init_score_.clear();
    num_init_score_ = 0;
    return;
  }
  if ((len % num_data_) != 0) {
    Log::Fatal("Initial score size doesn't match data size");
  }
297
  if (init_score_.empty()) { init_score_.resize(len); }
298
  num_init_score_ = len;
299

Guolin Ke's avatar
Guolin Ke committed
300
  #pragma omp parallel for schedule(static, 512) if (num_init_score_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
301
  for (int64_t i = 0; i < num_init_score_; ++i) {
302
    init_score_[i] = Common::AvoidInf(init_score[i]);
303
  }
304
  init_score_load_from_file_ = false;
305
306
}

307
void Metadata::SetLabel(const label_t* label, data_size_t len) {
308
  std::lock_guard<std::mutex> lock(mutex_);
309
310
311
  if (label == nullptr) {
    Log::Fatal("label cannot be nullptr");
  }
Guolin Ke's avatar
Guolin Ke committed
312
  if (num_data_ != len) {
313
    Log::Fatal("Length of label is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
314
  }
315
  if (label_.empty()) { label_.resize(num_data_); }
316

Guolin Ke's avatar
Guolin Ke committed
317
  #pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
318
  for (data_size_t i = 0; i < num_data_; ++i) {
319
    label_[i] = Common::AvoidInf(label[i]);
Guolin Ke's avatar
Guolin Ke committed
320
321
322
  }
}

323
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
324
  std::lock_guard<std::mutex> lock(mutex_);
325
326
327
328
329
330
  // save to nullptr
  if (weights == nullptr || len == 0) {
    weights_.clear();
    num_weights_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
331
  if (num_data_ != len) {
332
    Log::Fatal("Length of weights is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
333
  }
334
  if (weights_.empty()) { weights_.resize(num_data_); }
Guolin Ke's avatar
Guolin Ke committed
335
  num_weights_ = num_data_;
336

Guolin Ke's avatar
Guolin Ke committed
337
  #pragma omp parallel for schedule(static, 512) if (num_weights_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
338
  for (data_size_t i = 0; i < num_weights_; ++i) {
339
    weights_[i] = Common::AvoidInf(weights[i]);
Guolin Ke's avatar
Guolin Ke committed
340
341
  }
  LoadQueryWeights();
342
  weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
343
344
}

Guolin Ke's avatar
Guolin Ke committed
345
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
346
  std::lock_guard<std::mutex> lock(mutex_);
347
  // save to nullptr
Guolin Ke's avatar
Guolin Ke committed
348
  if (query == nullptr || len == 0) {
349
350
351
352
    query_boundaries_.clear();
    num_queries_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
353
  data_size_t sum = 0;
354
  #pragma omp parallel for schedule(static) reduction(+:sum)
Guolin Ke's avatar
Guolin Ke committed
355
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
356
    sum += query[i];
Guolin Ke's avatar
Guolin Ke committed
357
358
  }
  if (num_data_ != sum) {
359
    Log::Fatal("Sum of query counts is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
360
361
  }
  num_queries_ = len;
362
  query_boundaries_.resize(num_queries_ + 1);
Guolin Ke's avatar
Guolin Ke committed
363
  query_boundaries_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
364
  for (data_size_t i = 0; i < num_queries_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
365
    query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
Guolin Ke's avatar
Guolin Ke committed
366
367
  }
  LoadQueryWeights();
368
  query_load_from_file_ = false;
369
}
Guolin Ke's avatar
Guolin Ke committed
370

Guolin Ke's avatar
Guolin Ke committed
371
372
373
374
375
void Metadata::LoadWeights() {
  num_weights_ = 0;
  std::string weight_filename(data_filename_);
  // default weight file name
  weight_filename.append(".weight");
Guolin Ke's avatar
Guolin Ke committed
376
  TextReader<size_t> reader(weight_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
377
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
378
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
379
380
    return;
  }
381
  Log::Info("Loading weights...");
Guolin Ke's avatar
Guolin Ke committed
382
  num_weights_ = static_cast<data_size_t>(reader.Lines().size());
383
  weights_ = std::vector<label_t>(num_weights_);
384
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
385
  for (data_size_t i = 0; i < num_weights_; ++i) {
386
    double tmp_weight = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
387
    Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
388
    weights_[i] = Common::AvoidInf(static_cast<label_t>(tmp_weight));
Guolin Ke's avatar
Guolin Ke committed
389
  }
390
  weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
391
392
}

393
void Metadata::LoadInitialScore() {
Guolin Ke's avatar
Guolin Ke committed
394
  num_init_score_ = 0;
395
396
397
398
  std::string init_score_filename(data_filename_);
  init_score_filename = std::string(data_filename_);
  // default init_score file name
  init_score_filename.append(".init");
Guolin Ke's avatar
Guolin Ke committed
399
  TextReader<size_t> reader(init_score_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
400
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
401
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
402
403
    return;
  }
404
405
  Log::Info("Loading initial scores...");

406
407
408
  // use first line to count number class
  int num_class = static_cast<int>(Common::Split(reader.Lines()[0].c_str(), '\t').size());
  data_size_t num_line = static_cast<data_size_t>(reader.Lines().size());
Guolin Ke's avatar
Guolin Ke committed
409
  num_init_score_ = static_cast<int64_t>(num_line) * num_class;
410

Guolin Ke's avatar
Guolin Ke committed
411
  init_score_ = std::vector<double>(num_init_score_);
412
  if (num_class == 1) {
413
    #pragma omp parallel for schedule(static)
414
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
415
      double tmp = 0.0f;
416
      Common::Atof(reader.Lines()[i].c_str(), &tmp);
417
      init_score_[i] = Common::AvoidInf(static_cast<double>(tmp));
418
    }
419
  } else {
420
    std::vector<std::string> oneline_init_score;
421
    #pragma omp parallel for schedule(static)
422
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
423
      double tmp = 0.0f;
424
425
      oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
      if (static_cast<int>(oneline_init_score.size()) != num_class) {
426
        Log::Fatal("Invalid initial score file. Redundant or insufficient columns");
427
428
429
      }
      for (int k = 0; k < num_class; ++k) {
        Common::Atof(oneline_init_score[k].c_str(), &tmp);
430
        init_score_[static_cast<size_t>(k) * num_line + i] = Common::AvoidInf(static_cast<double>(tmp));
431
      }
432
    }
Guolin Ke's avatar
Guolin Ke committed
433
  }
434
  init_score_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
435
436
437
438
439
440
441
}

void Metadata::LoadQueryBoundaries() {
  num_queries_ = 0;
  std::string query_filename(data_filename_);
  // default query file name
  query_filename.append(".query");
Guolin Ke's avatar
Guolin Ke committed
442
  TextReader<size_t> reader(query_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
443
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
444
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
445
446
    return;
  }
447
  Log::Info("Loading query boundaries...");
Guolin Ke's avatar
Guolin Ke committed
448
  query_boundaries_ = std::vector<data_size_t>(reader.Lines().size() + 1);
Guolin Ke's avatar
Guolin Ke committed
449
450
451
452
453
454
455
  num_queries_ = static_cast<data_size_t>(reader.Lines().size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < reader.Lines().size(); ++i) {
    int tmp_cnt;
    Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
    query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
  }
456
  query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
457
458
459
}

void Metadata::LoadQueryWeights() {
Guolin Ke's avatar
Guolin Ke committed
460
  if (weights_.size() == 0 || query_boundaries_.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
461
462
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
463
  query_weights_.clear();
464
  Log::Info("Loading query weights...");
465
  query_weights_ = std::vector<label_t>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
466
467
468
469
470
471
472
473
474
475
476
477
478
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_weights_[i] = 0.0f;
    for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
      query_weights_[i] += weights_[j];
    }
    query_weights_[i] /= (query_boundaries_[i + 1] - query_boundaries_[i]);
  }
}

void Metadata::LoadFromMemory(const void* memory) {
  const char* mem_ptr = reinterpret_cast<const char*>(memory);

  num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
479
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(num_data_));
Guolin Ke's avatar
Guolin Ke committed
480
  num_weights_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
481
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(num_weights_));
Guolin Ke's avatar
Guolin Ke committed
482
  num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
483
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(num_queries_));
Guolin Ke's avatar
Guolin Ke committed
484

Guolin Ke's avatar
Guolin Ke committed
485
  if (!label_.empty()) { label_.clear(); }
486
  label_ = std::vector<label_t>(num_data_);
487
  std::memcpy(label_.data(), mem_ptr, sizeof(label_t) * num_data_);
488
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
489
490

  if (num_weights_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
491
    if (!weights_.empty()) { weights_.clear(); }
492
    weights_ = std::vector<label_t>(num_weights_);
493
    std::memcpy(weights_.data(), mem_ptr, sizeof(label_t) * num_weights_);
494
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_weights_);
495
    weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
496
497
  }
  if (num_queries_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
498
    if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
499
    query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
500
    std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t) * (num_queries_ + 1));
501
502
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(data_size_t) *
                                              (num_queries_ + 1));
503
    query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
504
  }
Guolin Ke's avatar
Guolin Ke committed
505
  LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
506
507
}

508
void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
509
510
511
512
  writer->AlignedWrite(&num_data_, sizeof(num_data_));
  writer->AlignedWrite(&num_weights_, sizeof(num_weights_));
  writer->AlignedWrite(&num_queries_, sizeof(num_queries_));
  writer->AlignedWrite(label_.data(), sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
513
  if (!weights_.empty()) {
514
    writer->AlignedWrite(weights_.data(), sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
515
  }
Guolin Ke's avatar
Guolin Ke committed
516
  if (!query_boundaries_.empty()) {
517
518
    writer->AlignedWrite(query_boundaries_.data(),
                         sizeof(data_size_t) * (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
519
  }
520
521
522
523
  if (num_init_score_ > 0) {
    Log::Warning("Please note that `init_score` is not saved in binary file.\n"
      "If you need it, please set it again after loading Dataset.");
  }
Guolin Ke's avatar
Guolin Ke committed
524
525
}

526
size_t Metadata::SizesInByte() const {
527
528
529
530
  size_t size = VirtualFileWriter::AlignedSize(sizeof(num_data_)) +
                VirtualFileWriter::AlignedSize(sizeof(num_weights_)) +
                VirtualFileWriter::AlignedSize(sizeof(num_queries_));
  size += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
531
  if (!weights_.empty()) {
532
    size += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
533
  }
Guolin Ke's avatar
Guolin Ke committed
534
  if (!query_boundaries_.empty()) {
535
536
    size += VirtualFileWriter::AlignedSize(sizeof(data_size_t) *
                                           (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
537
538
539
540
541
542
  }
  return size;
}


}  // namespace LightGBM