Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
1765b2e3
Commit
1765b2e3
authored
Jan 25, 2017
by
Guolin Ke
Browse files
fix partition error when set weight_colunm
parent
7150c722
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
99 additions
and
124 deletions
+99
-124
include/LightGBM/dataset.h
include/LightGBM/dataset.h
+3
-2
include/LightGBM/dataset_loader.h
include/LightGBM/dataset_loader.h
+2
-2
src/io/dataset.cpp
src/io/dataset.cpp
+0
-2
src/io/dataset_loader.cpp
src/io/dataset_loader.cpp
+11
-11
src/io/metadata.cpp
src/io/metadata.cpp
+83
-107
No files found.
include/LightGBM/dataset.h
View file @
1765b2e3
...
...
@@ -88,8 +88,6 @@ public:
void
SetQuery
(
const
data_size_t
*
query
,
data_size_t
len
);
void
SetQueryId
(
const
data_size_t
*
query_id
,
data_size_t
len
);
/*!
* \brief Set initial scores
* \param init_score Initial scores, this class will manage memory for init_score.
...
...
@@ -244,6 +242,9 @@ private:
std
::
vector
<
data_size_t
>
queries_
;
/*! \brief mutex for threading safe call */
std
::
mutex
mutex_
;
bool
weight_load_from_file_
;
bool
query_load_from_file_
;
bool
init_score_load_from_file_
;
};
...
...
include/LightGBM/dataset_loader.h
View file @
1765b2e3
...
...
@@ -20,8 +20,6 @@ public:
LIGHTGBM_EXPORT
Dataset
*
LoadFromFileAlignWithOtherDataset
(
const
char
*
filename
,
const
Dataset
*
train_data
);
LIGHTGBM_EXPORT
Dataset
*
LoadFromBinFile
(
const
char
*
data_filename
,
const
char
*
bin_filename
,
int
rank
,
int
num_machines
);
LIGHTGBM_EXPORT
Dataset
*
CostructFromSampleData
(
std
::
vector
<
std
::
vector
<
double
>>&
sample_values
,
size_t
total_sample_size
,
data_size_t
num_data
);
/*! \brief Disable copy */
...
...
@@ -31,6 +29,8 @@ public:
private:
Dataset
*
LoadFromBinFile
(
const
char
*
data_filename
,
const
char
*
bin_filename
,
int
rank
,
int
num_machines
,
int
*
num_global_data
,
std
::
vector
<
data_size_t
>*
used_data_indices
);
void
SetHeader
(
const
char
*
filename
);
void
CheckDataset
(
const
Dataset
*
dataset
);
...
...
src/io/dataset.cpp
View file @
1765b2e3
...
...
@@ -112,8 +112,6 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si
name
=
Common
::
Trim
(
name
);
if
(
name
==
std
::
string
(
"query"
)
||
name
==
std
::
string
(
"group"
))
{
metadata_
.
SetQuery
(
field_data
,
num_element
);
}
else
if
(
name
==
std
::
string
(
"query_id"
)
||
name
==
std
::
string
(
"group_id"
))
{
metadata_
.
SetQueryId
(
field_data
,
num_element
);
}
else
{
return
false
;
}
...
...
src/io/dataset_loader.cpp
View file @
1765b2e3
...
...
@@ -209,7 +209,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
}
}
else
{
// load data from binary file
dataset
.
reset
(
LoadFromBinFile
(
filename
,
bin_filename
.
c_str
(),
rank
,
num_machines
));
dataset
.
reset
(
LoadFromBinFile
(
filename
,
bin_filename
.
c_str
(),
rank
,
num_machines
,
&
num_global_data
,
&
used_data_indices
));
}
// check meta data
dataset
->
metadata_
.
CheckOrPartition
(
num_global_data
,
used_data_indices
);
...
...
@@ -255,7 +255,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
}
}
else
{
// load data from binary file
dataset
.
reset
(
LoadFromBinFile
(
filename
,
bin_filename
.
c_str
(),
0
,
1
));
dataset
.
reset
(
LoadFromBinFile
(
filename
,
bin_filename
.
c_str
(),
0
,
1
,
&
num_global_data
,
&
used_data_indices
));
}
// not need to check validation data
// check meta data
...
...
@@ -263,7 +263,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
return
dataset
.
release
();
}
Dataset
*
DatasetLoader
::
LoadFromBinFile
(
const
char
*
data_filename
,
const
char
*
bin_filename
,
int
rank
,
int
num_machines
)
{
Dataset
*
DatasetLoader
::
LoadFromBinFile
(
const
char
*
data_filename
,
const
char
*
bin_filename
,
int
rank
,
int
num_machines
,
int
*
num_global_data
,
std
::
vector
<
data_size_t
>*
used_data_indices
)
{
auto
dataset
=
std
::
unique_ptr
<
Dataset
>
(
new
Dataset
());
FILE
*
file
;
#ifdef _MSC_VER
...
...
@@ -364,8 +364,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// load meta data
dataset
->
metadata_
.
LoadFromMemory
(
buffer
.
data
());
std
::
vector
<
data_size_t
>
used_data_indices
;
data_size_t
num_global_data
=
dataset
->
num_data_
;
*
num_global_data
=
dataset
->
num_data_
;
used_data_indices
->
clear
()
;
// sample local used data if need to partition
if
(
num_machines
>
1
&&
!
io_config_
.
is_pre_partition
)
{
const
data_size_t
*
query_boundaries
=
dataset
->
metadata_
.
query_boundaries
();
...
...
@@ -373,7 +373,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// if not contain query file, minimal sample unit is one record
for
(
data_size_t
i
=
0
;
i
<
dataset
->
num_data_
;
++
i
)
{
if
(
random_
.
NextInt
(
0
,
num_machines
)
==
rank
)
{
used_data_indices
.
push_back
(
i
);
used_data_indices
->
push_back
(
i
);
}
}
}
else
{
...
...
@@ -394,13 +394,13 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
++
qid
;
}
if
(
is_query_used
)
{
used_data_indices
.
push_back
(
i
);
used_data_indices
->
push_back
(
i
);
}
}
}
dataset
->
num_data_
=
static_cast
<
data_size_t
>
(
used_data_indices
.
size
());
dataset
->
num_data_
=
static_cast
<
data_size_t
>
(
(
*
used_data_indices
)
.
size
());
}
dataset
->
metadata_
.
PartitionLabel
(
used_data_indices
);
dataset
->
metadata_
.
PartitionLabel
(
*
used_data_indices
);
// read feature data
for
(
int
i
=
0
;
i
<
dataset
->
num_features_
;
++
i
)
{
// read feature size
...
...
@@ -422,8 +422,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
}
dataset
->
features_
.
emplace_back
(
std
::
unique_ptr
<
Feature
>
(
new
Feature
(
buffer
.
data
(),
num_global_data
,
used_data_indices
)
*
num_global_data
,
*
used_data_indices
)
));
}
dataset
->
features_
.
shrink_to_fit
();
...
...
src/io/metadata.cpp
View file @
1765b2e3
...
...
@@ -12,6 +12,9 @@ Metadata::Metadata() {
num_init_score_
=
0
;
num_data_
=
0
;
num_queries_
=
0
;
weight_load_from_file_
=
false
;
query_load_from_file_
=
false
;
init_score_load_from_file_
=
false
;
}
void
Metadata
::
Init
(
const
char
*
data_filename
)
{
...
...
@@ -40,6 +43,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
for
(
data_size_t
i
=
0
;
i
<
num_weights_
;
++
i
)
{
weights_
[
i
]
=
0.0
f
;
}
weight_load_from_file_
=
false
;
}
if
(
query_idx
>=
0
)
{
if
(
!
query_boundaries_
.
empty
())
{
...
...
@@ -52,6 +56,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
queries_
[
i
]
=
0
;
}
query_load_from_file_
=
false
;
}
}
...
...
@@ -185,87 +190,92 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
Log
::
Fatal
(
"Initial score size doesn't match data size"
);
}
}
else
{
if
(
!
queries_
.
empty
())
{
Log
::
Fatal
(
"Cannot used query_id for parallel training"
);
}
data_size_t
num_used_data
=
static_cast
<
data_size_t
>
(
used_data_indices
.
size
());
// check weights
if
(
weights_
.
size
()
>
0
&&
num_weights_
!=
num_all_data
)
{
weights_
.
clear
();
num_weights_
=
0
;
Log
::
Fatal
(
"Weights size doesn't match data size"
);
}
// check query boundries
if
(
!
query_boundaries_
.
empty
()
&&
query_boundaries_
[
num_queries_
]
!=
num_all_data
)
{
query_boundaries_
.
clear
();
num_queries_
=
0
;
Log
::
Fatal
(
"Query size doesn't match data size"
);
}
// contain initial score file
if
(
!
init_score_
.
empty
()
&&
(
num_init_score_
%
num_all_data
)
!=
0
)
{
init_score_
.
clear
();
num_init_score_
=
0
;
Log
::
Fatal
(
"Initial score size doesn't match data size"
);
}
// get local weights
if
(
!
weights_
.
empty
())
{
auto
old_weights
=
weights_
;
num_weights_
=
num_data_
;
weights_
=
std
::
vector
<
float
>
(
num_data_
);
if
(
weight_load_from_file_
)
{
if
(
weights_
.
size
()
>
0
&&
num_weights_
!=
num_all_data
)
{
weights_
.
clear
();
num_weights_
=
0
;
Log
::
Fatal
(
"Weights size doesn't match data size"
);
}
// get local weights
if
(
!
weights_
.
empty
())
{
auto
old_weights
=
weights_
;
num_weights_
=
num_data_
;
weights_
=
std
::
vector
<
float
>
(
num_data_
);
#pragma omp parallel for schedule(static)
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
used_data_indices
.
size
());
++
i
)
{
weights_
[
i
]
=
old_weights
[
used_data_indices
[
i
]];
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
used_data_indices
.
size
());
++
i
)
{
weights_
[
i
]
=
old_weights
[
used_data_indices
[
i
]];
}
old_weights
.
clear
();
}
old_weights
.
clear
();
}
// get local query boundaries
if
(
!
query_boundaries_
.
empty
())
{
std
::
vector
<
data_size_t
>
used_query
;
data_size_t
data_idx
=
0
;
for
(
data_size_t
qid
=
0
;
qid
<
num_queries_
&&
data_idx
<
num_used_data
;
++
qid
)
{
data_size_t
start
=
query_boundaries_
[
qid
];
data_size_t
end
=
query_boundaries_
[
qid
+
1
];
data_size_t
len
=
end
-
start
;
if
(
used_data_indices
[
data_idx
]
>
start
)
{
continue
;
}
else
if
(
used_data_indices
[
data_idx
]
==
start
)
{
if
(
num_used_data
>=
data_idx
+
len
&&
used_data_indices
[
data_idx
+
len
-
1
]
==
end
-
1
)
{
used_query
.
push_back
(
qid
);
data_idx
+=
len
;
if
(
query_load_from_file_
)
{
// check query boundries
if
(
!
query_boundaries_
.
empty
()
&&
query_boundaries_
[
num_queries_
]
!=
num_all_data
)
{
query_boundaries_
.
clear
();
num_queries_
=
0
;
Log
::
Fatal
(
"Query size doesn't match data size"
);
}
// get local query boundaries
if
(
!
query_boundaries_
.
empty
())
{
std
::
vector
<
data_size_t
>
used_query
;
data_size_t
data_idx
=
0
;
for
(
data_size_t
qid
=
0
;
qid
<
num_queries_
&&
data_idx
<
num_used_data
;
++
qid
)
{
data_size_t
start
=
query_boundaries_
[
qid
];
data_size_t
end
=
query_boundaries_
[
qid
+
1
];
data_size_t
len
=
end
-
start
;
if
(
used_data_indices
[
data_idx
]
>
start
)
{
continue
;
}
else
if
(
used_data_indices
[
data_idx
]
==
start
)
{
if
(
num_used_data
>=
data_idx
+
len
&&
used_data_indices
[
data_idx
+
len
-
1
]
==
end
-
1
)
{
used_query
.
push_back
(
qid
);
data_idx
+=
len
;
}
else
{
Log
::
Fatal
(
"Data partition error, data didn't match queries"
);
}
}
else
{
Log
::
Fatal
(
"Data partition error, data didn't match queries"
);
}
}
else
{
Log
::
Fatal
(
"Data partition error, data didn't match queries"
);
}
auto
old_query_boundaries
=
query_boundaries_
;
query_boundaries_
=
std
::
vector
<
data_size_t
>
(
used_query
.
size
()
+
1
);
num_queries_
=
static_cast
<
data_size_t
>
(
used_query
.
size
());
query_boundaries_
[
0
]
=
0
;
for
(
data_size_t
i
=
0
;
i
<
num_queries_
;
++
i
)
{
data_size_t
qid
=
used_query
[
i
];
data_size_t
len
=
old_query_boundaries
[
qid
+
1
]
-
old_query_boundaries
[
qid
];
query_boundaries_
[
i
+
1
]
=
query_boundaries_
[
i
]
+
len
;
}
old_query_boundaries
.
clear
();
}
auto
old_query_boundaries
=
query_boundaries_
;
query_boundaries_
=
std
::
vector
<
data_size_t
>
(
used_query
.
size
()
+
1
);
num_queries_
=
static_cast
<
data_size_t
>
(
used_query
.
size
());
query_boundaries_
[
0
]
=
0
;
for
(
data_size_t
i
=
0
;
i
<
num_queries_
;
++
i
)
{
data_size_t
qid
=
used_query
[
i
];
data_size_t
len
=
old_query_boundaries
[
qid
+
1
]
-
old_query_boundaries
[
qid
];
query_boundaries_
[
i
+
1
]
=
query_boundaries_
[
i
]
+
len
;
}
old_query_boundaries
.
clear
();
}
if
(
init_score_load_from_file_
)
{
// contain initial score file
if
(
!
init_score_
.
empty
()
&&
(
num_init_score_
%
num_all_data
)
!=
0
)
{
init_score_
.
clear
();
num_init_score_
=
0
;
Log
::
Fatal
(
"Initial score size doesn't match data size"
);
}
// get local initial scores
if
(
!
init_score_
.
empty
())
{
auto
old_scores
=
init_score_
;
int
num_class
=
static_cast
<
int
>
(
num_init_score_
/
num_all_data
);
num_init_score_
=
static_cast
<
int64_t
>
(
num_data_
)
*
num_class
;
init_score_
=
std
::
vector
<
double
>
(
num_init_score_
);
// get local initial scores
if
(
!
init_score_
.
empty
())
{
auto
old_scores
=
init_score_
;
int
num_class
=
static_cast
<
int
>
(
num_init_score_
/
num_all_data
);
num_init_score_
=
static_cast
<
int64_t
>
(
num_data_
)
*
num_class
;
init_score_
=
std
::
vector
<
double
>
(
num_init_score_
);
#pragma omp parallel for schedule(static)
for
(
int
k
=
0
;
k
<
num_class
;
++
k
){
for
(
size_t
i
=
0
;
i
<
used_data_indices
.
size
();
++
i
)
{
init_score_
[
k
*
num_data_
+
i
]
=
old_scores
[
k
*
num_all_data
+
used_data_indices
[
i
]];
for
(
int
k
=
0
;
k
<
num_class
;
++
k
)
{
for
(
size_t
i
=
0
;
i
<
used_data_indices
.
size
();
++
i
)
{
init_score_
[
k
*
num_data_
+
i
]
=
old_scores
[
k
*
num_all_data
+
used_data_indices
[
i
]];
}
}
old_scores
.
clear
();
}
old_scores
.
clear
();
}
// re-load query weight
LoadQueryWeights
();
}
...
...
@@ -289,6 +299,7 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) {
for
(
int64_t
i
=
0
;
i
<
num_init_score_
;
++
i
)
{
init_score_
[
i
]
=
init_score
[
i
];
}
init_score_load_from_file_
=
false
;
}
void
Metadata
::
SetLabel
(
const
float
*
label
,
data_size_t
len
)
{
...
...
@@ -326,6 +337,7 @@ void Metadata::SetWeights(const float* weights, data_size_t len) {
weights_
[
i
]
=
weights
[
i
];
}
LoadQueryWeights
();
weight_load_from_file_
=
false
;
}
void
Metadata
::
SetQuery
(
const
data_size_t
*
query
,
data_size_t
len
)
{
...
...
@@ -352,48 +364,7 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
query_boundaries_
[
i
+
1
]
=
query_boundaries_
[
i
]
+
query
[
i
];
}
LoadQueryWeights
();
}
void
Metadata
::
SetQueryId
(
const
data_size_t
*
query_id
,
data_size_t
len
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// save to nullptr
if
(
query_id
==
nullptr
||
len
==
0
)
{
query_boundaries_
.
clear
();
queries_
.
clear
();
num_queries_
=
0
;
return
;
}
if
(
num_data_
!=
len
)
{
Log
::
Fatal
(
"len of query id is not same with #data"
);
}
if
(
!
queries_
.
empty
())
{
queries_
.
clear
();
}
queries_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
for
(
data_size_t
i
=
0
;
i
<
num_weights_
;
++
i
)
{
queries_
[
i
]
=
query_id
[
i
];
}
// need convert query_id to boundaries
std
::
vector
<
data_size_t
>
tmp_buffer
;
data_size_t
last_qid
=
-
1
;
data_size_t
cur_cnt
=
0
;
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
if
(
last_qid
!=
queries_
[
i
])
{
if
(
cur_cnt
>
0
)
{
tmp_buffer
.
push_back
(
cur_cnt
);
}
cur_cnt
=
0
;
last_qid
=
queries_
[
i
];
}
++
cur_cnt
;
}
tmp_buffer
.
push_back
(
cur_cnt
);
query_boundaries_
=
std
::
vector
<
data_size_t
>
(
tmp_buffer
.
size
()
+
1
);
num_queries_
=
static_cast
<
data_size_t
>
(
tmp_buffer
.
size
());
query_boundaries_
[
0
]
=
0
;
for
(
size_t
i
=
0
;
i
<
tmp_buffer
.
size
();
++
i
)
{
query_boundaries_
[
i
+
1
]
=
query_boundaries_
[
i
]
+
tmp_buffer
[
i
];
}
queries_
.
clear
();
LoadQueryWeights
();
query_load_from_file_
=
false
;
}
void
Metadata
::
LoadWeights
()
{
...
...
@@ -415,6 +386,7 @@ void Metadata::LoadWeights() {
Common
::
Atof
(
reader
.
Lines
()[
i
].
c_str
(),
&
tmp_weight
);
weights_
[
i
]
=
static_cast
<
float
>
(
tmp_weight
);
}
weight_load_from_file_
=
true
;
}
void
Metadata
::
LoadInitialScore
()
{
...
...
@@ -457,6 +429,7 @@ void Metadata::LoadInitialScore() {
}
}
}
init_score_load_from_file_
=
true
;
}
void
Metadata
::
LoadQueryBoundaries
()
{
...
...
@@ -478,6 +451,7 @@ void Metadata::LoadQueryBoundaries() {
Common
::
Atoi
(
reader
.
Lines
()[
i
].
c_str
(),
&
tmp_cnt
);
query_boundaries_
[
i
+
1
]
=
query_boundaries_
[
i
]
+
static_cast
<
data_size_t
>
(
tmp_cnt
);
}
query_load_from_file_
=
true
;
}
void
Metadata
::
LoadQueryWeights
()
{
...
...
@@ -516,12 +490,14 @@ void Metadata::LoadFromMemory(const void* memory) {
weights_
=
std
::
vector
<
float
>
(
num_weights_
);
std
::
memcpy
(
weights_
.
data
(),
mem_ptr
,
sizeof
(
float
)
*
num_weights_
);
mem_ptr
+=
sizeof
(
float
)
*
num_weights_
;
weight_load_from_file_
=
true
;
}
if
(
num_queries_
>
0
)
{
if
(
!
query_boundaries_
.
empty
())
{
query_boundaries_
.
clear
();
}
query_boundaries_
=
std
::
vector
<
data_size_t
>
(
num_queries_
+
1
);
std
::
memcpy
(
query_boundaries_
.
data
(),
mem_ptr
,
sizeof
(
data_size_t
)
*
(
num_queries_
+
1
));
mem_ptr
+=
sizeof
(
data_size_t
)
*
(
num_queries_
+
1
);
query_load_from_file_
=
true
;
}
LoadQueryWeights
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment