"vscode:/vscode.git/clone" did not exist on "4f8c32d9a6c6b0c8d774d571da3bcd70191d218b"
metadata.cpp 18.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6
#include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
7
8

#include <string>
9
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
10
11
12

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
13
Metadata::Metadata() {
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
  num_weights_ = 0;
  num_init_score_ = 0;
  num_data_ = 0;
  num_queries_ = 0;
18
19
20
  weight_load_from_file_ = false;
  query_load_from_file_ = false;
  init_score_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
21
22
}

23
void Metadata::Init(const char* data_filename) {
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
28
  data_filename_ = data_filename;
  // for lambdarank, it needs query data for partition data in parallel learning
  LoadQueryBoundaries();
  LoadWeights();
  LoadQueryWeights();
29
  LoadInitialScore();
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
}

Metadata::~Metadata() {
}

35
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
Guolin Ke's avatar
Guolin Ke committed
36
  num_data_ = num_data;
37
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
38
  if (weight_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
39
    if (!weights_.empty()) {
40
      Log::Info("Using weights in data file, ignoring the additional weights file");
Guolin Ke's avatar
Guolin Ke committed
41
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
42
    }
43
    weights_ = std::vector<label_t>(num_data_, 0.0f);
Guolin Ke's avatar
Guolin Ke committed
44
    num_weights_ = num_data_;
45
    weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
46
47
  }
  if (query_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
48
    if (!query_boundaries_.empty()) {
49
      Log::Info("Using query id in data file, ignoring the additional query file");
Guolin Ke's avatar
Guolin Ke committed
50
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
51
    }
Guolin Ke's avatar
Guolin Ke committed
52
    if (!query_weights_.empty()) { query_weights_.clear(); }
53
    queries_ = std::vector<data_size_t>(num_data_, 0);
54
    query_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
55
  }
Guolin Ke's avatar
Guolin Ke committed
56
57
}

Guolin Ke's avatar
Guolin Ke committed
58
59
60
void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) {
  num_data_ = num_used_indices;

61
  label_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
62
#pragma omp parallel for schedule(static, 512) if (num_used_indices >= 1024)
63
  for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
64
65
66
    label_[i] = fullset.label_[used_indices[i]];
  }

Guolin Ke's avatar
Guolin Ke committed
67
  if (!fullset.weights_.empty()) {
68
    weights_ = std::vector<label_t>(num_used_indices);
Guolin Ke's avatar
Guolin Ke committed
69
    num_weights_ = num_used_indices;
Guolin Ke's avatar
Guolin Ke committed
70
#pragma omp parallel for schedule(static, 512) if (num_used_indices >= 1024)
71
    for (data_size_t i = 0; i < num_used_indices; ++i) {
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
      weights_[i] = fullset.weights_[used_indices[i]];
    }
  } else {
    num_weights_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
78
  if (!fullset.init_score_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
79
    int num_class = static_cast<int>(fullset.num_init_score_ / fullset.num_data_);
80
    init_score_ = std::vector<double>(static_cast<size_t>(num_used_indices) * num_class);
Guolin Ke's avatar
Guolin Ke committed
81
    num_init_score_ = static_cast<int64_t>(num_used_indices) * num_class;
82
    #pragma omp parallel for schedule(static)
83
    for (int k = 0; k < num_class; ++k) {
84
85
86
87
      const size_t offset_dest = static_cast<size_t>(k) * num_data_;
      const size_t offset_src = static_cast<size_t>(k) * fullset.num_data_;
      for (data_size_t i = 0; i < num_used_indices; ++i) {
        init_score_[offset_dest + i] = fullset.init_score_[offset_src + used_indices[i]];
88
      }
Guolin Ke's avatar
Guolin Ke committed
89
90
91
92
93
    }
  } else {
    num_init_score_ = 0;
  }

Guolin Ke's avatar
Guolin Ke committed
94
  if (!fullset.query_boundaries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    std::vector<data_size_t> used_query;
    data_size_t data_idx = 0;
    for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_indices; ++qid) {
      data_size_t start = fullset.query_boundaries_[qid];
      data_size_t end = fullset.query_boundaries_[qid + 1];
      data_size_t len = end - start;
      if (used_indices[data_idx] > start) {
        continue;
      } else if (used_indices[data_idx] == start) {
        if (num_used_indices >= data_idx + len && used_indices[data_idx + len - 1] == end - 1) {
          used_query.push_back(qid);
          data_idx += len;
        } else {
          Log::Fatal("Data partition error, data didn't match queries");
        }
      } else {
        Log::Fatal("Data partition error, data didn't match queries");
      }
    }
    query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
    num_queries_ = static_cast<data_size_t>(used_query.size());
    query_boundaries_[0] = 0;
    for (data_size_t i = 0; i < num_queries_; ++i) {
      data_size_t qid = used_query[i];
      data_size_t len = fullset.query_boundaries_[qid + 1] - fullset.query_boundaries_[qid];
      query_boundaries_[i + 1] = query_boundaries_[i] + len;
    }
  } else {
    num_queries_ = 0;
  }
}

Guolin Ke's avatar
Guolin Ke committed
127
void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
Guolin Ke's avatar
Guolin Ke committed
128
  if (used_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
129
130
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
131
  auto old_label = label_;
Guolin Ke's avatar
Guolin Ke committed
132
  num_data_ = static_cast<data_size_t>(used_indices.size());
133
  label_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
134
#pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
135
136
137
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = old_label[used_indices[i]];
  }
Guolin Ke's avatar
Guolin Ke committed
138
  old_label.clear();
Guolin Ke's avatar
Guolin Ke committed
139
140
141
}

void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
142
143
  if (used_data_indices.empty()) {
    if (!queries_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
Guolin Ke's avatar
Guolin Ke committed
159
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
164
165
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
166
      queries_.clear();
Guolin Ke's avatar
Guolin Ke committed
167
    }
Guolin Ke's avatar
Guolin Ke committed
168
    // check weights
Guolin Ke's avatar
Guolin Ke committed
169
    if (!weights_.empty() && num_weights_ != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
170
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
171
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
172
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
173
174
175
    }

    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
176
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
Guolin Ke's avatar
Guolin Ke committed
177
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
178
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
179
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
180
181
182
    }

    // contain initial score file
183
    if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
Guolin Ke's avatar
Guolin Ke committed
184
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
185
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
186
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
187
188
    }
  } else {
189
190
191
    if (!queries_.empty()) {
      Log::Fatal("Cannot used query_id for parallel training");
    }
Guolin Ke's avatar
Guolin Ke committed
192
193
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
194
195
196
197
198
199
200
201
202
203
    if (weight_load_from_file_) {
      if (weights_.size() > 0 && num_weights_ != num_all_data) {
        weights_.clear();
        num_weights_ = 0;
        Log::Fatal("Weights size doesn't match data size");
      }
      // get local weights
      if (!weights_.empty()) {
        auto old_weights = weights_;
        num_weights_ = num_data_;
204
        weights_ = std::vector<label_t>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
205
#pragma omp parallel for schedule(static, 512)
206
207
208
209
        for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
          weights_[i] = old_weights[used_data_indices[i]];
        }
        old_weights.clear();
Guolin Ke's avatar
Guolin Ke committed
210
211
      }
    }
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    if (query_load_from_file_) {
      // check query boundries
      if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
        query_boundaries_.clear();
        num_queries_ = 0;
        Log::Fatal("Query size doesn't match data size");
      }
      // get local query boundaries
      if (!query_boundaries_.empty()) {
        std::vector<data_size_t> used_query;
        data_size_t data_idx = 0;
        for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
          data_size_t start = query_boundaries_[qid];
          data_size_t end = query_boundaries_[qid + 1];
          data_size_t len = end - start;
          if (used_data_indices[data_idx] > start) {
            continue;
          } else if (used_data_indices[data_idx] == start) {
            if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
              used_query.push_back(qid);
              data_idx += len;
            } else {
              Log::Fatal("Data partition error, data didn't match queries");
            }
Guolin Ke's avatar
Guolin Ke committed
236
          } else {
Guolin Ke's avatar
Guolin Ke committed
237
            Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
238
239
          }
        }
240
241
242
243
244
245
246
247
248
249
        auto old_query_boundaries = query_boundaries_;
        query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
        num_queries_ = static_cast<data_size_t>(used_query.size());
        query_boundaries_[0] = 0;
        for (data_size_t i = 0; i < num_queries_; ++i) {
          data_size_t qid = used_query[i];
          data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
          query_boundaries_[i + 1] = query_boundaries_[i] + len;
        }
        old_query_boundaries.clear();
Guolin Ke's avatar
Guolin Ke committed
250
251
      }
    }
252
253
254
255
256
257
258
    if (init_score_load_from_file_) {
      // contain initial score file
      if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
        init_score_.clear();
        num_init_score_ = 0;
        Log::Fatal("Initial score size doesn't match data size");
      }
Guolin Ke's avatar
Guolin Ke committed
259

260
261
262
263
264
265
      // get local initial scores
      if (!init_score_.empty()) {
        auto old_scores = init_score_;
        int num_class = static_cast<int>(num_init_score_ / num_all_data);
        num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
        init_score_ = std::vector<double>(num_init_score_);
Guolin Ke's avatar
Guolin Ke committed
266
#pragma omp parallel for schedule(static)
267
        for (int k = 0; k < num_class; ++k) {
268
269
          const size_t offset_dest = static_cast<size_t>(k) * num_data_;
          const size_t offset_src = static_cast<size_t>(k) * num_all_data;
270
          for (size_t i = 0; i < used_data_indices.size(); ++i) {
271
            init_score_[offset_dest + i] = old_scores[offset_src + used_data_indices[i]];
272
          }
273
        }
274
        old_scores.clear();
Guolin Ke's avatar
Guolin Ke committed
275
276
277
278
279
280
281
      }
    }
    // re-load query weight
    LoadQueryWeights();
  }
}

282
283
284
285
286
287
288
289
290
291
292
void Metadata::SetInitScore(const double* init_score, data_size_t len) {
  std::lock_guard<std::mutex> lock(mutex_);
  // save to nullptr
  if (init_score == nullptr || len == 0) {
    init_score_.clear();
    num_init_score_ = 0;
    return;
  }
  if ((len % num_data_) != 0) {
    Log::Fatal("Initial score size doesn't match data size");
  }
293
  if (init_score_.empty()) { init_score_.resize(len); }
294
  num_init_score_ = len;
295

Guolin Ke's avatar
Guolin Ke committed
296
  #pragma omp parallel for schedule(static, 512) if (num_init_score_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
297
  for (int64_t i = 0; i < num_init_score_; ++i) {
298
    init_score_[i] = Common::AvoidInf(init_score[i]);
299
  }
300
  init_score_load_from_file_ = false;
301
302
}

303
void Metadata::SetLabel(const label_t* label, data_size_t len) {
304
  std::lock_guard<std::mutex> lock(mutex_);
305
306
307
  if (label == nullptr) {
    Log::Fatal("label cannot be nullptr");
  }
Guolin Ke's avatar
Guolin Ke committed
308
  if (num_data_ != len) {
309
    Log::Fatal("Length of label is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
310
  }
311
  if (label_.empty()) { label_.resize(num_data_); }
312

Guolin Ke's avatar
Guolin Ke committed
313
  #pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
314
  for (data_size_t i = 0; i < num_data_; ++i) {
315
    label_[i] = Common::AvoidInf(label[i]);
Guolin Ke's avatar
Guolin Ke committed
316
317
318
  }
}

319
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
320
  std::lock_guard<std::mutex> lock(mutex_);
321
322
323
324
325
326
  // save to nullptr
  if (weights == nullptr || len == 0) {
    weights_.clear();
    num_weights_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
327
  if (num_data_ != len) {
328
    Log::Fatal("Length of weights is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
329
  }
330
  if (weights_.empty()) { weights_.resize(num_data_); }
Guolin Ke's avatar
Guolin Ke committed
331
  num_weights_ = num_data_;
332

Guolin Ke's avatar
Guolin Ke committed
333
  #pragma omp parallel for schedule(static, 512) if (num_weights_ >= 1024)
Guolin Ke's avatar
Guolin Ke committed
334
  for (data_size_t i = 0; i < num_weights_; ++i) {
335
    weights_[i] = Common::AvoidInf(weights[i]);
Guolin Ke's avatar
Guolin Ke committed
336
337
  }
  LoadQueryWeights();
338
  weight_load_from_file_ = false;
Guolin Ke's avatar
Guolin Ke committed
339
340
}

Guolin Ke's avatar
Guolin Ke committed
341
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
342
  std::lock_guard<std::mutex> lock(mutex_);
343
  // save to nullptr
Guolin Ke's avatar
Guolin Ke committed
344
  if (query == nullptr || len == 0) {
345
346
347
348
    query_boundaries_.clear();
    num_queries_ = 0;
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
349
  data_size_t sum = 0;
350
  #pragma omp parallel for schedule(static) reduction(+:sum)
Guolin Ke's avatar
Guolin Ke committed
351
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
352
    sum += query[i];
Guolin Ke's avatar
Guolin Ke committed
353
354
  }
  if (num_data_ != sum) {
355
    Log::Fatal("Sum of query counts is not same with #data");
Guolin Ke's avatar
Guolin Ke committed
356
357
  }
  num_queries_ = len;
358
  query_boundaries_.resize(num_queries_ + 1);
Guolin Ke's avatar
Guolin Ke committed
359
  query_boundaries_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
360
  for (data_size_t i = 0; i < num_queries_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
361
    query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
Guolin Ke's avatar
Guolin Ke committed
362
363
  }
  LoadQueryWeights();
364
  query_load_from_file_ = false;
365
}
Guolin Ke's avatar
Guolin Ke committed
366

Guolin Ke's avatar
Guolin Ke committed
367
368
369
370
371
void Metadata::LoadWeights() {
  num_weights_ = 0;
  std::string weight_filename(data_filename_);
  // default weight file name
  weight_filename.append(".weight");
Guolin Ke's avatar
Guolin Ke committed
372
  TextReader<size_t> reader(weight_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
373
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
374
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
375
376
    return;
  }
377
  Log::Info("Loading weights...");
Guolin Ke's avatar
Guolin Ke committed
378
  num_weights_ = static_cast<data_size_t>(reader.Lines().size());
379
  weights_ = std::vector<label_t>(num_weights_);
380
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
381
  for (data_size_t i = 0; i < num_weights_; ++i) {
382
    double tmp_weight = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
383
    Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
384
    weights_[i] = Common::AvoidInf(static_cast<label_t>(tmp_weight));
Guolin Ke's avatar
Guolin Ke committed
385
  }
386
  weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
387
388
}

389
void Metadata::LoadInitialScore() {
Guolin Ke's avatar
Guolin Ke committed
390
  num_init_score_ = 0;
391
392
393
394
  std::string init_score_filename(data_filename_);
  init_score_filename = std::string(data_filename_);
  // default init_score file name
  init_score_filename.append(".init");
Guolin Ke's avatar
Guolin Ke committed
395
  TextReader<size_t> reader(init_score_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
396
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
397
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
398
399
    return;
  }
400
401
  Log::Info("Loading initial scores...");

402
403
404
  // use first line to count number class
  int num_class = static_cast<int>(Common::Split(reader.Lines()[0].c_str(), '\t').size());
  data_size_t num_line = static_cast<data_size_t>(reader.Lines().size());
Guolin Ke's avatar
Guolin Ke committed
405
  num_init_score_ = static_cast<int64_t>(num_line) * num_class;
406

Guolin Ke's avatar
Guolin Ke committed
407
  init_score_ = std::vector<double>(num_init_score_);
408
  if (num_class == 1) {
409
    #pragma omp parallel for schedule(static)
410
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
411
      double tmp = 0.0f;
412
      Common::Atof(reader.Lines()[i].c_str(), &tmp);
413
      init_score_[i] = Common::AvoidInf(static_cast<double>(tmp));
414
    }
415
  } else {
416
    std::vector<std::string> oneline_init_score;
417
    #pragma omp parallel for schedule(static)
418
    for (data_size_t i = 0; i < num_line; ++i) {
Guolin Ke's avatar
Guolin Ke committed
419
      double tmp = 0.0f;
420
421
      oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
      if (static_cast<int>(oneline_init_score.size()) != num_class) {
422
        Log::Fatal("Invalid initial score file. Redundant or insufficient columns");
423
424
425
      }
      for (int k = 0; k < num_class; ++k) {
        Common::Atof(oneline_init_score[k].c_str(), &tmp);
426
        init_score_[static_cast<size_t>(k) * num_line + i] = Common::AvoidInf(static_cast<double>(tmp));
427
      }
428
    }
Guolin Ke's avatar
Guolin Ke committed
429
  }
430
  init_score_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
431
432
433
434
435
436
437
}

void Metadata::LoadQueryBoundaries() {
  num_queries_ = 0;
  std::string query_filename(data_filename_);
  // default query file name
  query_filename.append(".query");
Guolin Ke's avatar
Guolin Ke committed
438
  TextReader<size_t> reader(query_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
439
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
440
  if (reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
441
442
    return;
  }
443
  Log::Info("Loading query boundaries...");
Guolin Ke's avatar
Guolin Ke committed
444
  query_boundaries_ = std::vector<data_size_t>(reader.Lines().size() + 1);
Guolin Ke's avatar
Guolin Ke committed
445
446
447
448
449
450
451
  num_queries_ = static_cast<data_size_t>(reader.Lines().size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < reader.Lines().size(); ++i) {
    int tmp_cnt;
    Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
    query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
  }
452
  query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
453
454
455
}

void Metadata::LoadQueryWeights() {
Guolin Ke's avatar
Guolin Ke committed
456
  if (weights_.size() == 0 || query_boundaries_.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
457
458
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
459
  query_weights_.clear();
460
  Log::Info("Loading query weights...");
461
  query_weights_ = std::vector<label_t>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
462
463
464
465
466
467
468
469
470
471
472
473
474
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_weights_[i] = 0.0f;
    for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
      query_weights_[i] += weights_[j];
    }
    query_weights_[i] /= (query_boundaries_[i + 1] - query_boundaries_[i]);
  }
}

void Metadata::LoadFromMemory(const void* memory) {
  const char* mem_ptr = reinterpret_cast<const char*>(memory);

  num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
475
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(num_data_));
Guolin Ke's avatar
Guolin Ke committed
476
  num_weights_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
477
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(num_weights_));
Guolin Ke's avatar
Guolin Ke committed
478
  num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
479
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(num_queries_));
Guolin Ke's avatar
Guolin Ke committed
480

Guolin Ke's avatar
Guolin Ke committed
481
  if (!label_.empty()) { label_.clear(); }
482
  label_ = std::vector<label_t>(num_data_);
483
  std::memcpy(label_.data(), mem_ptr, sizeof(label_t) * num_data_);
484
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
485
486

  if (num_weights_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
487
    if (!weights_.empty()) { weights_.clear(); }
488
    weights_ = std::vector<label_t>(num_weights_);
489
    std::memcpy(weights_.data(), mem_ptr, sizeof(label_t) * num_weights_);
490
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_weights_);
491
    weight_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
492
493
  }
  if (num_queries_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
494
    if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
495
    query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
496
    std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t) * (num_queries_ + 1));
497
498
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(data_size_t) *
                                              (num_queries_ + 1));
499
    query_load_from_file_ = true;
Guolin Ke's avatar
Guolin Ke committed
500
  }
Guolin Ke's avatar
Guolin Ke committed
501
  LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
502
503
}

504
void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
505
506
507
508
  writer->AlignedWrite(&num_data_, sizeof(num_data_));
  writer->AlignedWrite(&num_weights_, sizeof(num_weights_));
  writer->AlignedWrite(&num_queries_, sizeof(num_queries_));
  writer->AlignedWrite(label_.data(), sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
509
  if (!weights_.empty()) {
510
    writer->AlignedWrite(weights_.data(), sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
511
  }
Guolin Ke's avatar
Guolin Ke committed
512
  if (!query_boundaries_.empty()) {
513
514
    writer->AlignedWrite(query_boundaries_.data(),
                         sizeof(data_size_t) * (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
515
  }
516
517
518
519
  if (num_init_score_ > 0) {
    Log::Warning("Please note that `init_score` is not saved in binary file.\n"
      "If you need it, please set it again after loading Dataset.");
  }
Guolin Ke's avatar
Guolin Ke committed
520
521
}

522
size_t Metadata::SizesInByte() const {
523
524
525
526
  size_t size = VirtualFileWriter::AlignedSize(sizeof(num_data_)) +
                VirtualFileWriter::AlignedSize(sizeof(num_weights_)) +
                VirtualFileWriter::AlignedSize(sizeof(num_queries_));
  size += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_data_);
Guolin Ke's avatar
Guolin Ke committed
527
  if (!weights_.empty()) {
528
    size += VirtualFileWriter::AlignedSize(sizeof(label_t) * num_weights_);
Guolin Ke's avatar
Guolin Ke committed
529
  }
Guolin Ke's avatar
Guolin Ke committed
530
  if (!query_boundaries_.empty()) {
531
532
    size += VirtualFileWriter::AlignedSize(sizeof(data_size_t) *
                                           (num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
533
534
535
536
537
538
  }
  return size;
}


}  // namespace LightGBM