metadata.cpp 13 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
#include <LightGBM/dataset.h>

#include <LightGBM/utils/common.h>

#include <vector>
#include <string>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
10
Metadata::Metadata() {
Guolin Ke's avatar
Guolin Ke committed
11
12
}

Guolin Ke's avatar
Guolin Ke committed
13
void Metadata::Init(const char * data_filename, const int num_class) {
Guolin Ke's avatar
Guolin Ke committed
14
  data_filename_ = data_filename;
15
  num_class_ = num_class;
Guolin Ke's avatar
Guolin Ke committed
16
17
18
19
20
21
22
23
24
25
26
27
28
  // for lambdarank, it needs query data for partition data in parallel learning
  LoadQueryBoundaries();
  LoadWeights();
  LoadQueryWeights();
  LoadInitialScore();
}



Metadata::~Metadata() {
}


29
void Metadata::Init(data_size_t num_data, int num_class, int weight_idx, int query_idx) {
Guolin Ke's avatar
Guolin Ke committed
30
  num_data_ = num_data;
31
  num_class_ = num_class;
Guolin Ke's avatar
Guolin Ke committed
32
  label_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
33
  if (weight_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
34
    if (weights_.size() > 0) {
35
      Log::Info("Using weights in data file, ignoring the additional weights file");
Guolin Ke's avatar
Guolin Ke committed
36
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
37
    }
Guolin Ke's avatar
Guolin Ke committed
38
    weights_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
39
    num_weights_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
40
    std::fill(weights_.begin(), weights_.end(), 0.0f);
Guolin Ke's avatar
Guolin Ke committed
41
42
  }
  if (query_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
43
    if (query_boundaries_.size() >  0) {
44
      Log::Info("Using query id in data file, ignoring the additional query file");
Guolin Ke's avatar
Guolin Ke committed
45
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
46
    }
Guolin Ke's avatar
Guolin Ke committed
47
48
49
    if (query_weights_.size() > 0) { query_weights_.clear(); }
    queries_ = std::vector<data_size_t>(num_data_);
    std::fill(queries_.begin(), queries_.end(), 0);
Guolin Ke's avatar
Guolin Ke committed
50
  }
Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
56
}

void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
  if (used_indices.size() <= 0) {
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
57
  auto old_label = label_;
Guolin Ke's avatar
Guolin Ke committed
58
  num_data_ = static_cast<data_size_t>(used_indices.size());
Guolin Ke's avatar
Guolin Ke committed
59
  label_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
60
61
62
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = old_label[used_indices[i]];
  }
Guolin Ke's avatar
Guolin Ke committed
63
  old_label.clear();
Guolin Ke's avatar
Guolin Ke committed
64
65
66
67
}

void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
  if (used_data_indices.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
68
    if (queries_.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
Guolin Ke's avatar
Guolin Ke committed
84
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
90
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
91
      queries_.clear();
Guolin Ke's avatar
Guolin Ke committed
92
    }
Guolin Ke's avatar
Guolin Ke committed
93
    // check weights
Guolin Ke's avatar
Guolin Ke committed
94
95
    if (weights_.size() > 0 && num_weights_ != num_data_) {
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
96
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
97
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
98
99
100
    }

    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
101
102
    if (query_boundaries_.size() > 0 && query_boundaries_[num_queries_] != num_data_) {
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
103
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
104
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
105
106
107
    }

    // contain initial score file
Guolin Ke's avatar
Guolin Ke committed
108
109
    if (init_score_.size() > 0 && num_init_score_ != num_data_) {
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
110
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
111
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
112
113
114
115
    }
  } else {
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
Guolin Ke's avatar
Guolin Ke committed
116
117
    if (weights_.size() > 0 && num_weights_ != num_all_data) {
      weights_.clear();
Guolin Ke's avatar
Guolin Ke committed
118
      num_weights_ = 0;
Guolin Ke's avatar
Guolin Ke committed
119
      Log::Fatal("Weights size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
120
121
    }
    // check query boundries
Guolin Ke's avatar
Guolin Ke committed
122
123
    if (query_boundaries_.size() > 0 && query_boundaries_[num_queries_] != num_all_data) {
      query_boundaries_.clear();
Guolin Ke's avatar
Guolin Ke committed
124
      num_queries_ = 0;
Guolin Ke's avatar
Guolin Ke committed
125
      Log::Fatal("Query size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
126
127
128
    }

    // contain initial score file
Guolin Ke's avatar
Guolin Ke committed
129
130
    if (init_score_.size() > 0 && num_init_score_ != num_all_data) {
      init_score_.clear();
Guolin Ke's avatar
Guolin Ke committed
131
      num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
132
      Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
133
134
135
    }

    // get local weights
Guolin Ke's avatar
Guolin Ke committed
136
137
    if (weights_.size() > 0) {
      auto old_weights = weights_;
Guolin Ke's avatar
Guolin Ke committed
138
      num_weights_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
139
      weights_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
140
141
142
      for (size_t i = 0; i < used_data_indices.size(); ++i) {
        weights_[i] = old_weights[used_data_indices[i]];
      }
Guolin Ke's avatar
Guolin Ke committed
143
      old_weights.clear();
Guolin Ke's avatar
Guolin Ke committed
144
145
146
    }

    // get local query boundaries
Guolin Ke's avatar
Guolin Ke committed
147
    if (query_boundaries_.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
148
149
150
151
152
153
154
155
156
157
158
159
160
      std::vector<data_size_t> used_query;
      data_size_t data_idx = 0;
      for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
        data_size_t start = query_boundaries_[qid];
        data_size_t end = query_boundaries_[qid + 1];
        data_size_t len = end - start;
        if (used_data_indices[data_idx] > start) {
          continue;
        } else if (used_data_indices[data_idx] == start) {
          if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
            used_query.push_back(qid);
            data_idx += len;
          } else {
Guolin Ke's avatar
Guolin Ke committed
161
            Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
162
163
          }
        } else {
Guolin Ke's avatar
Guolin Ke committed
164
          Log::Fatal("Data partition error, data didn't match queries");
Guolin Ke's avatar
Guolin Ke committed
165
166
        }
      }
Guolin Ke's avatar
Guolin Ke committed
167
168
      auto old_query_boundaries = query_boundaries_;
      query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
169
170
171
172
173
174
175
      num_queries_ = static_cast<data_size_t>(used_query.size());
      query_boundaries_[0] = 0;
      for (data_size_t i = 0; i < num_queries_; ++i) {
        data_size_t qid = used_query[i];
        data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
        query_boundaries_[i + 1] = query_boundaries_[i] + len;
      }
Guolin Ke's avatar
Guolin Ke committed
176
      old_query_boundaries.clear();
Guolin Ke's avatar
Guolin Ke committed
177
178
179
    }

    // get local initial scores
Guolin Ke's avatar
Guolin Ke committed
180
181
    if (init_score_.size() > 0) {
      auto old_scores = init_score_;
Guolin Ke's avatar
Guolin Ke committed
182
      num_init_score_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
183
      init_score_ = std::vector<float>(num_init_score_ * num_class_);
184
185
186
      for (int k = 0; k < num_class_; ++k){
        for (size_t i = 0; i < used_data_indices.size(); ++i) {
          init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]];
187
        }
Guolin Ke's avatar
Guolin Ke committed
188
      }
Guolin Ke's avatar
Guolin Ke committed
189
      old_scores.clear();
Guolin Ke's avatar
Guolin Ke committed
190
191
192
193
194
195
196
197
    }

    // re-load query weight
    LoadQueryWeights();
  }
}


Guolin Ke's avatar
Guolin Ke committed
198
void Metadata::SetInitScore(const float* init_score, data_size_t len) {
199
  if (len != num_data_ * num_class_) {
200
    Log::Fatal("Initial score size doesn't match data size");
Guolin Ke's avatar
Guolin Ke committed
201
  }
Guolin Ke's avatar
Guolin Ke committed
202
  if (init_score_.size() > 0) { init_score_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
203
  num_init_score_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
204
  init_score_ = std::vector<float>(len);
205
  for (data_size_t i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
206
207
    init_score_[i] = init_score[i];
  }
Guolin Ke's avatar
Guolin Ke committed
208
209
}

Guolin Ke's avatar
Guolin Ke committed
210
211
212
213
void Metadata::SetLabel(const float* label, data_size_t len) {
  if (num_data_ != len) {
    Log::Fatal("len of label is not same with #data");
  }
Guolin Ke's avatar
Guolin Ke committed
214
215
  if (label_.size() > 0) { label_.clear(); }
  label_ = std::vector<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
216
217
218
219
220
221
222
223
224
  for (data_size_t i = 0; i < num_data_; ++i) {
    label_[i] = label[i];
  }
}

void Metadata::SetWeights(const float* weights, data_size_t len) {
  if (num_data_ != len) {
    Log::Fatal("len of weights is not same with #data");
  }
Guolin Ke's avatar
Guolin Ke committed
225
  if (weights_.size() > 0) { weights_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
226
  num_weights_ = num_data_;
Guolin Ke's avatar
Guolin Ke committed
227
  weights_ = std::vector<float>(num_weights_);
Guolin Ke's avatar
Guolin Ke committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
  for (data_size_t i = 0; i < num_weights_; ++i) {
    weights_[i] = weights[i];
  }
  LoadQueryWeights();
}

void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len) {
  data_size_t sum = 0;
  for (data_size_t i = 0; i < len; ++i) {
    sum += query_boundaries[i];
  }
  if (num_data_ != sum) {
    Log::Fatal("sum of query counts is not same with #data");
  }
Guolin Ke's avatar
Guolin Ke committed
242
  if (query_boundaries_.size() > 0) { query_boundaries_.clear(); }
Guolin Ke's avatar
Guolin Ke committed
243
  num_queries_ = len;
Guolin Ke's avatar
Guolin Ke committed
244
  query_boundaries_ = std::vector<data_size_t>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
245
246
247
248
249
250
251
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_boundaries_[i] = query_boundaries[i];
  }
  LoadQueryWeights();
}


Guolin Ke's avatar
Guolin Ke committed
252
253
254
255
256
void Metadata::LoadWeights() {
  num_weights_ = 0;
  std::string weight_filename(data_filename_);
  // default weight file name
  weight_filename.append(".weight");
Guolin Ke's avatar
Guolin Ke committed
257
  TextReader<size_t> reader(weight_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
258
259
260
261
  reader.ReadAllLines();
  if (reader.Lines().size() <= 0) {
    return;
  }
262
  Log::Info("Loading weights...");
Guolin Ke's avatar
Guolin Ke committed
263
  num_weights_ = static_cast<data_size_t>(reader.Lines().size());
Guolin Ke's avatar
Guolin Ke committed
264
  weights_ = std::vector<float>(num_weights_);
Guolin Ke's avatar
Guolin Ke committed
265
  for (data_size_t i = 0; i < num_weights_; ++i) {
266
    double tmp_weight = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
267
    Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
268
    weights_[i] = static_cast<float>(tmp_weight);
Guolin Ke's avatar
Guolin Ke committed
269
270
271
272
273
  }
}

void Metadata::LoadInitialScore() {
  num_init_score_ = 0;
Guolin Ke's avatar
Guolin Ke committed
274
275
276
277
  std::string init_score_filename(data_filename_);
  // default weight file name
  init_score_filename.append(".init");
  TextReader<size_t> reader(init_score_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
278
  reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
279
280
281
  if (reader.Lines().size() <= 0) {
    return;
  }
282
  Log::Info("Loading initial scores...");
Guolin Ke's avatar
Guolin Ke committed
283
  num_init_score_ = static_cast<data_size_t>(reader.Lines().size());
284

Guolin Ke's avatar
Guolin Ke committed
285
  init_score_ = std::vector<float>(num_init_score_ * num_class_);
286
  double tmp = 0.0f;
287

288
289
290
291
292
293
294
295
  if (num_class_ == 1){
      for (data_size_t i = 0; i < num_init_score_; ++i) {
        Common::Atof(reader.Lines()[i].c_str(), &tmp);
        init_score_[i] = static_cast<float>(tmp);
      }
  } else {
      std::vector<std::string> oneline_init_score;
      for (data_size_t i = 0; i < num_init_score_; ++i) {
296
        oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
297
298
299
300
301
302
303
304
        if (static_cast<int>(oneline_init_score.size()) != num_class_){
            Log::Fatal("Invalid initial score file. Redundant or insufficient columns.");
        }
        for (int k = 0; k < num_class_; ++k) {
          Common::Atof(oneline_init_score[k].c_str(), &tmp);
          init_score_[k * num_init_score_ + i] = static_cast<float>(tmp);
        }
      }
Guolin Ke's avatar
Guolin Ke committed
305
306
307
308
309
310
311
312
  }
}

void Metadata::LoadQueryBoundaries() {
  num_queries_ = 0;
  std::string query_filename(data_filename_);
  // default query file name
  query_filename.append(".query");
Guolin Ke's avatar
Guolin Ke committed
313
  TextReader<size_t> reader(query_filename.c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
  reader.ReadAllLines();
  if (reader.Lines().size() <= 0) {
    return;
  }
318
  Log::Info("Loading query boundaries...");
Guolin Ke's avatar
Guolin Ke committed
319
  query_boundaries_ = std::vector<data_size_t>(reader.Lines().size() + 1);
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
324
325
326
327
328
329
  num_queries_ = static_cast<data_size_t>(reader.Lines().size());
  query_boundaries_[0] = 0;
  for (size_t i = 0; i < reader.Lines().size(); ++i) {
    int tmp_cnt;
    Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
    query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
  }
}

void Metadata::LoadQueryWeights() {
Guolin Ke's avatar
Guolin Ke committed
330
  if (weights_.size() == 0 || query_boundaries_.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
331
332
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
333
  query_weights_.clear();
334
  Log::Info("Loading query weights...");
Guolin Ke's avatar
Guolin Ke committed
335
  query_weights_ = std::vector<float>(num_queries_);
Guolin Ke's avatar
Guolin Ke committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
  for (data_size_t i = 0; i < num_queries_; ++i) {
    query_weights_[i] = 0.0f;
    for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
      query_weights_[i] += weights_[j];
    }
    query_weights_[i] /= (query_boundaries_[i + 1] - query_boundaries_[i]);
  }
}

void Metadata::LoadFromMemory(const void* memory) {
  const char* mem_ptr = reinterpret_cast<const char*>(memory);

  num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_data_);
  num_weights_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_weights_);
  num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(num_queries_);

Guolin Ke's avatar
Guolin Ke committed
355
356
357
  if (label_.size() > 0) { label_.clear(); }
  label_ = std::vector<float>(num_data_);
  std::memcpy(label_.data(), mem_ptr, sizeof(float)*num_data_);
Guolin Ke's avatar
Guolin Ke committed
358
  mem_ptr += sizeof(float)*num_data_;
Guolin Ke's avatar
Guolin Ke committed
359
360

  if (num_weights_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
361
362
363
    if (weights_.size() > 0) { weights_.clear(); }
    weights_ = std::vector<float>(num_weights_);
    std::memcpy(weights_.data(), mem_ptr, sizeof(float)*num_weights_);
Guolin Ke's avatar
Guolin Ke committed
364
365
366
    mem_ptr += sizeof(float)*num_weights_;
  }
  if (num_queries_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
367
368
369
    if (query_boundaries_.size() > 0) { query_boundaries_.clear(); }
    query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
    std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t)*(num_queries_ + 1));
Guolin Ke's avatar
Guolin Ke committed
370
371
    mem_ptr += sizeof(data_size_t)*(num_queries_ + 1);
  }
Guolin Ke's avatar
Guolin Ke committed
372
  LoadQueryWeights();
Guolin Ke's avatar
Guolin Ke committed
373
374
375
376
377
378
}

void Metadata::SaveBinaryToFile(FILE* file) const {
  fwrite(&num_data_, sizeof(num_data_), 1, file);
  fwrite(&num_weights_, sizeof(num_weights_), 1, file);
  fwrite(&num_queries_, sizeof(num_queries_), 1, file);
Guolin Ke's avatar
Guolin Ke committed
379
380
381
  fwrite(label_.data(), sizeof(float), num_data_, file);
  if (weights_.size() > 0) {
    fwrite(weights_.data(), sizeof(float), num_weights_, file);
Guolin Ke's avatar
Guolin Ke committed
382
  }
Guolin Ke's avatar
Guolin Ke committed
383
384
  if (query_boundaries_.size() > 0) {
    fwrite(query_boundaries_.data(), sizeof(data_size_t), num_queries_ + 1, file);
Guolin Ke's avatar
Guolin Ke committed
385
386
387
388
389
390
391
392
  }

}

size_t Metadata::SizesInByte() const  {
  size_t size = sizeof(num_data_) + sizeof(num_weights_)
    + sizeof(num_queries_);
  size += sizeof(float) * num_data_;
Guolin Ke's avatar
Guolin Ke committed
393
  if (weights_.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
394
395
    size += sizeof(float) * num_weights_;
  }
Guolin Ke's avatar
Guolin Ke committed
396
  if (query_boundaries_.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
397
398
399
400
401
402
403
    size += sizeof(data_size_t) * (num_queries_ + 1);
  }
  return size;
}


}  // namespace LightGBM