Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
d2d32668
Commit
d2d32668
authored
Apr 26, 2023
by
yuguo960516yuguo
Browse files
2.3.0-dtk-22.04.2
parent
ad08b8ce
Pipeline
#226
failed with stages
in 0 seconds
Changes
268
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4357 additions
and
0 deletions
+4357
-0
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
+77
-0
paddle/fluid/distributed/ps/table/graph/class_macro.h
paddle/fluid/distributed/ps/table/graph/class_macro.h
+39
-0
paddle/fluid/distributed/ps/table/graph/graph_edge.cc
paddle/fluid/distributed/ps/table/graph/graph_edge.cc
+30
-0
paddle/fluid/distributed/ps/table/graph/graph_edge.h
paddle/fluid/distributed/ps/table/graph/graph_edge.h
+47
-0
paddle/fluid/distributed/ps/table/graph/graph_node.cc
paddle/fluid/distributed/ps/table/graph/graph_node.cc
+121
-0
paddle/fluid/distributed/ps/table/graph/graph_node.h
paddle/fluid/distributed/ps/table/graph/graph_node.h
+138
-0
paddle/fluid/distributed/ps/table/graph/graph_weighted_sampler.cc
...luid/distributed/ps/table/graph/graph_weighted_sampler.cc
+164
-0
paddle/fluid/distributed/ps/table/graph/graph_weighted_sampler.h
...fluid/distributed/ps/table/graph/graph_weighted_sampler.h
+64
-0
paddle/fluid/distributed/ps/table/memory_dense_table.cc
paddle/fluid/distributed/ps/table/memory_dense_table.cc
+420
-0
paddle/fluid/distributed/ps/table/memory_dense_table.h
paddle/fluid/distributed/ps/table/memory_dense_table.h
+82
-0
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc
+246
-0
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
+87
-0
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
+743
-0
paddle/fluid/distributed/ps/table/memory_sparse_table.h
paddle/fluid/distributed/ps/table/memory_sparse_table.h
+109
-0
paddle/fluid/distributed/ps/table/sparse_accessor.cc
paddle/fluid/distributed/ps/table/sparse_accessor.cc
+306
-0
paddle/fluid/distributed/ps/table/sparse_accessor.h
paddle/fluid/distributed/ps/table/sparse_accessor.h
+195
-0
paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc
paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc
+256
-0
paddle/fluid/distributed/ps/table/sparse_sgd_rule.h
paddle/fluid/distributed/ps/table/sparse_sgd_rule.h
+148
-0
paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
+998
-0
paddle/fluid/distributed/ps/table/ssd_sparse_table.h
paddle/fluid/distributed/ps/table/ssd_sparse_table.h
+87
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace
paddle
{
namespace
distributed
{
struct
PullSparseValue
{
PullSparseValue
()
{}
explicit
PullSparseValue
(
int
numel
,
int
dim
)
:
numel_
(
numel
),
dim_
(
dim
),
is_training_
(
true
),
feasigns_
(
nullptr
),
frequencies_
(
nullptr
)
{}
explicit
PullSparseValue
(
std
::
vector
<
uint64_t
>&
feasigns
,
// NOLINT
std
::
vector
<
uint32_t
>&
frequencies
,
// NOLINT
int
dim
)
{
numel_
=
feasigns
.
size
();
dim_
=
dim
;
is_training_
=
true
;
feasigns_
=
feasigns
.
data
();
frequencies_
=
frequencies
.
data
();
}
void
DeserializeFromBytes
(
void
*
bytes
)
{
/*
|---isTraining--------------|
|---8*{num}B(keysData)------|
|---4*{num}B(Frequencies)---|
*/
auto
*
begin
=
reinterpret_cast
<
char
*>
(
bytes
);
is_training_
=
reinterpret_cast
<
bool
*>
(
begin
)[
0
];
feasigns_
=
reinterpret_cast
<
uint64_t
*>
(
begin
+
sizeof
(
bool
));
frequencies_
=
reinterpret_cast
<
uint32_t
*>
(
begin
+
sizeof
(
bool
)
+
sizeof
(
uint64_t
)
*
numel_
);
}
void
Fission
(
const
int
shard_id
,
const
int
shard_num
,
std
::
vector
<
int
>*
offset_shard
)
const
{
offset_shard
->
reserve
(
numel_
/
shard_num
+
1
);
for
(
int
x
=
0
;
x
<
numel_
;
++
x
)
{
if
(
int
(
feasigns_
[
x
]
%
shard_num
)
==
shard_id
)
{
offset_shard
->
push_back
(
x
);
}
}
}
int
numel_
;
int
dim_
;
bool
is_training_
;
uint64_t
*
feasigns_
;
uint32_t
*
frequencies_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/graph/class_macro.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define DECLARE_GRAPH_FRIEND_CLASS(a) friend class a;
#define DECLARE_1_FRIEND_CLASS(a, ...) DECLARE_GRAPH_FRIEND_CLASS(a)
#define DECLARE_2_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_1_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_3_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_2_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_4_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_3_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_5_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_4_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_6_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_5_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_7_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_6_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_8_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_7_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_9_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_8_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_10_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_9_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_11_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_10_FRIEND_CLASS(__VA_ARGS__)
#define REGISTER_GRAPH_FRIEND_CLASS(n, ...) \
DECLARE_##n##_FRIEND_CLASS(__VA_ARGS__)
paddle/fluid/distributed/ps/table/graph/graph_edge.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/graph/graph_edge.h"
#include <cstring>
namespace
paddle
{
namespace
distributed
{
void
GraphEdgeBlob
::
add_edge
(
int64_t
id
,
float
weight
=
1
)
{
id_arr
.
push_back
(
id
);
}
void
WeightedGraphEdgeBlob
::
add_edge
(
int64_t
id
,
float
weight
=
1
)
{
id_arr
.
push_back
(
id
);
weight_arr
.
push_back
(
weight
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/graph/graph_edge.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <cstdint>
#include <vector>
namespace
paddle
{
namespace
distributed
{
class
GraphEdgeBlob
{
public:
GraphEdgeBlob
()
{}
virtual
~
GraphEdgeBlob
()
{}
size_t
size
()
{
return
id_arr
.
size
();
}
virtual
void
add_edge
(
int64_t
id
,
float
weight
);
int64_t
get_id
(
int
idx
)
{
return
id_arr
[
idx
];
}
virtual
float
get_weight
(
int
idx
)
{
return
1
;
}
std
::
vector
<
int64_t
>&
export_id_array
()
{
return
id_arr
;
}
protected:
std
::
vector
<
int64_t
>
id_arr
;
};
class
WeightedGraphEdgeBlob
:
public
GraphEdgeBlob
{
public:
WeightedGraphEdgeBlob
()
{}
virtual
~
WeightedGraphEdgeBlob
()
{}
virtual
void
add_edge
(
int64_t
id
,
float
weight
);
virtual
float
get_weight
(
int
idx
)
{
return
weight_arr
[
idx
];
}
protected:
std
::
vector
<
float
>
weight_arr
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/graph/graph_node.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include <cstring>
namespace
paddle
{
namespace
distributed
{
GraphNode
::~
GraphNode
()
{
if
(
sampler
!=
nullptr
)
{
delete
sampler
;
sampler
=
nullptr
;
}
if
(
edges
!=
nullptr
)
{
delete
edges
;
edges
=
nullptr
;
}
}
int
Node
::
weight_size
=
sizeof
(
float
);
int
Node
::
id_size
=
sizeof
(
uint64_t
);
int
Node
::
int_size
=
sizeof
(
int
);
int
Node
::
get_size
(
bool
need_feature
)
{
return
id_size
+
int_size
;
}
void
Node
::
to_buffer
(
char
*
buffer
,
bool
need_feature
)
{
memcpy
(
buffer
,
&
id
,
id_size
);
buffer
+=
id_size
;
int
feat_num
=
0
;
memcpy
(
buffer
,
&
feat_num
,
sizeof
(
int
));
}
void
Node
::
recover_from_buffer
(
char
*
buffer
)
{
memcpy
(
&
id
,
buffer
,
id_size
);
}
int
FeatureNode
::
get_size
(
bool
need_feature
)
{
int
size
=
id_size
+
int_size
;
// id, feat_num
if
(
need_feature
)
{
size
+=
feature
.
size
()
*
int_size
;
for
(
const
std
::
string
&
fea
:
feature
)
{
size
+=
fea
.
size
();
}
}
return
size
;
}
void
GraphNode
::
build_edges
(
bool
is_weighted
)
{
if
(
edges
==
nullptr
)
{
if
(
is_weighted
==
true
)
{
edges
=
new
WeightedGraphEdgeBlob
();
}
else
{
edges
=
new
GraphEdgeBlob
();
}
}
}
void
GraphNode
::
build_sampler
(
std
::
string
sample_type
)
{
if
(
sampler
!=
nullptr
)
{
return
;
}
if
(
sample_type
==
"random"
)
{
sampler
=
new
RandomSampler
();
}
else
if
(
sample_type
==
"weighted"
)
{
sampler
=
new
WeightedSampler
();
}
sampler
->
build
(
edges
);
}
void
FeatureNode
::
to_buffer
(
char
*
buffer
,
bool
need_feature
)
{
memcpy
(
buffer
,
&
id
,
id_size
);
buffer
+=
id_size
;
int
feat_num
=
0
;
int
feat_len
;
if
(
need_feature
)
{
feat_num
+=
feature
.
size
();
memcpy
(
buffer
,
&
feat_num
,
sizeof
(
int
));
buffer
+=
sizeof
(
int
);
for
(
int
i
=
0
;
i
<
feat_num
;
++
i
)
{
feat_len
=
feature
[
i
].
size
();
memcpy
(
buffer
,
&
feat_len
,
sizeof
(
int
));
buffer
+=
sizeof
(
int
);
memcpy
(
buffer
,
feature
[
i
].
c_str
(),
feature
[
i
].
size
());
buffer
+=
feature
[
i
].
size
();
}
}
else
{
memcpy
(
buffer
,
&
feat_num
,
sizeof
(
int
));
}
}
void
FeatureNode
::
recover_from_buffer
(
char
*
buffer
)
{
int
feat_num
,
feat_len
;
memcpy
(
&
id
,
buffer
,
id_size
);
buffer
+=
id_size
;
memcpy
(
&
feat_num
,
buffer
,
sizeof
(
int
));
buffer
+=
sizeof
(
int
);
feature
.
clear
();
for
(
int
i
=
0
;
i
<
feat_num
;
++
i
)
{
memcpy
(
&
feat_len
,
buffer
,
sizeof
(
int
));
buffer
+=
sizeof
(
int
);
char
str
[
feat_len
+
1
];
memcpy
(
str
,
buffer
,
feat_len
);
buffer
+=
feat_len
;
str
[
feat_len
]
=
'\0'
;
feature
.
push_back
(
std
::
string
(
str
));
}
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/graph/graph_node.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <iostream>
#include <memory>
#include <sstream>
#include <vector>
#include "paddle/fluid/distributed/ps/table/graph/graph_weighted_sampler.h"
namespace
paddle
{
namespace
distributed
{
class
Node
{
public:
Node
()
{}
Node
(
uint64_t
id
)
:
id
(
id
)
{}
virtual
~
Node
()
{}
static
int
id_size
,
int_size
,
weight_size
;
uint64_t
get_id
()
{
return
id
;
}
void
set_id
(
uint64_t
id
)
{
this
->
id
=
id
;
}
virtual
void
build_edges
(
bool
is_weighted
)
{}
virtual
void
build_sampler
(
std
::
string
sample_type
)
{}
virtual
void
add_edge
(
uint64_t
id
,
float
weight
)
{}
virtual
std
::
vector
<
int
>
sample_k
(
int
k
,
const
std
::
shared_ptr
<
std
::
mt19937_64
>
rng
)
{
return
std
::
vector
<
int
>
();
}
virtual
uint64_t
get_neighbor_id
(
int
idx
)
{
return
0
;
}
virtual
float
get_neighbor_weight
(
int
idx
)
{
return
1.
;
}
virtual
int
get_size
(
bool
need_feature
);
virtual
void
to_buffer
(
char
*
buffer
,
bool
need_feature
);
virtual
void
recover_from_buffer
(
char
*
buffer
);
virtual
std
::
string
get_feature
(
int
idx
)
{
return
std
::
string
(
""
);
}
virtual
void
set_feature
(
int
idx
,
std
::
string
str
)
{}
virtual
void
set_feature_size
(
int
size
)
{}
virtual
int
get_feature_size
()
{
return
0
;
}
virtual
size_t
get_neighbor_size
()
{
return
0
;
}
protected:
uint64_t
id
;
bool
is_weighted
;
};
class
GraphNode
:
public
Node
{
public:
GraphNode
()
:
Node
(),
sampler
(
nullptr
),
edges
(
nullptr
)
{}
GraphNode
(
uint64_t
id
)
:
Node
(
id
),
sampler
(
nullptr
),
edges
(
nullptr
)
{}
virtual
~
GraphNode
();
virtual
void
build_edges
(
bool
is_weighted
);
virtual
void
build_sampler
(
std
::
string
sample_type
);
virtual
void
add_edge
(
uint64_t
id
,
float
weight
)
{
edges
->
add_edge
(
id
,
weight
);
}
virtual
std
::
vector
<
int
>
sample_k
(
int
k
,
const
std
::
shared_ptr
<
std
::
mt19937_64
>
rng
)
{
return
sampler
->
sample_k
(
k
,
rng
);
}
virtual
uint64_t
get_neighbor_id
(
int
idx
)
{
return
edges
->
get_id
(
idx
);
}
virtual
float
get_neighbor_weight
(
int
idx
)
{
return
edges
->
get_weight
(
idx
);
}
virtual
size_t
get_neighbor_size
()
{
return
edges
->
size
();
}
protected:
Sampler
*
sampler
;
GraphEdgeBlob
*
edges
;
};
class
FeatureNode
:
public
Node
{
public:
FeatureNode
()
:
Node
()
{}
FeatureNode
(
uint64_t
id
)
:
Node
(
id
)
{}
virtual
~
FeatureNode
()
{}
virtual
int
get_size
(
bool
need_feature
);
virtual
void
to_buffer
(
char
*
buffer
,
bool
need_feature
);
virtual
void
recover_from_buffer
(
char
*
buffer
);
virtual
std
::
string
get_feature
(
int
idx
)
{
if
(
idx
<
(
int
)
this
->
feature
.
size
())
{
return
this
->
feature
[
idx
];
}
else
{
return
std
::
string
(
""
);
}
}
virtual
void
set_feature
(
int
idx
,
std
::
string
str
)
{
if
(
idx
>=
(
int
)
this
->
feature
.
size
())
{
this
->
feature
.
resize
(
idx
+
1
);
}
this
->
feature
[
idx
]
=
str
;
}
virtual
void
set_feature_size
(
int
size
)
{
this
->
feature
.
resize
(
size
);
}
virtual
int
get_feature_size
()
{
return
this
->
feature
.
size
();
}
template
<
typename
T
>
static
std
::
string
parse_value_to_bytes
(
std
::
vector
<
std
::
string
>
feat_str
)
{
T
v
;
size_t
Tsize
=
sizeof
(
T
)
*
feat_str
.
size
();
char
buffer
[
Tsize
];
for
(
size_t
i
=
0
;
i
<
feat_str
.
size
();
i
++
)
{
std
::
stringstream
ss
(
feat_str
[
i
]);
ss
>>
v
;
std
::
memcpy
(
buffer
+
sizeof
(
T
)
*
i
,
(
char
*
)
&
v
,
sizeof
(
T
));
}
return
std
::
string
(
buffer
,
Tsize
);
}
template
<
typename
T
>
static
std
::
vector
<
T
>
parse_bytes_to_array
(
std
::
string
feat_str
)
{
T
v
;
std
::
vector
<
T
>
out
;
size_t
start
=
0
;
const
char
*
buffer
=
feat_str
.
data
();
while
(
start
<
feat_str
.
size
())
{
std
::
memcpy
((
char
*
)
&
v
,
buffer
+
start
,
sizeof
(
T
));
start
+=
sizeof
(
T
);
out
.
push_back
(
v
);
}
return
out
;
}
protected:
std
::
vector
<
std
::
string
>
feature
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/graph/graph_weighted_sampler.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/graph/graph_weighted_sampler.h"
#include <iostream>
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/generator.h"
namespace
paddle
{
namespace
distributed
{
void
RandomSampler
::
build
(
GraphEdgeBlob
*
edges
)
{
this
->
edges
=
edges
;
}
std
::
vector
<
int
>
RandomSampler
::
sample_k
(
int
k
,
const
std
::
shared_ptr
<
std
::
mt19937_64
>
rng
)
{
int
n
=
edges
->
size
();
if
(
k
>=
n
)
{
k
=
n
;
std
::
vector
<
int
>
sample_result
;
for
(
int
i
=
0
;
i
<
k
;
i
++
)
{
sample_result
.
push_back
(
i
);
}
return
sample_result
;
}
std
::
vector
<
int
>
sample_result
;
std
::
unordered_map
<
int
,
int
>
replace_map
;
while
(
k
--
)
{
std
::
uniform_int_distribution
<
int
>
distrib
(
0
,
n
-
1
);
int
rand_int
=
distrib
(
*
rng
);
auto
iter
=
replace_map
.
find
(
rand_int
);
if
(
iter
==
replace_map
.
end
())
{
sample_result
.
push_back
(
rand_int
);
}
else
{
sample_result
.
push_back
(
iter
->
second
);
}
iter
=
replace_map
.
find
(
n
-
1
);
if
(
iter
==
replace_map
.
end
())
{
replace_map
[
rand_int
]
=
n
-
1
;
}
else
{
replace_map
[
rand_int
]
=
iter
->
second
;
}
--
n
;
}
return
sample_result
;
}
WeightedSampler
::
WeightedSampler
()
{
left
=
nullptr
;
right
=
nullptr
;
edges
=
nullptr
;
}
WeightedSampler
::~
WeightedSampler
()
{
if
(
left
!=
nullptr
)
{
delete
left
;
left
=
nullptr
;
}
if
(
right
!=
nullptr
)
{
delete
right
;
right
=
nullptr
;
}
}
void
WeightedSampler
::
build
(
GraphEdgeBlob
*
edges
)
{
if
(
left
!=
nullptr
)
{
delete
left
;
left
=
nullptr
;
}
if
(
right
!=
nullptr
)
{
delete
right
;
right
=
nullptr
;
}
return
build_one
((
WeightedGraphEdgeBlob
*
)
edges
,
0
,
edges
->
size
());
}
void
WeightedSampler
::
build_one
(
WeightedGraphEdgeBlob
*
edges
,
int
start
,
int
end
)
{
count
=
0
;
this
->
edges
=
edges
;
if
(
start
+
1
==
end
)
{
left
=
right
=
nullptr
;
idx
=
start
;
count
=
1
;
weight
=
edges
->
get_weight
(
idx
);
}
else
{
left
=
new
WeightedSampler
();
right
=
new
WeightedSampler
();
left
->
build_one
(
edges
,
start
,
start
+
(
end
-
start
)
/
2
);
right
->
build_one
(
edges
,
start
+
(
end
-
start
)
/
2
,
end
);
weight
=
left
->
weight
+
right
->
weight
;
count
=
left
->
count
+
right
->
count
;
}
}
std
::
vector
<
int
>
WeightedSampler
::
sample_k
(
int
k
,
const
std
::
shared_ptr
<
std
::
mt19937_64
>
rng
)
{
if
(
k
>=
count
)
{
k
=
count
;
std
::
vector
<
int
>
sample_result
;
for
(
int
i
=
0
;
i
<
k
;
i
++
)
{
sample_result
.
push_back
(
i
);
}
return
sample_result
;
}
std
::
vector
<
int
>
sample_result
;
float
subtract
;
std
::
unordered_map
<
WeightedSampler
*
,
float
>
subtract_weight_map
;
std
::
unordered_map
<
WeightedSampler
*
,
int
>
subtract_count_map
;
std
::
uniform_real_distribution
<
float
>
distrib
(
0
,
1.0
);
while
(
k
--
)
{
float
query_weight
=
distrib
(
*
rng
);
query_weight
*=
weight
-
subtract_weight_map
[
this
];
sample_result
.
push_back
(
sample
(
query_weight
,
subtract_weight_map
,
subtract_count_map
,
subtract
));
}
return
sample_result
;
}
int
WeightedSampler
::
sample
(
float
query_weight
,
std
::
unordered_map
<
WeightedSampler
*
,
float
>
&
subtract_weight_map
,
std
::
unordered_map
<
WeightedSampler
*
,
int
>
&
subtract_count_map
,
float
&
subtract
)
{
if
(
left
==
nullptr
)
{
subtract_weight_map
[
this
]
=
weight
;
subtract
=
weight
;
subtract_count_map
[
this
]
=
1
;
return
idx
;
}
int
left_count
=
left
->
count
-
subtract_count_map
[
left
];
int
right_count
=
right
->
count
-
subtract_count_map
[
right
];
float
left_subtract
=
subtract_weight_map
[
left
];
int
return_idx
;
if
(
right_count
==
0
||
left_count
>
0
&&
left
->
weight
-
left_subtract
>=
query_weight
)
{
return_idx
=
left
->
sample
(
query_weight
,
subtract_weight_map
,
subtract_count_map
,
subtract
);
}
else
{
return_idx
=
right
->
sample
(
query_weight
-
(
left
->
weight
-
left_subtract
),
subtract_weight_map
,
subtract_count_map
,
subtract
);
}
subtract_weight_map
[
this
]
+=
subtract
;
subtract_count_map
[
this
]
++
;
return
return_idx
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/graph/graph_weighted_sampler.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ctime>
#include <memory>
#include <random>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps/table/graph/graph_edge.h"
namespace
paddle
{
namespace
distributed
{
class
Sampler
{
public:
virtual
~
Sampler
()
{}
virtual
void
build
(
GraphEdgeBlob
*
edges
)
=
0
;
virtual
std
::
vector
<
int
>
sample_k
(
int
k
,
const
std
::
shared_ptr
<
std
::
mt19937_64
>
rng
)
=
0
;
};
class
RandomSampler
:
public
Sampler
{
public:
virtual
~
RandomSampler
()
{}
virtual
void
build
(
GraphEdgeBlob
*
edges
);
virtual
std
::
vector
<
int
>
sample_k
(
int
k
,
const
std
::
shared_ptr
<
std
::
mt19937_64
>
rng
);
GraphEdgeBlob
*
edges
;
};
class
WeightedSampler
:
public
Sampler
{
public:
WeightedSampler
();
virtual
~
WeightedSampler
();
WeightedSampler
*
left
,
*
right
;
float
weight
;
int
count
;
int
idx
;
GraphEdgeBlob
*
edges
;
virtual
void
build
(
GraphEdgeBlob
*
edges
);
virtual
void
build_one
(
WeightedGraphEdgeBlob
*
edges
,
int
start
,
int
end
);
virtual
std
::
vector
<
int
>
sample_k
(
int
k
,
const
std
::
shared_ptr
<
std
::
mt19937_64
>
rng
);
private:
int
sample
(
float
query_weight
,
std
::
unordered_map
<
WeightedSampler
*
,
float
>
&
subtract_weight_map
,
std
::
unordered_map
<
WeightedSampler
*
,
int
>
&
subtract_count_map
,
float
&
subtract
);
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/memory_dense_table.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/memory_dense_table.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
int
FLAGS_pslib_table_save_max_retry_dense
=
3
;
void
MemoryDenseTable
::
CreateInitializer
(
const
std
::
string
&
attr
,
const
std
::
string
&
name
)
{
auto
slices
=
string
::
split_string
<
std
::
string
>
(
attr
,
"&"
);
if
(
slices
[
0
]
==
"gaussian_random"
)
{
initializers_
[
name
]
=
new
GaussianInitializer
(
slices
);
}
else
if
(
slices
[
0
]
==
"fill_constant"
)
{
initializers_
[
name
]
=
new
FillConstantInitializer
(
slices
);
}
else
if
(
slices
[
0
]
==
"uniform_random"
)
{
initializers_
[
name
]
=
new
UniformInitializer
(
slices
);
}
else
if
(
slices
[
0
]
==
"truncated_gaussian_random"
)
{
initializers_
[
name
]
=
new
TruncatedGaussianInitializer
(
slices
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s can not be supported"
,
name
));
}
}
int32_t
MemoryDenseTable
::
Initialize
()
{
_shards_task_pool
.
resize
(
task_pool_size_
);
for
(
size_t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
_shards_task_pool
[
i
].
reset
(
new
::
ThreadPool
(
1
));
}
sync
=
_config
.
common
().
sync
();
VLOG
(
1
)
<<
"table "
<<
_config
.
common
().
table_name
()
<<
" is sync: "
<<
sync
;
_global_lr
=
new
float
(
1.0
);
InitializeValue
();
InitializeOptimizer
();
return
0
;
}
int32_t
MemoryDenseTable
::
InitializeValue
()
{
auto
common
=
_config
.
common
();
int
size
=
static_cast
<
int
>
(
common
.
params
().
size
());
values_
.
resize
(
size
);
total_dim_
=
0
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
varname
=
common
.
params
()[
x
];
auto
&
dim
=
common
.
dims
()[
x
];
if
(
varname
==
"Param"
)
{
param_dim_
=
dim
;
param_idx_
=
x
;
}
auto
&
initializer
=
common
.
initializers
()[
x
];
total_dim_
+=
dim
;
CreateInitializer
(
initializer
,
varname
);
values_
[
x
].
resize
(
dim
);
names_index_
[
varname
]
=
x
;
for
(
size_t
y
=
0
;
y
<
dim
;
++
y
)
{
values_
[
x
][
y
]
=
initializers_
[
varname
]
->
GetValue
();
}
}
fixed_len_params_dim_
=
0
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
dim
=
common
.
dims
()[
x
];
if
(
static_cast
<
int
>
(
dim
)
!=
param_dim_
)
{
fixed_len_params_dim_
+=
dim
;
}
else
{
param_col_ids_
.
push_back
(
x
);
}
}
if
(
_config
.
common
().
name
()
==
"adam_d2sum"
)
{
param_col_ids_
.
insert
(
param_col_ids_
.
begin
()
+
1
,
-
1
);
}
VLOG
(
1
)
<<
"MemoryDenseTable::InitializeValue total dim: "
<<
total_dim_
<<
" fixed_len_params_dim: "
<<
fixed_len_params_dim_
;
pull_reservoir_
=
ReservoirValue
<
float
>
(
param_dim_
);
return
0
;
}
int32_t
MemoryDenseTable
::
InitializeOptimizer
()
{
auto
common
=
_config
.
common
();
auto
name
=
common
.
name
();
auto
attrs
=
common
.
attributes
();
if
(
name
==
"sgd"
)
{
optimizer_
=
std
::
make_shared
<
DSGD
>
(
common
,
&
values_
);
optimizer_
->
SetGlobalLR
(
_global_lr
);
}
else
if
(
name
==
"adam"
)
{
optimizer_
=
std
::
make_shared
<
DAdam
>
(
common
,
&
values_
);
optimizer_
->
SetGlobalLR
(
_global_lr
);
}
else
if
(
name
==
"adam_d2sum"
)
{
optimizer_
=
std
::
make_shared
<
DAdamD2Sum
>
(
common
,
&
values_
);
// optimizer_->SetGlobalLR(_global_lr); //no use
}
else
if
(
name
==
"sum"
)
{
optimizer_
=
std
::
make_shared
<
DSUM
>
(
common
,
&
values_
);
}
else
if
(
name
==
"summary"
)
{
optimizer_
=
std
::
make_shared
<
DSummary
>
(
common
,
&
values_
);
}
else
{
VLOG
(
0
)
<<
"init optimizer failed"
;
}
VLOG
(
3
)
<<
"init optimizer "
<<
name
<<
" done"
;
return
0
;
}
int32_t
MemoryDenseTable
::
SetGlobalLR
(
float
*
lr
)
{
_global_lr
=
lr
;
optimizer_
->
SetGlobalLR
(
_global_lr
);
return
0
;
}
int32_t
MemoryDenseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
float
*
pull_values
=
context
.
pull_context
.
values
;
return
PullDense
(
pull_values
,
context
.
num
);
}
int32_t
MemoryDenseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
if
(
context
.
push_context
.
values
!=
nullptr
)
{
if
(
!
context
.
push_context
.
is_param
)
{
return
PushDense
(
context
.
push_context
.
values
,
context
.
num
);
}
else
{
return
PushDenseParam
(
context
.
push_context
.
values
,
context
.
num
);
}
}
return
0
;
}
int32_t
MemoryDenseTable
::
PullDense
(
float
*
pull_values
,
size_t
num
)
{
std
::
copy
(
values_
[
param_idx_
].
begin
(),
values_
[
param_idx_
].
end
(),
pull_values
);
return
0
;
}
int32_t
MemoryDenseTable
::
PushDenseParam
(
const
float
*
values
,
size_t
num
)
{
PADDLE_ENFORCE_GE
(
num
,
param_dim_
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"update desne param numel expected %d, but got %d"
,
param_dim_
,
num
));
std
::
copy_n
(
values
,
param_dim_
,
values_
[
param_idx_
].
begin
());
return
0
;
}
int32_t
MemoryDenseTable
::
Pour
()
{
pull_reservoir_
.
avg
();
_PushDense
(
pull_reservoir_
.
values
.
data
(),
pull_reservoir_
.
values
.
size
());
pull_reservoir_
.
reset
();
return
0
;
}
int32_t
MemoryDenseTable
::
PushDense
(
const
float
*
values
,
size_t
num
)
{
if
(
sync
)
{
std
::
future
<
int
>
task
=
_shards_task_pool
[
0
]
->
enqueue
([
this
,
&
values
]()
->
int
{
pull_reservoir_
.
add
(
values
,
param_dim_
);
return
0
;
});
task
.
wait
();
}
else
{
_PushDense
(
values
,
num
);
}
return
0
;
}
int32_t
MemoryDenseTable
::
_PushDense
(
const
float
*
values
,
size_t
num
)
{
PADDLE_ENFORCE_GE
(
num
,
param_dim_
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"update desne numel expected %d, but got %d"
,
param_dim_
,
num
));
std
::
vector
<
int
>
buckets
=
bucket
(
param_dim_
,
task_pool_size_
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
task_pool_size_
);
for
(
int
shard_id
=
0
;
shard_id
<
task_pool_size_
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
]
->
enqueue
(
[
this
,
shard_id
,
&
buckets
,
&
values
]()
->
int
{
auto
begin
=
buckets
[
shard_id
];
auto
end
=
buckets
[
shard_id
+
1
];
optimizer_
->
Update
(
values
,
param_dim_
,
begin
,
end
);
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
VLOG
(
2
)
<<
"debug MemoryDenseTable::_push_dense done"
;
return
0
;
}
int32_t
MemoryDenseTable
::
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
if
(
param_dim_
<=
0
)
{
return
0
;
}
std
::
string
table_path
=
TableDir
(
path
);
auto
file_list
=
_afs_client
.
list
(
table_path
);
std
::
sort
(
file_list
.
begin
(),
file_list
.
end
());
for
(
auto
ff
:
file_list
)
{
VLOG
(
1
)
<<
"load dense table file list: "
<<
ff
;
}
size_t
dim_num_per_file
=
_config
.
accessor
().
fea_dim
()
/
file_list
.
size
()
+
1
;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t
dim_num_per_shard
=
_value_accesor
->
GetAccessorInfo
().
fea_dim
/
_shard_num
+
1
;
size_t
start_dim_idx
=
dim_num_per_shard
*
_shard_idx
;
size_t
start_file_idx
=
start_dim_idx
/
dim_num_per_file
;
size_t
end_file_idx
=
(
start_dim_idx
+
param_dim_
)
/
dim_num_per_file
;
end_file_idx
=
end_file_idx
<
file_list
.
size
()
?
end_file_idx
:
file_list
.
size
()
-
1
;
VLOG
(
2
)
<<
"load dense table start_file_idx: "
<<
start_file_idx
<<
" end_file_idx: "
<<
end_file_idx
;
int
load_param
=
atoi
(
param
.
c_str
());
FsChannelConfig
channel_config
;
channel_config
.
converter
=
_value_accesor
->
Converter
(
load_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
Converter
(
load_param
).
deconverter
;
bool
is_read_failed
=
false
;
int
err_no
=
0
;
int
retry_num
=
0
;
do
{
is_read_failed
=
false
;
try
{
int
dim_idx
=
0
;
float
data_buffer
[
5
];
float
*
data_buff_ptr
=
data_buffer
;
std
::
string
line_data
;
auto
common
=
_config
.
common
();
for
(
size_t
i
=
start_file_idx
;
i
<
end_file_idx
+
1
;
++
i
)
{
channel_config
.
path
=
file_list
[
i
];
err_no
=
0
;
auto
read_channel
=
_afs_client
.
open_r
(
channel_config
,
0
,
&
err_no
);
size_t
file_start_idx
=
start_dim_idx
-
i
*
dim_num_per_file
;
// not all file contains param and the length of last file containing
// param may not equal to others
size_t
file_dim_idx
=
0
;
for
(;
file_dim_idx
<
dim_num_per_file
;
++
file_dim_idx
)
{
if
(
read_channel
->
read_line
(
line_data
)
!=
0
)
{
break
;
}
if
(
dim_idx
>=
param_dim_
)
{
break
;
}
if
(
file_dim_idx
<
file_start_idx
)
{
continue
;
}
size_t
str_len
=
paddle
::
string
::
str_to_float
(
line_data
.
data
(),
data_buff_ptr
);
CHECK
(
str_len
==
param_col_ids_
.
size
())
<<
"expect "
<<
param_col_ids_
.
size
()
<<
" float, but got "
<<
str_len
;
for
(
size_t
col_idx
=
0
;
col_idx
<
str_len
;
++
col_idx
)
{
if
(
param_col_ids_
[
col_idx
]
<
0
)
{
continue
;
}
values_
[
param_col_ids_
[
col_idx
]][
dim_idx
]
=
data_buffer
[
col_idx
];
VLOG
(
2
)
<<
"MemoryDenseTable::load param x: "
<<
param_col_ids_
[
col_idx
]
<<
" y: "
<<
dim_idx
<<
" value: "
<<
values_
[
param_col_ids_
[
col_idx
]][
dim_idx
]
<<
" line "
<<
file_dim_idx
;
}
++
dim_idx
;
}
read_channel
->
close
();
VLOG
(
1
)
<<
"DownpourDenseTable load done "
<<
channel_config
.
path
<<
" file_start_idx: "
<<
file_start_idx
<<
" dim_idx: "
<<
dim_idx
;
if
(
err_no
==
-
1
)
{
if
(
retry_num
>
FLAGS_pslib_table_save_max_retry_dense
)
{
LOG
(
ERROR
)
<<
"DownpourDenseTable load failed reach max limit!"
;
exit
(
-
1
);
}
++
retry_num
;
--
i
;
LOG
(
ERROR
)
<<
"DownpourDenseTable load failed after read , retry it! path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
continue
;
}
retry_num
=
0
;
start_dim_idx
+=
file_dim_idx
-
file_start_idx
;
LOG
(
INFO
)
<<
"DownpourDenseTable load success, path:"
<<
channel_config
.
path
;
}
}
catch
(...)
{
is_read_failed
=
true
;
LOG
(
ERROR
)
<<
"DownpourDenseTable load failed, retry it! path:"
<<
channel_config
.
path
;
}
}
while
(
is_read_failed
);
return
0
;
}
int32_t
MemoryDenseTable
::
Save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
int
save_param
=
atoi
(
param
.
c_str
());
uint32_t
feasign_size
;
VLOG
(
0
)
<<
"MemoryDenseTable::save path "
<<
path
;
FsChannelConfig
channel_config
;
if
(
_config
.
compress_in_save
())
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d.gz"
,
TableDir
(
path
).
c_str
(),
_shard_idx
);
}
else
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d"
,
TableDir
(
path
).
c_str
(),
_shard_idx
);
}
_afs_client
.
remove
(
channel_config
.
path
);
channel_config
.
converter
=
_value_accesor
->
Converter
(
save_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
Converter
(
save_param
).
deconverter
;
bool
is_write_failed
=
false
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
result_buffer_param
(
param_dim_
,
std
::
vector
<
std
::
string
>
());
std
::
vector
<
std
::
string
>
result_buffer_fixed_len
;
result_buffer_fixed_len
.
reserve
(
fixed_len_params_dim_
);
auto
common
=
_config
.
common
();
int
size
=
static_cast
<
int
>
(
common
.
params
().
size
());
if
(
_config
.
common
().
name
()
==
"summary"
)
{
for
(
int
x
=
0
;
x
<
param_dim_
;
++
x
)
{
result_buffer_param
[
x
].
emplace_back
(
std
::
to_string
(
values_
[
param_idx_
][
x
]));
}
}
else
{
std
::
ostringstream
os
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
int
dim
=
common
.
dims
()[
x
];
VLOG
(
3
)
<<
"MemoryDenseTable::save dim "
<<
x
<<
" size: "
<<
dim
;
for
(
int
y
=
0
;
y
<
dim
;
++
y
)
{
os
.
clear
();
os
.
str
(
""
);
os
<<
values_
[
x
][
y
];
if
(
dim
==
param_dim_
)
{
result_buffer_param
[
y
].
emplace_back
(
std
::
move
(
os
.
str
()));
}
else
{
result_buffer_fixed_len
.
emplace_back
(
std
::
move
(
os
.
str
()));
}
}
}
}
int
retry_num
=
0
;
int
err_no
=
0
;
do
{
err_no
=
0
;
is_write_failed
=
false
;
feasign_size
=
0
;
// 40M
auto
write_channel
=
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
,
&
err_no
);
for
(
auto
&
t
:
result_buffer_param
)
{
if
(
_config
.
common
().
name
()
==
"adam_d2sum"
)
{
t
.
insert
(
t
.
begin
()
+
1
,
"0"
);
// avg_w
}
if
(
0
!=
write_channel
->
write_line
(
paddle
::
string
::
join_strings
(
t
,
' '
)))
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"DownpourDenseTable save failed, retry it! "
"path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
break
;
}
}
++
feasign_size
;
write_channel
->
close
();
if
(
err_no
==
-
1
)
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"DownpourDenseTable save failed after write, retry it! "
<<
"path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
}
if
(
is_write_failed
)
{
_afs_client
.
remove
(
channel_config
.
path
);
}
if
(
retry_num
>
paddle
::
distributed
::
FLAGS_pslib_table_save_max_retry_dense
)
{
LOG
(
ERROR
)
<<
"DownpourDenseTable save failed reach max limit!"
;
exit
(
-
1
);
}
}
while
(
is_write_failed
);
LOG
(
INFO
)
<<
"DownpourDenseTable save success, path:"
<<
channel_config
.
path
;
return
feasign_size
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/memory_dense_table.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
#include <string>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/depends/dense.h"
#include "paddle/fluid/distributed/ps/table/depends/initializers.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
distributed
{
class
DenseOptimizer
;
class
MemoryDenseTable
:
public
Table
{
public:
MemoryDenseTable
()
{}
virtual
~
MemoryDenseTable
()
{}
int32_t
Initialize
()
override
;
int32_t
InitializeShard
()
override
{
return
0
;
}
void
CreateInitializer
(
const
std
::
string
&
attr
,
const
std
::
string
&
name
);
int32_t
InitializeValue
();
int32_t
InitializeOptimizer
();
int32_t
Pull
(
TableContext
&
context
)
override
;
int32_t
Push
(
TableContext
&
context
)
override
;
int32_t
PullDense
(
float
*
pull_values
,
size_t
num
);
int32_t
PushDenseParam
(
const
float
*
values
,
size_t
num
);
int32_t
PushDense
(
const
float
*
values
,
size_t
num
);
int32_t
Pour
()
override
;
int32_t
SetGlobalLR
(
float
*
lr
)
override
;
int32_t
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
int32_t
Save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
int32_t
Flush
()
override
{
return
0
;
}
int32_t
Shrink
(
const
std
::
string
&
param
)
override
{
return
0
;
}
void
Clear
()
override
{
return
;
}
void
*
GetShard
(
size_t
shard_idx
)
override
{
return
0
;
}
protected:
int32_t
_PushDense
(
const
float
*
values
,
size_t
num
);
private:
const
int
task_pool_size_
=
10
;
bool
sync
=
true
;
std
::
vector
<
std
::
shared_ptr
<::
ThreadPool
>>
_shards_task_pool
;
int
param_dim_
=
0
;
int
param_idx_
=
0
;
std
::
shared_ptr
<
DenseOptimizer
>
optimizer_
;
std
::
vector
<
std
::
vector
<
float
>>
values_
;
ReservoirValue
<
float
>
pull_reservoir_
;
std
::
unordered_map
<
std
::
string
,
Initializer
*>
initializers_
;
std
::
unordered_map
<
std
::
string
,
int
>
names_index_
;
int
total_dim_
=
0
;
int
fixed_len_params_dim_
=
0
;
// used for save/load
std
::
vector
<
int
>
param_col_ids_
;
// used for save/load
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
namespace
paddle
{
namespace
distributed
{
int32_t
MemorySparseGeoTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
pull_context
.
geo_pull_keys
!=
nullptr
)
{
return
PullGeoParam
(
context
.
trainer_id
,
context
.
pull_context
.
geo_pull_values
,
context
.
pull_context
.
geo_pull_keys
);
}
else
{
return
PullSparse
(
context
.
pull_context
.
values
,
context
.
pull_context
.
pull_value
);
}
}
int32_t
MemorySparseGeoTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
!
context
.
push_context
.
is_param
)
{
return
PushSparse
(
context
.
push_context
.
keys
,
context
.
push_context
.
values
,
context
.
num
);
}
else
{
return
PushSparseParam
(
context
.
push_context
.
keys
,
context
.
push_context
.
values
,
context
.
num
);
}
}
int32_t
MemorySparseGeoTable
::
PushSparseParam
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
)
{
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::PushSparseParam begin "
"PushSparseParam "
<<
num
;
auto
shard_num
=
_task_pool_size
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
offset_bucket
;
offset_bucket
.
resize
(
shard_num
);
for
(
size_t
x
=
0
;
x
<
num
;
++
x
)
{
auto
y
=
keys
[
x
]
%
shard_num
;
offset_bucket
[
y
].
push_back
(
x
);
if
(
x
<
10
)
{
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::PushSparseParam key: "
<<
keys
[
x
]
<<
" shard: "
<<
y
;
}
}
std
::
vector
<
std
::
future
<
int
>>
tasks
(
shard_num
);
for
(
int
shard_id
=
0
;
shard_id
<
shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
]
->
enqueue
(
[
this
,
shard_id
,
&
keys
,
&
offset_bucket
,
&
values
]()
->
int
{
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
offsets
=
offset_bucket
[
shard_id
];
for
(
size_t
i
=
0
;
i
<
offsets
.
size
();
++
i
)
{
auto
offset
=
offsets
[
i
];
auto
id
=
keys
[
offset
];
auto
&
feature_value
=
local_shard
[
id
];
feature_value
.
resize
(
_dim
);
std
::
copy_n
(
values
+
_dim
*
offset
,
_dim
,
feature_value
.
data
());
if
(
i
<
10
)
{
VLOG
(
5
)
<<
"MemorySparseGeoTable::PushSparseParam "
"PushSparseParam key "
<<
id
<<
" value[0]: "
<<
(
values
+
_dim
*
offset
)[
0
]
<<
" data: "
<<
feature_value
.
data
()[
0
]
<<
" value[-1]: "
<<
(
values
+
_dim
*
offset
)[
_dim
-
1
]
<<
" data: "
<<
feature_value
.
data
()[
_dim
-
1
];
}
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
int32_t
MemorySparseGeoTable
::
PullGeoParam
(
const
uint32_t
trainer_id
,
std
::
vector
<
float
>*
values
,
std
::
vector
<
uint64_t
>*
ids
)
{
_geo_recorder
->
GetAndClear
(
trainer_id
,
ids
);
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id "
<<
trainer_id
<<
" id_num: "
<<
ids
->
size
();
std
::
vector
<
uint32_t
>
frequencies
;
frequencies
.
resize
(
ids
->
size
(),
1
);
auto
pull_value
=
PullSparseValue
(
ids
->
size
(),
_dim
);
pull_value
.
is_training_
=
true
;
pull_value
.
feasigns_
=
ids
->
data
();
pull_value
.
frequencies_
=
frequencies
.
data
();
values
->
resize
(
ids
->
size
()
*
_dim
);
PullSparse
(
values
->
data
(),
pull_value
);
return
0
;
}
int32_t
MemorySparseGeoTable
::
PushSparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
)
{
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::PushSparse keys[0]"
<<
keys
[
0
]
<<
" key_num: "
<<
num
;
std
::
vector
<
uint64_t
>
ids
;
ids
.
resize
(
num
);
std
::
copy_n
(
keys
,
num
,
ids
.
begin
());
_geo_recorder
->
Update
(
ids
);
_PushSparse
(
keys
,
values
,
num
);
return
0
;
}
int32_t
MemorySparseGeoTable
::
Initialize
()
{
if
(
!
_geo_recorder
)
{
auto
trainers
=
_config
.
common
().
trainer_num
();
_geo_recorder
=
std
::
make_shared
<
GeoRecorder
>
(
trainers
);
}
_dim
=
_config
.
common
().
dims
()[
0
];
_shards_task_pool
.
resize
(
_task_pool_size
);
for
(
size_t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
_shards_task_pool
[
i
].
reset
(
new
::
ThreadPool
(
1
));
}
_local_shards
.
reset
(
new
shard_type
[
_task_pool_size
]);
return
0
;
}
// hash different from MemorySparseTable
int32_t
MemorySparseGeoTable
::
PullSparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
auto
shard_num
=
_task_pool_size
;
std
::
vector
<
std
::
future
<
int
>>
tasks
(
shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
shard_num
);
size_t
num
=
pull_value
.
numel_
;
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
pull_value
.
feasigns_
[
i
]
%
shard_num
;
task_keys
[
shard_id
].
push_back
({
pull_value
.
feasigns_
[
i
],
i
});
}
for
(
int
shard_id
=
0
;
shard_id
<
shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
]
->
enqueue
(
[
this
,
shard_id
,
&
task_keys
,
pull_values
]()
->
int
{
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
keys
=
task_keys
[
shard_id
];
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
i
++
)
{
uint64_t
key
=
keys
[
i
].
first
;
auto
offset
=
keys
[
i
].
second
;
float
*
select_data
=
pull_values
+
_dim
*
offset
;
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
// ++missed_keys;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
_dim
);
memset
(
feature_value
.
data
(),
0
,
sizeof
(
float
)
*
_dim
);
VLOG
(
0
)
<<
"MemorySparseGeoTable PullSparse key not found!!! "
<<
key
;
itr
=
local_shard
.
find
(
key
);
}
memcpy
(
select_data
,
itr
.
value
().
data
(),
_dim
*
sizeof
(
float
));
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::PullSparse key: "
<<
key
<<
" select_data[0] "
<<
select_data
[
0
]
<<
" value[0]: "
<<
itr
.
value
().
data
()[
0
];
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
int32_t
MemorySparseGeoTable
::
_PushSparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
)
{
auto
shard_num
=
_task_pool_size
;
std
::
vector
<
std
::
future
<
int
>>
tasks
(
shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
shard_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
keys
[
i
]
%
shard_num
;
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
for
(
int
shard_id
=
0
;
shard_id
<
shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
]
->
enqueue
(
[
this
,
shard_id
,
values
,
&
task_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
blas
=
GetBlas
<
float
>
();
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
const
float
*
update_data
=
values
+
push_data_idx
*
_dim
;
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
VLOG
(
0
)
<<
"sparse geo table push not found key!!! "
<<
key
;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
_dim
);
memset
(
feature_value
.
data
(),
0
,
sizeof
(
float
)
*
_dim
);
itr
=
local_shard
.
find
(
key
);
}
auto
&
feature_value
=
itr
.
value
();
float
*
value_data
=
feature_value
.
data
();
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::_push_sparse before key: "
<<
key
<<
" update_data[0] "
<<
update_data
[
0
]
<<
" value[0]: "
<<
value_data
[
0
];
blas
.
VADD
(
_dim
,
update_data
,
value_data
,
value_data
);
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::_push_sparse after key: "
<<
key
<<
" value[0]: "
<<
value_data
[
0
];
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
// #include <pthread.h>
#include <stdint.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/distributed/ps/table/depends/geo_recorder.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
distributed
{
class
GeoRecorder
;
class
MemorySparseGeoTable
:
public
Table
{
public:
typedef
SparseTableShard
<
uint64_t
,
FixedFeatureValue
>
shard_type
;
MemorySparseGeoTable
()
{
_geo_recorder
=
nullptr
;
}
virtual
~
MemorySparseGeoTable
()
{}
int32_t
Initialize
()
override
;
int32_t
InitializeShard
()
override
{
return
0
;
}
int32_t
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
{
return
0
;
}
int32_t
Save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
{
return
0
;
}
int32_t
Pull
(
TableContext
&
context
)
override
;
int32_t
Push
(
TableContext
&
context
)
override
;
int32_t
Flush
()
override
{
return
0
;
}
int32_t
Shrink
(
const
std
::
string
&
param
)
override
{
return
0
;
}
void
Clear
()
override
{
return
;
}
int32_t
PullSparse
(
float
*
values
,
const
PullSparseValue
&
pull_value
);
int32_t
PushSparseParam
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
);
int32_t
PullGeoParam
(
const
uint32_t
trainer_id
,
std
::
vector
<
float
>*
values
,
std
::
vector
<
uint64_t
>*
keys
);
int32_t
PushSparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
);
int32_t
_PushSparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
);
// int32_t _pull_sparse(float* pull_values, const PullSparseValue&
// pull_value);
void
*
GetShard
(
size_t
shard_idx
)
override
{
return
&
_local_shards
[
shard_idx
];
}
private:
std
::
shared_ptr
<
GeoRecorder
>
_geo_recorder
;
const
int
_task_pool_size
=
10
;
std
::
vector
<
std
::
shared_ptr
<::
ThreadPool
>>
_shards_task_pool
;
std
::
unique_ptr
<
shard_type
[]
>
_local_shards
;
int
_dim
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include <omp.h>
#include <sstream>
#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/enforce.h"
DEFINE_bool
(
pserver_print_missed_key_num_every_push
,
false
,
"pserver_print_missed_key_num_every_push"
);
DEFINE_bool
(
pserver_create_value_when_push
,
true
,
"pserver create value when push"
);
DEFINE_bool
(
pserver_enable_create_feasign_randomly
,
false
,
"pserver_enable_create_feasign_randomly"
);
DEFINE_int32
(
pserver_table_save_max_retry
,
3
,
"pserver_table_save_max_retry"
);
namespace
paddle
{
namespace
distributed
{
int32_t
MemorySparseTable
::
Initialize
()
{
_shards_task_pool
.
resize
(
_task_pool_size
);
for
(
size_t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
_shards_task_pool
[
i
].
reset
(
new
::
ThreadPool
(
1
));
}
auto
&
profiler
=
CostProfiler
::
instance
();
profiler
.
register_profiler
(
"pserver_sparse_update_all"
);
profiler
.
register_profiler
(
"pserver_sparse_select_all"
);
InitializeValue
();
VLOG
(
0
)
<<
"initalize MemorySparseTable succ"
;
return
0
;
}
int32_t
MemorySparseTable
::
InitializeValue
()
{
_sparse_table_shard_num
=
static_cast
<
int
>
(
_config
.
shard_num
());
_avg_local_shard_num
=
sparse_local_shard_num
(
_sparse_table_shard_num
,
_shard_num
);
_real_local_shard_num
=
_avg_local_shard_num
;
if
(
static_cast
<
int
>
(
_real_local_shard_num
*
(
_shard_idx
+
1
))
>
_sparse_table_shard_num
)
{
_real_local_shard_num
=
_sparse_table_shard_num
-
_real_local_shard_num
*
_shard_idx
;
_real_local_shard_num
=
_real_local_shard_num
<
0
?
0
:
_real_local_shard_num
;
}
VLOG
(
1
)
<<
"memory sparse table _avg_local_shard_num: "
<<
_avg_local_shard_num
<<
" _real_local_shard_num: "
<<
_real_local_shard_num
;
_local_shards
.
reset
(
new
shard_type
[
_real_local_shard_num
]);
return
0
;
}
int32_t
MemorySparseTable
::
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
std
::
string
table_path
=
TableDir
(
path
);
auto
file_list
=
_afs_client
.
list
(
table_path
);
std
::
sort
(
file_list
.
begin
(),
file_list
.
end
());
for
(
auto
file
:
file_list
)
{
VLOG
(
1
)
<<
"MemorySparseTable::Load() file list: "
<<
file
;
}
int
load_param
=
atoi
(
param
.
c_str
());
size_t
expect_shard_num
=
_sparse_table_shard_num
;
if
(
file_list
.
size
()
!=
expect_shard_num
)
{
LOG
(
WARNING
)
<<
"MemorySparseTable file_size:"
<<
file_list
.
size
()
<<
" not equal to expect_shard_num:"
<<
expect_shard_num
;
return
-
1
;
}
if
(
file_list
.
size
()
==
0
)
{
LOG
(
WARNING
)
<<
"MemorySparseTable load file is empty, path:"
<<
path
;
return
-
1
;
}
size_t
file_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
size_t
feature_value_size
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
int
thread_num
=
_real_local_shard_num
<
15
?
_real_local_shard_num
:
15
;
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
FsChannelConfig
channel_config
;
channel_config
.
path
=
file_list
[
file_start_idx
+
i
];
VLOG
(
1
)
<<
"MemorySparseTable::load begin load "
<<
channel_config
.
path
<<
" into local shard "
<<
i
;
channel_config
.
converter
=
_value_accesor
->
Converter
(
load_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
Converter
(
load_param
).
deconverter
;
bool
is_read_failed
=
false
;
int
retry_num
=
0
;
int
err_no
=
0
;
do
{
is_read_failed
=
false
;
err_no
=
0
;
std
::
string
line_data
;
auto
read_channel
=
_afs_client
.
open_r
(
channel_config
,
0
,
&
err_no
);
char
*
end
=
NULL
;
auto
&
shard
=
_local_shards
[
i
];
try
{
while
(
read_channel
->
read_line
(
line_data
)
==
0
&&
line_data
.
size
()
>
1
)
{
uint64_t
key
=
std
::
strtoul
(
line_data
.
data
(),
&
end
,
10
);
auto
&
value
=
shard
[
key
];
value
.
resize
(
feature_value_size
);
int
parse_size
=
_value_accesor
->
ParseFromString
(
++
end
,
value
.
data
());
value
.
resize
(
parse_size
);
// for debug
for
(
int
ii
=
0
;
ii
<
parse_size
;
++
ii
)
{
VLOG
(
2
)
<<
"MemorySparseTable::load key: "
<<
key
<<
" value "
<<
ii
<<
": "
<<
value
.
data
()[
ii
]
<<
" local_shard: "
<<
i
;
}
}
read_channel
->
close
();
if
(
err_no
==
-
1
)
{
++
retry_num
;
is_read_failed
=
true
;
LOG
(
ERROR
)
<<
"MemorySparseTable load failed after read, retry it! path:"
<<
channel_config
.
path
<<
" , retry_num="
<<
retry_num
;
}
}
catch
(...)
{
++
retry_num
;
is_read_failed
=
true
;
LOG
(
ERROR
)
<<
"MemorySparseTable load failed, retry it! path:"
<<
channel_config
.
path
<<
" , retry_num="
<<
retry_num
;
}
if
(
retry_num
>
FLAGS_pserver_table_save_max_retry
)
{
LOG
(
ERROR
)
<<
"MemorySparseTable load failed reach max limit!"
;
exit
(
-
1
);
}
}
while
(
is_read_failed
);
}
LOG
(
INFO
)
<<
"MemorySparseTable load success, path from "
<<
file_list
[
file_start_idx
]
<<
" to "
<<
file_list
[
file_start_idx
+
_real_local_shard_num
-
1
];
return
0
;
}
int32_t
MemorySparseTable
::
LoadLocalFS
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
std
::
string
table_path
=
TableDir
(
path
);
auto
file_list
=
paddle
::
framework
::
localfs_list
(
table_path
);
size_t
expect_shard_num
=
_sparse_table_shard_num
;
if
(
file_list
.
size
()
!=
expect_shard_num
)
{
LOG
(
WARNING
)
<<
"MemorySparseTable file_size:"
<<
file_list
.
size
()
<<
" not equal to expect_shard_num:"
<<
expect_shard_num
;
return
-
1
;
}
if
(
file_list
.
size
()
==
0
)
{
LOG
(
WARNING
)
<<
"MemorySparseTable load file is empty, path:"
<<
path
;
return
-
1
;
}
size_t
file_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
size_t
feature_value_size
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
int
thread_num
=
_real_local_shard_num
<
15
?
_real_local_shard_num
:
15
;
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
bool
is_read_failed
=
false
;
int
retry_num
=
0
;
int
err_no
=
0
;
do
{
is_read_failed
=
false
;
err_no
=
0
;
std
::
string
line_data
;
std
::
ifstream
file
(
file_list
[
file_start_idx
+
i
]);
char
*
end
=
NULL
;
auto
&
shard
=
_local_shards
[
i
];
try
{
while
(
std
::
getline
(
file
,
line_data
)
&&
line_data
.
size
()
>
1
)
{
uint64_t
key
=
std
::
strtoul
(
line_data
.
data
(),
&
end
,
10
);
auto
&
value
=
shard
[
key
];
value
.
resize
(
feature_value_size
);
int
parse_size
=
_value_accesor
->
ParseFromString
(
++
end
,
value
.
data
());
value
.
resize
(
parse_size
);
}
file
.
close
();
if
(
err_no
==
-
1
)
{
++
retry_num
;
is_read_failed
=
true
;
LOG
(
ERROR
)
<<
"MemorySparseTable load failed after read, retry it! path:"
<<
file_list
[
file_start_idx
+
i
]
<<
" , retry_num="
<<
retry_num
;
}
}
catch
(...)
{
++
retry_num
;
is_read_failed
=
true
;
LOG
(
ERROR
)
<<
"MemorySparseTable load failed, retry it! path:"
<<
file_list
[
file_start_idx
+
i
]
<<
" , retry_num="
<<
retry_num
;
}
if
(
retry_num
>
FLAGS_pserver_table_save_max_retry
)
{
LOG
(
ERROR
)
<<
"MemorySparseTable load failed reach max limit!"
;
exit
(
-
1
);
}
}
while
(
is_read_failed
);
}
LOG
(
INFO
)
<<
"MemorySparseTable load success, path from "
<<
file_list
[
file_start_idx
]
<<
" to "
<<
file_list
[
file_start_idx
+
_real_local_shard_num
-
1
];
return
0
;
}
int32_t
MemorySparseTable
::
Save
(
const
std
::
string
&
dirname
,
const
std
::
string
&
param
)
{
VLOG
(
0
)
<<
"MemorySparseTable::save dirname: "
<<
dirname
;
int
save_param
=
atoi
(
param
.
c_str
());
// checkpoint:0 xbox delta:1 xbox base:2
std
::
string
table_path
=
TableDir
(
dirname
);
_afs_client
.
remove
(
paddle
::
string
::
format_string
(
"%s/part-%03d-*"
,
table_path
.
c_str
(),
_shard_idx
));
std
::
atomic
<
uint32_t
>
feasign_size_all
{
0
};
size_t
file_start_idx
=
_avg_local_shard_num
*
_shard_idx
;
int
thread_num
=
_real_local_shard_num
<
20
?
_real_local_shard_num
:
20
;
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
FsChannelConfig
channel_config
;
if
(
_config
.
compress_in_save
()
&&
(
save_param
==
0
||
save_param
==
3
))
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d-%05d.gz"
,
table_path
.
c_str
(),
_shard_idx
,
file_start_idx
+
i
);
}
else
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d-%05d"
,
table_path
.
c_str
(),
_shard_idx
,
file_start_idx
+
i
);
}
channel_config
.
converter
=
_value_accesor
->
Converter
(
save_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
Converter
(
save_param
).
deconverter
;
bool
is_write_failed
=
false
;
int
feasign_size
=
0
;
int
retry_num
=
0
;
int
err_no
=
0
;
auto
&
shard
=
_local_shards
[
i
];
do
{
err_no
=
0
;
feasign_size
=
0
;
is_write_failed
=
false
;
auto
write_channel
=
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
,
&
err_no
);
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();
++
it
)
{
if
(
_value_accesor
->
Save
(
it
.
value
().
data
(),
save_param
))
{
std
::
string
format_value
=
_value_accesor
->
ParseToString
(
it
.
value
().
data
(),
it
.
value
().
size
());
if
(
0
!=
write_channel
->
write_line
(
paddle
::
string
::
format_string
(
"%lu %s"
,
it
.
key
(),
format_value
.
c_str
())))
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"MemorySparseTable save prefix failed, retry it! path:"
<<
channel_config
.
path
<<
" , retry_num="
<<
retry_num
;
break
;
}
++
feasign_size
;
}
}
write_channel
->
close
();
if
(
err_no
==
-
1
)
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"MemorySparseTable save prefix failed after write, retry it! "
<<
"path:"
<<
channel_config
.
path
<<
" , retry_num="
<<
retry_num
;
}
if
(
is_write_failed
)
{
_afs_client
.
remove
(
channel_config
.
path
);
}
if
(
retry_num
>
FLAGS_pserver_table_save_max_retry
)
{
LOG
(
ERROR
)
<<
"MemorySparseTable save prefix failed reach max limit!"
;
exit
(
-
1
);
}
}
while
(
is_write_failed
);
feasign_size_all
+=
feasign_size
;
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();
++
it
)
{
_value_accesor
->
UpdateStatAfterSave
(
it
.
value
().
data
(),
save_param
);
}
LOG
(
INFO
)
<<
"MemorySparseTable save prefix success, path: "
<<
channel_config
.
path
;
}
// int32 may overflow need to change return value
return
0
;
}
int32_t
MemorySparseTable
::
SaveLocalFS
(
const
std
::
string
&
dirname
,
const
std
::
string
&
param
,
const
std
::
string
&
prefix
)
{
int
save_param
=
atoi
(
param
.
c_str
());
// checkpoint:0 xbox delta:1 xbox base:2
std
::
string
table_path
=
TableDir
(
dirname
);
int
feasign_cnt
=
0
;
size_t
file_start_idx
=
_avg_local_shard_num
*
_shard_idx
;
int
thread_num
=
_real_local_shard_num
<
20
?
_real_local_shard_num
:
20
;
std
::
atomic
<
uint32_t
>
feasign_size_all
{
0
};
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
feasign_cnt
=
0
;
auto
&
shard
=
_local_shards
[
i
];
std
::
string
file_name
=
paddle
::
string
::
format_string
(
"%s/part-%s-%03d-%05d"
,
table_path
.
c_str
(),
prefix
.
c_str
(),
_shard_idx
,
file_start_idx
+
i
);
std
::
ofstream
os
;
os
.
open
(
file_name
);
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();
++
it
)
{
if
(
_value_accesor
->
Save
(
it
.
value
().
data
(),
save_param
))
{
std
::
string
format_value
=
_value_accesor
->
ParseToString
(
it
.
value
().
data
(),
it
.
value
().
size
());
std
::
string
out_line
=
paddle
::
string
::
format_string
(
"%lu %s
\n
"
,
it
.
key
(),
format_value
.
c_str
());
// VLOG(2) << out_line.c_str();
os
.
write
(
out_line
.
c_str
(),
sizeof
(
char
)
*
out_line
.
size
());
++
feasign_cnt
;
}
}
os
.
close
();
LOG
(
INFO
)
<<
"MemorySparseTable save prefix success, path:"
<<
file_name
<<
"feasign_cnt: "
<<
feasign_cnt
;
}
return
0
;
}
int64_t
MemorySparseTable
::
LocalSize
()
{
int64_t
local_size
=
0
;
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
local_size
+=
_local_shards
[
i
].
size
();
}
return
local_size
;
}
int64_t
MemorySparseTable
::
LocalMFSize
()
{
std
::
vector
<
int64_t
>
size_arr
(
_real_local_shard_num
,
0
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
int64_t
ret_size
=
0
;
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_shards_task_pool
.
size
()]
->
enqueue
(
[
this
,
shard_id
,
&
size_arr
]()
->
int
{
auto
&
local_shard
=
_local_shards
[
shard_id
];
for
(
auto
it
=
local_shard
.
begin
();
it
!=
local_shard
.
end
();
++
it
)
{
if
(
_value_accesor
->
HasMF
(
it
.
value
().
size
()))
{
size_arr
[
shard_id
]
+=
1
;
}
}
return
0
;
});
}
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
tasks
[
i
].
wait
();
}
for
(
auto
x
:
size_arr
)
{
ret_size
+=
x
;
}
return
ret_size
;
}
std
::
pair
<
int64_t
,
int64_t
>
MemorySparseTable
::
PrintTableStat
()
{
int64_t
feasign_size
=
LocalSize
();
int64_t
mf_size
=
LocalMFSize
();
return
{
feasign_size
,
mf_size
};
}
int32_t
MemorySparseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
return
PullSparsePtr
(
pull_values
,
keys
,
context
.
num
);
}
else
{
float
*
pull_values
=
context
.
pull_context
.
values
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
return
PullSparse
(
pull_values
,
pull_value
);
}
}
int32_t
MemorySparseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
!
context
.
use_ptr
)
{
return
PushSparse
(
context
.
push_context
.
keys
,
context
.
push_context
.
values
,
context
.
num
);
}
else
{
return
PushSparse
(
context
.
push_context
.
keys
,
context
.
push_context
.
ptr_values
,
context
.
num
);
}
}
int32_t
MemorySparseTable
::
PullSparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
CostTimer
timer
(
"pserver_sparse_select_all"
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
const
size_t
value_size
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
mf_value_size
=
_value_accesor
->
GetAccessorInfo
().
mf_size
/
sizeof
(
float
);
size_t
select_value_size
=
_value_accesor
->
GetAccessorInfo
().
select_size
/
sizeof
(
float
);
// std::atomic<uint32_t> missed_keys{0};
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
_real_local_shard_num
);
size_t
num
=
pull_value
.
numel_
;
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
(
pull_value
.
feasigns_
[
i
]
%
_sparse_table_shard_num
)
%
_avg_local_shard_num
;
task_keys
[
shard_id
].
push_back
({
pull_value
.
feasigns_
[
i
],
i
});
}
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_shards_task_pool
.
size
()]
->
enqueue
(
[
this
,
shard_id
,
&
task_keys
,
value_size
,
pull_values
,
mf_value_size
,
select_value_size
]()
->
int
{
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_size
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
auto
&
keys
=
task_keys
[
shard_id
];
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
i
++
)
{
uint64_t
key
=
keys
[
i
].
first
;
auto
itr
=
local_shard
.
find
(
key
);
size_t
data_size
=
value_size
-
mf_value_size
;
if
(
itr
==
local_shard
.
end
())
{
// ++missed_keys;
if
(
FLAGS_pserver_create_value_when_push
)
{
memset
(
data_buffer
,
0
,
sizeof
(
float
)
*
data_size
);
}
else
{
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
data_size
);
float
*
data_ptr
=
feature_value
.
data
();
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
data_ptr
,
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
}
}
else
{
data_size
=
itr
.
value
().
size
();
memcpy
(
data_buffer_ptr
,
itr
.
value
().
data
(),
data_size
*
sizeof
(
float
));
}
for
(
size_t
mf_idx
=
data_size
;
mf_idx
<
value_size
;
++
mf_idx
)
{
data_buffer
[
mf_idx
]
=
0.0
;
}
auto
offset
=
keys
[
i
].
second
;
float
*
select_data
=
pull_values
+
select_value_size
*
offset
;
_value_accesor
->
Select
(
&
select_data
,
(
const
float
**
)
&
data_buffer_ptr
,
1
);
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
int32_t
MemorySparseTable
::
PullSparsePtr
(
char
**
pull_values
,
const
uint64_t
*
keys
,
size_t
num
)
{
CostTimer
timer
(
"pscore_sparse_select_all"
);
size_t
value_size
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
mf_value_size
=
_value_accesor
->
GetAccessorInfo
().
mf_size
/
sizeof
(
float
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
_real_local_shard_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
(
keys
[
i
]
%
_sparse_table_shard_num
)
%
_avg_local_shard_num
;
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
// std::atomic<uint32_t> missed_keys{0};
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_shards_task_pool
.
size
()]
->
enqueue
(
[
this
,
shard_id
,
&
task_keys
,
pull_values
,
value_size
,
mf_value_size
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_size
];
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
auto
itr
=
local_shard
.
find
(
key
);
size_t
data_size
=
value_size
-
mf_value_size
;
FixedFeatureValue
*
ret
=
NULL
;
if
(
itr
==
local_shard
.
end
())
{
// ++missed_keys;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
data_size
);
float
*
data_ptr
=
feature_value
.
data
();
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
data_ptr
,
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
ret
=
&
feature_value
;
}
else
{
ret
=
itr
.
value_ptr
();
}
int
pull_data_idx
=
keys
[
i
].
second
;
pull_values
[
pull_data_idx
]
=
(
char
*
)
ret
;
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
int32_t
MemorySparseTable
::
PushSparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
)
{
CostTimer
timer
(
"pserver_sparse_update_all"
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
_real_local_shard_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
(
keys
[
i
]
%
_sparse_table_shard_num
)
%
_avg_local_shard_num
;
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
const
size_t
value_col
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
GetAccessorInfo
().
mf_size
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
GetAccessorInfo
().
update_size
/
sizeof
(
float
);
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_task_pool_size
]
->
enqueue
(
[
this
,
shard_id
,
value_col
,
mf_value_col
,
update_value_col
,
values
,
&
task_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_col
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
const
float
*
update_data
=
values
+
push_data_idx
*
update_value_col
;
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
if
(
FLAGS_pserver_enable_create_feasign_randomly
&&
!
_value_accesor
->
CreateValue
(
1
,
update_data
))
{
continue
;
}
auto
value_size
=
value_col
-
mf_value_col
;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
value_size
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
feature_value
.
data
(),
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
itr
=
local_shard
.
find
(
key
);
}
auto
&
feature_value
=
itr
.
value
();
float
*
value_data
=
feature_value
.
data
();
size_t
value_size
=
feature_value
.
size
();
if
(
value_size
==
value_col
)
{
// 已拓展到最大size, 则就地update
_value_accesor
->
Update
(
&
value_data
,
&
update_data
,
1
);
}
else
{
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy
(
data_buffer_ptr
,
value_data
,
value_size
*
sizeof
(
float
));
_value_accesor
->
Update
(
&
data_buffer_ptr
,
&
update_data
,
1
);
if
(
_value_accesor
->
NeedExtendMF
(
data_buffer
))
{
feature_value
.
resize
(
value_col
);
value_data
=
feature_value
.
data
();
_value_accesor
->
Create
(
&
value_data
,
1
);
}
memcpy
(
value_data
,
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
}
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
int32_t
MemorySparseTable
::
PushSparse
(
const
uint64_t
*
keys
,
const
float
**
values
,
size_t
num
)
{
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
_real_local_shard_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
(
keys
[
i
]
%
_sparse_table_shard_num
)
%
_avg_local_shard_num
;
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
size_t
value_col
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
GetAccessorInfo
().
mf_size
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
GetAccessorInfo
().
update_size
/
sizeof
(
float
);
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_task_pool_size
]
->
enqueue
(
[
this
,
shard_id
,
value_col
,
mf_value_col
,
update_value_col
,
values
,
&
task_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_col
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
const
float
*
update_data
=
values
[
push_data_idx
];
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
if
(
FLAGS_pserver_enable_create_feasign_randomly
&&
!
_value_accesor
->
CreateValue
(
1
,
update_data
))
{
continue
;
}
auto
value_size
=
value_col
-
mf_value_col
;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
value_size
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
feature_value
.
data
(),
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
itr
=
local_shard
.
find
(
key
);
}
auto
&
feature_value
=
itr
.
value
();
float
*
value_data
=
feature_value
.
data
();
size_t
value_size
=
feature_value
.
size
();
if
(
value_size
==
value_col
)
{
// 已拓展到最大size, 则就地update
_value_accesor
->
Update
(
&
value_data
,
&
update_data
,
1
);
}
else
{
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy
(
data_buffer_ptr
,
value_data
,
value_size
*
sizeof
(
float
));
_value_accesor
->
Update
(
&
data_buffer_ptr
,
&
update_data
,
1
);
if
(
_value_accesor
->
NeedExtendMF
(
data_buffer
))
{
feature_value
.
resize
(
value_col
);
value_data
=
feature_value
.
data
();
_value_accesor
->
Create
(
&
value_data
,
1
);
}
memcpy
(
value_data
,
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
}
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
int32_t
MemorySparseTable
::
Flush
()
{
return
0
;
}
int32_t
MemorySparseTable
::
Shrink
(
const
std
::
string
&
param
)
{
VLOG
(
0
)
<<
"MemorySparseTable::Shrink"
;
// TODO(zhaocaibei123): implement with multi-thread
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
// Shrink
auto
&
shard
=
_local_shards
[
shard_id
];
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();)
{
if
(
_value_accesor
->
Shrink
(
it
.
value
().
data
()))
{
it
=
shard
.
erase
(
it
);
}
else
{
++
it
;
}
}
}
return
0
;
}
void
MemorySparseTable
::
Clear
()
{
VLOG
(
0
)
<<
"clear coming soon"
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/memory_sparse_table.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/string/string_helper.h"
#define PSERVER_SAVE_SUFFIX ".shard"
namespace
paddle
{
namespace
distributed
{
class
MemorySparseTable
:
public
Table
{
public:
typedef
SparseTableShard
<
uint64_t
,
FixedFeatureValue
>
shard_type
;
MemorySparseTable
()
{}
virtual
~
MemorySparseTable
()
{}
// unused method end
static
int32_t
sparse_local_shard_num
(
uint32_t
shard_num
,
uint32_t
server_num
)
{
if
(
shard_num
%
server_num
==
0
)
{
return
shard_num
/
server_num
;
}
size_t
local_shard_num
=
shard_num
/
server_num
+
1
;
return
local_shard_num
;
}
static
size_t
get_sparse_shard
(
uint32_t
shard_num
,
uint32_t
server_num
,
uint64_t
key
)
{
return
(
key
%
shard_num
)
/
sparse_local_shard_num
(
shard_num
,
server_num
);
}
int32_t
Pull
(
TableContext
&
context
)
override
;
int32_t
Push
(
TableContext
&
context
)
override
;
int32_t
Initialize
()
override
;
int32_t
InitializeShard
()
override
{
return
0
;
}
int32_t
InitializeValue
();
virtual
int32_t
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
virtual
int32_t
Save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
int32_t
LoadLocalFS
(
const
std
::
string
&
path
,
const
std
::
string
&
param
);
int32_t
SaveLocalFS
(
const
std
::
string
&
path
,
const
std
::
string
&
param
,
const
std
::
string
&
prefix
);
int64_t
LocalSize
();
int64_t
LocalMFSize
();
std
::
pair
<
int64_t
,
int64_t
>
PrintTableStat
()
override
;
int32_t
PullSparse
(
float
*
values
,
const
PullSparseValue
&
pull_value
);
int32_t
PullSparsePtr
(
char
**
pull_values
,
const
uint64_t
*
keys
,
size_t
num
);
int32_t
PushSparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
);
int32_t
PushSparse
(
const
uint64_t
*
keys
,
const
float
**
values
,
size_t
num
);
int32_t
Flush
()
override
;
virtual
int32_t
Shrink
(
const
std
::
string
&
param
)
override
;
void
Clear
()
override
;
void
*
GetShard
(
size_t
shard_idx
)
override
{
return
&
_local_shards
[
shard_idx
];
}
protected:
const
int
_task_pool_size
=
24
;
int
_avg_local_shard_num
;
int
_real_local_shard_num
;
int
_sparse_table_shard_num
;
std
::
vector
<
std
::
shared_ptr
<::
ThreadPool
>>
_shards_task_pool
;
std
::
unique_ptr
<
shard_type
[]
>
_local_shards
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/sparse_accessor.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/sparse_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
distributed
{
int
SparseAccessor
::
Initialize
()
{
auto
name
=
_config
.
embed_sgd_param
().
name
();
_embed_sgd_rule
=
CREATE_PSCORE_CLASS
(
SparseValueSGDRule
,
name
);
_embed_sgd_rule
->
LoadConfig
(
_config
.
embed_sgd_param
(),
1
);
name
=
_config
.
embedx_sgd_param
().
name
();
_embedx_sgd_rule
=
CREATE_PSCORE_CLASS
(
SparseValueSGDRule
,
name
);
_embedx_sgd_rule
->
LoadConfig
(
_config
.
embedx_sgd_param
(),
_config
.
embedx_dim
());
sparse_feature_value
.
embed_sgd_dim
=
_embed_sgd_rule
->
Dim
();
sparse_feature_value
.
embedx_dim
=
_config
.
embedx_dim
();
sparse_feature_value
.
embedx_sgd_dim
=
_embedx_sgd_rule
->
Dim
();
_show_click_decay_rate
=
_config
.
ctr_accessor_param
().
show_click_decay_rate
();
InitAccessorInfo
();
return
0
;
}
void
SparseAccessor
::
InitAccessorInfo
()
{
_accessor_info
.
dim
=
sparse_feature_value
.
Dim
();
_accessor_info
.
size
=
sparse_feature_value
.
Size
();
auto
embedx_dim
=
_config
.
embedx_dim
();
_accessor_info
.
select_dim
=
1
+
embedx_dim
;
_accessor_info
.
select_size
=
_accessor_info
.
select_dim
*
sizeof
(
float
);
;
_accessor_info
.
update_dim
=
4
+
embedx_dim
;
_accessor_info
.
update_size
=
_accessor_info
.
update_dim
*
sizeof
(
float
);
_accessor_info
.
mf_size
=
(
embedx_dim
+
sparse_feature_value
.
embedx_sgd_dim
)
*
sizeof
(
float
);
}
bool
SparseAccessor
::
Shrink
(
float
*
value
)
{
auto
delete_after_unseen_days
=
_config
.
ctr_accessor_param
().
delete_after_unseen_days
();
auto
delete_threshold
=
_config
.
ctr_accessor_param
().
delete_threshold
();
// time_decay first
sparse_feature_value
.
Show
(
value
)
*=
_show_click_decay_rate
;
sparse_feature_value
.
Click
(
value
)
*=
_show_click_decay_rate
;
// shrink after
auto
score
=
ShowClickScore
(
sparse_feature_value
.
Show
(
value
),
sparse_feature_value
.
Click
(
value
));
auto
unseen_days
=
sparse_feature_value
.
UnseenDays
(
value
);
if
(
score
<
delete_threshold
||
unseen_days
>
delete_after_unseen_days
)
{
return
true
;
}
return
false
;
}
bool
SparseAccessor
::
Save
(
float
*
value
,
int
param
)
{
auto
base_threshold
=
_config
.
ctr_accessor_param
().
base_threshold
();
auto
delta_threshold
=
_config
.
ctr_accessor_param
().
delta_threshold
();
auto
delta_keep_days
=
_config
.
ctr_accessor_param
().
delta_keep_days
();
if
(
param
==
2
)
{
delta_threshold
=
0
;
}
switch
(
param
)
{
// save all
case
0
:
{
return
true
;
}
// save xbox delta
case
1
:
// save xbox base
case
2
:
{
if
(
ShowClickScore
(
sparse_feature_value
.
Show
(
value
),
sparse_feature_value
.
Click
(
value
))
>=
base_threshold
&&
sparse_feature_value
.
DeltaScore
(
value
)
>=
delta_threshold
&&
sparse_feature_value
.
UnseenDays
(
value
)
<=
delta_keep_days
)
{
// do this after save, because it must not be modified when retry
if
(
param
==
2
)
{
sparse_feature_value
.
DeltaScore
(
value
)
=
0
;
}
return
true
;
}
else
{
return
false
;
}
}
// already decayed in shrink
case
3
:
{
// do this after save, because it must not be modified when retry
// sparse_feature_value.UnseenDays(value)++;
return
true
;
}
// save revert batch_model
case
5
:
{
return
true
;
}
default:
return
true
;
}
}
void
SparseAccessor
::
UpdateStatAfterSave
(
float
*
value
,
int
param
)
{
auto
base_threshold
=
_config
.
ctr_accessor_param
().
base_threshold
();
auto
delta_threshold
=
_config
.
ctr_accessor_param
().
delta_threshold
();
auto
delta_keep_days
=
_config
.
ctr_accessor_param
().
delta_keep_days
();
if
(
param
==
2
)
{
delta_threshold
=
0
;
}
switch
(
param
)
{
case
1
:
{
if
(
ShowClickScore
(
sparse_feature_value
.
Show
(
value
),
sparse_feature_value
.
Click
(
value
))
>=
base_threshold
&&
sparse_feature_value
.
DeltaScore
(
value
)
>=
delta_threshold
&&
sparse_feature_value
.
UnseenDays
(
value
)
<=
delta_keep_days
)
{
sparse_feature_value
.
DeltaScore
(
value
)
=
0
;
}
}
return
;
case
3
:
{
sparse_feature_value
.
UnseenDays
(
value
)
++
;
}
return
;
default:
return
;
}
}
int32_t
SparseAccessor
::
Create
(
float
**
values
,
size_t
num
)
{
for
(
size_t
value_item
=
0
;
value_item
<
num
;
++
value_item
)
{
float
*
value
=
values
[
value_item
];
value
[
sparse_feature_value
.
UnseenDaysIndex
()]
=
0
;
value
[
sparse_feature_value
.
DeltaScoreIndex
()]
=
0
;
value
[
sparse_feature_value
.
ShowIndex
()]
=
0
;
value
[
sparse_feature_value
.
ClickIndex
()]
=
0
;
value
[
sparse_feature_value
.
SlotIndex
()]
=
-
1
;
_embed_sgd_rule
->
InitValue
(
value
+
sparse_feature_value
.
EmbedWIndex
(),
value
+
sparse_feature_value
.
EmbedG2SumIndex
());
_embedx_sgd_rule
->
InitValue
(
value
+
sparse_feature_value
.
EmbedxWIndex
(),
value
+
sparse_feature_value
.
EmbedxG2SumIndex
(),
false
);
}
return
0
;
}
bool
SparseAccessor
::
NeedExtendMF
(
float
*
value
)
{
float
show
=
value
[
sparse_feature_value
.
ShowIndex
()];
float
click
=
value
[
sparse_feature_value
.
ClickIndex
()];
float
score
=
(
show
-
click
)
*
_config
.
ctr_accessor_param
().
nonclk_coeff
()
+
click
*
_config
.
ctr_accessor_param
().
click_coeff
();
return
score
>=
_config
.
embedx_threshold
();
}
bool
SparseAccessor
::
HasMF
(
int
size
)
{
return
size
>
sparse_feature_value
.
EmbedxG2SumIndex
();
}
// from SparseFeatureValue to SparsePullValue
int32_t
SparseAccessor
::
Select
(
float
**
select_values
,
const
float
**
values
,
size_t
num
)
{
auto
embedx_dim
=
_config
.
embedx_dim
();
for
(
size_t
value_item
=
0
;
value_item
<
num
;
++
value_item
)
{
float
*
select_value
=
select_values
[
value_item
];
const
float
*
value
=
values
[
value_item
];
select_value
[
SparsePullValue
::
EmbedWIndex
()]
=
value
[
sparse_feature_value
.
EmbedWIndex
()];
memcpy
(
select_value
+
SparsePullValue
::
EmbedxWIndex
(),
value
+
sparse_feature_value
.
EmbedxWIndex
(),
embedx_dim
*
sizeof
(
float
));
}
return
0
;
}
// from SparsePushValue to SparsePushValue
// first dim: item
// second dim: field num
int32_t
SparseAccessor
::
Merge
(
float
**
update_values
,
const
float
**
other_update_values
,
size_t
num
)
{
auto
embedx_dim
=
_config
.
embedx_dim
();
size_t
total_dim
=
SparsePushValue
::
Dim
(
embedx_dim
);
for
(
size_t
value_item
=
0
;
value_item
<
num
;
++
value_item
)
{
float
*
update_value
=
update_values
[
value_item
];
const
float
*
other_update_value
=
other_update_values
[
value_item
];
for
(
size_t
i
=
0
;
i
<
total_dim
;
++
i
)
{
if
(
static_cast
<
int
>
(
i
)
!=
SparsePushValue
::
SlotIndex
())
{
update_value
[
i
]
+=
other_update_value
[
i
];
}
}
}
return
0
;
}
// from SparsePushValue to SparseFeatureValue
// first dim: item
// second dim: field num
int32_t
SparseAccessor
::
Update
(
float
**
update_values
,
const
float
**
push_values
,
size_t
num
)
{
for
(
size_t
value_item
=
0
;
value_item
<
num
;
++
value_item
)
{
float
*
update_value
=
update_values
[
value_item
];
const
float
*
push_value
=
push_values
[
value_item
];
float
push_show
=
push_value
[
SparsePushValue
::
ShowIndex
()];
float
push_click
=
push_value
[
SparsePushValue
::
ClickIndex
()];
float
slot
=
push_value
[
SparsePushValue
::
SlotIndex
()];
update_value
[
sparse_feature_value
.
ShowIndex
()]
+=
push_show
;
update_value
[
sparse_feature_value
.
ClickIndex
()]
+=
push_click
;
update_value
[
sparse_feature_value
.
SlotIndex
()]
=
slot
;
update_value
[
sparse_feature_value
.
DeltaScoreIndex
()]
+=
(
push_show
-
push_click
)
*
_config
.
ctr_accessor_param
().
nonclk_coeff
()
+
push_click
*
_config
.
ctr_accessor_param
().
click_coeff
();
update_value
[
sparse_feature_value
.
UnseenDaysIndex
()]
=
0
;
_embed_sgd_rule
->
UpdateValue
(
update_value
+
sparse_feature_value
.
EmbedWIndex
(),
update_value
+
sparse_feature_value
.
EmbedG2SumIndex
(),
push_value
+
SparsePushValue
::
EmbedGIndex
());
_embedx_sgd_rule
->
UpdateValue
(
update_value
+
sparse_feature_value
.
EmbedxWIndex
(),
update_value
+
sparse_feature_value
.
EmbedxG2SumIndex
(),
push_value
+
SparsePushValue
::
EmbedxGIndex
());
}
return
0
;
}
bool
SparseAccessor
::
CreateValue
(
int
stage
,
const
float
*
value
)
{
// stage == 0, pull
// stage == 1, push
if
(
stage
==
0
)
{
return
true
;
}
else
if
(
stage
==
1
)
{
// operation
auto
show
=
SparsePushValue
::
Show
(
const_cast
<
float
*>
(
value
));
auto
click
=
SparsePushValue
::
Click
(
const_cast
<
float
*>
(
value
));
auto
score
=
ShowClickScore
(
show
,
click
);
if
(
score
<=
0
)
{
return
false
;
}
if
(
score
>=
1
)
{
return
true
;
}
return
local_uniform_real_distribution
<
float
>
()(
local_random_engine
())
<
score
;
}
else
{
return
true
;
}
}
float
SparseAccessor
::
ShowClickScore
(
float
show
,
float
click
)
{
auto
nonclk_coeff
=
_config
.
ctr_accessor_param
().
nonclk_coeff
();
auto
click_coeff
=
_config
.
ctr_accessor_param
().
click_coeff
();
return
(
show
-
click
)
*
nonclk_coeff
+
click
*
click_coeff
;
}
std
::
string
SparseAccessor
::
ParseToString
(
const
float
*
v
,
int
param
)
{
thread_local
std
::
ostringstream
os
;
os
.
clear
();
os
.
str
(
""
);
os
<<
v
[
0
]
<<
" "
<<
v
[
1
]
<<
" "
<<
v
[
2
]
<<
" "
<<
v
[
3
]
<<
" "
<<
v
[
4
]
<<
" "
<<
v
[
5
];
for
(
int
i
=
sparse_feature_value
.
EmbedG2SumIndex
();
i
<
sparse_feature_value
.
EmbedxWIndex
();
i
++
)
{
os
<<
" "
<<
v
[
i
];
}
auto
show
=
sparse_feature_value
.
Show
(
const_cast
<
float
*>
(
v
));
auto
click
=
sparse_feature_value
.
Click
(
const_cast
<
float
*>
(
v
));
auto
score
=
ShowClickScore
(
show
,
click
);
if
(
score
>=
_config
.
embedx_threshold
()
&&
param
>
sparse_feature_value
.
EmbedxWIndex
())
{
for
(
auto
i
=
sparse_feature_value
.
EmbedxWIndex
();
i
<
sparse_feature_value
.
Dim
();
++
i
)
{
os
<<
" "
<<
v
[
i
];
}
}
return
os
.
str
();
}
int
SparseAccessor
::
ParseFromString
(
const
std
::
string
&
str
,
float
*
value
)
{
_embedx_sgd_rule
->
InitValue
(
value
+
sparse_feature_value
.
EmbedxWIndex
(),
value
+
sparse_feature_value
.
EmbedxG2SumIndex
());
auto
ret
=
paddle
::
string
::
str_to_float
(
str
.
data
(),
value
);
CHECK
(
ret
>=
6
)
<<
"expect more than 6 real:"
<<
ret
;
return
ret
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/sparse_accessor.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
namespace
paddle
{
namespace
distributed
{
// no show click, for word2vec(DownpourSparseValueAccessor)
class
SparseAccessor
:
public
ValueAccessor
{
public:
struct
SparseFeatureValue
{
/*
float slot;
float unseen_days;
float delta_score;
float show;
float click;
float embed_w;
std::vector<float> embed_g2sum;
std::vector<float> embedx_w;
std::<vector>float embedx_g2sum;
*/
int
Dim
()
{
return
6
+
embed_sgd_dim
+
embedx_sgd_dim
+
embedx_dim
;
}
int
DimSize
(
size_t
dim
,
int
embedx_dim
)
{
return
sizeof
(
float
);
}
int
Size
()
{
return
Dim
()
*
sizeof
(
float
);
}
int
SlotIndex
()
{
return
0
;
}
int
UnseenDaysIndex
()
{
return
SlotIndex
()
+
1
;
}
int
DeltaScoreIndex
()
{
return
UnseenDaysIndex
()
+
1
;
}
int
ShowIndex
()
{
return
DeltaScoreIndex
()
+
1
;
}
int
ClickIndex
()
{
return
ShowIndex
()
+
1
;
}
int
EmbedWIndex
()
{
return
ClickIndex
()
+
1
;
}
int
EmbedG2SumIndex
()
{
return
EmbedWIndex
()
+
1
;
}
int
EmbedxWIndex
()
{
return
EmbedG2SumIndex
()
+
embed_sgd_dim
;
}
int
EmbedxG2SumIndex
()
{
return
EmbedxWIndex
()
+
embedx_dim
;
}
float
&
UnseenDays
(
float
*
val
)
{
return
val
[
UnseenDaysIndex
()];
}
float
&
DeltaScore
(
float
*
val
)
{
return
val
[
DeltaScoreIndex
()];
}
float
&
Show
(
float
*
val
)
{
return
val
[
ShowIndex
()];
}
float
&
Click
(
float
*
val
)
{
return
val
[
ClickIndex
()];
}
float
&
Slot
(
float
*
val
)
{
return
val
[
SlotIndex
()];
}
float
&
EmbedW
(
float
*
val
)
{
return
val
[
EmbedWIndex
()];
}
float
&
EmbedG2Sum
(
float
*
val
)
{
return
val
[
EmbedG2SumIndex
()];
}
float
&
EmbedxW
(
float
*
val
)
{
return
val
[
EmbedxWIndex
()];
}
float
&
EmbedxG2Sum
(
float
*
val
)
{
return
val
[
EmbedxG2SumIndex
()];
}
int
embed_sgd_dim
;
int
embedx_dim
;
int
embedx_sgd_dim
;
};
struct
SparsePushValue
{
/*
float slot;
float show;
float click;
float embed_g;
std::vector<float> embedx_g;
*/
static
int
Dim
(
int
embedx_dim
)
{
return
4
+
embedx_dim
;
}
static
int
DimSize
(
int
dim
,
int
embedx_dim
)
{
return
sizeof
(
float
);
}
static
int
Size
(
int
embedx_dim
)
{
return
Dim
(
embedx_dim
)
*
sizeof
(
float
);
}
static
int
SlotIndex
()
{
return
0
;
}
static
int
ShowIndex
()
{
return
SparsePushValue
::
SlotIndex
()
+
1
;
}
static
int
ClickIndex
()
{
return
SparsePushValue
::
ShowIndex
()
+
1
;
}
static
int
EmbedGIndex
()
{
return
SparsePushValue
::
ClickIndex
()
+
1
;
}
static
int
EmbedxGIndex
()
{
return
SparsePushValue
::
EmbedGIndex
()
+
1
;
}
static
float
&
Slot
(
float
*
val
)
{
return
val
[
SparsePushValue
::
SlotIndex
()];
}
static
float
&
Show
(
float
*
val
)
{
return
val
[
SparsePushValue
::
ShowIndex
()];
}
static
float
&
Click
(
float
*
val
)
{
return
val
[
SparsePushValue
::
ClickIndex
()];
}
static
float
&
EmbedG
(
float
*
val
)
{
return
val
[
SparsePushValue
::
EmbedGIndex
()];
}
static
float
*
EmbedxG
(
float
*
val
)
{
return
val
+
SparsePushValue
::
EmbedxGIndex
();
}
};
struct
SparsePullValue
{
/*
float embed_w;
std::vector<float> embedx_w;
*/
static
int
Dim
(
int
embedx_dim
)
{
return
1
+
embedx_dim
;
}
static
int
DimSize
(
size_t
dim
)
{
return
sizeof
(
float
);
}
static
int
Size
(
int
embedx_dim
)
{
return
Dim
(
embedx_dim
)
*
sizeof
(
float
);
}
static
int
EmbedWIndex
()
{
return
0
;
}
static
int
EmbedxWIndex
()
{
return
1
;
}
static
float
&
EmbedW
(
float
*
val
)
{
return
val
[
SparsePullValue
::
EmbedWIndex
()];
}
static
float
*
EmbedxW
(
float
*
val
)
{
return
val
+
SparsePullValue
::
EmbedxWIndex
();
}
};
SparseAccessor
()
{}
virtual
~
SparseAccessor
()
{}
virtual
int
Initialize
();
// 初始化AccessorInfo
virtual
void
InitAccessorInfo
();
// 判断该value是否进行shrink
virtual
bool
Shrink
(
float
*
value
);
// 判断该value是否保存到ssd
// virtual bool save_ssd(float* value);
virtual
bool
NeedExtendMF
(
float
*
value
);
virtual
bool
HasMF
(
int
size
);
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature
// param = 1, save delta feature
// param = 2, save xbox base feature
bool
Save
(
float
*
value
,
int
param
)
override
;
bool
SaveCache
(
float
*
value
,
int
param
,
double
global_cache_threshold
)
{
return
false
;
}
bool
SaveSSD
(
float
*
value
)
{
return
false
;
}
// update delta_score and unseen_days after save
void
UpdateStatAfterSave
(
float
*
value
,
int
param
)
override
;
// keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕
virtual
int32_t
Create
(
float
**
value
,
size_t
num
);
// 从values中选取到select_values中
virtual
int32_t
Select
(
float
**
select_values
,
const
float
**
values
,
size_t
num
);
// 将update_values聚合到一起
virtual
int32_t
Merge
(
float
**
update_values
,
const
float
**
other_update_values
,
size_t
num
);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual
int32_t
Update
(
float
**
values
,
const
float
**
update_values
,
size_t
num
);
std
::
string
ParseToString
(
const
float
*
value
,
int
param
)
override
;
int32_t
ParseFromString
(
const
std
::
string
&
str
,
float
*
v
)
override
;
virtual
bool
CreateValue
(
int
type
,
const
float
*
value
);
// 这个接口目前只用来取show
float
GetField
(
float
*
value
,
const
std
::
string
&
name
)
override
{
// CHECK(name == "show");
if
(
name
==
"show"
)
{
return
sparse_feature_value
.
Show
(
value
);
}
return
0.0
;
}
private:
// float ShowClickScore(float show, float click);
// SparseValueSGDRule* _embed_sgd_rule;
// SparseValueSGDRule* _embedx_sgd_rule;
// SparseFeatureValue sparse_feature_value;
float
_show_click_decay_rate
;
int32_t
_ssd_unseenday_threshold
;
public:
// TODO(zhaocaibei123): it should be private, but we make it public
// for unit test
SparseFeatureValue
sparse_feature_value
;
float
ShowClickScore
(
float
show
,
float
click
);
SparseValueSGDRule
*
_embed_sgd_rule
;
SparseValueSGDRule
*
_embedx_sgd_rule
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
DEFINE_bool
(
enable_show_scale_gradient
,
true
,
"enable show scale gradient"
);
namespace
paddle
{
namespace
distributed
{
void
SparseNaiveSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
auto
naive_param
=
param
.
naive
();
learning_rate_
=
naive_param
.
learning_rate
();
_initial_range
=
naive_param
.
initial_range
();
if
(
naive_param
.
weight_bounds_size
()
==
0
)
{
_min_bound
=
-
std
::
numeric_limits
<
float
>::
max
();
_max_bound
=
std
::
numeric_limits
<
float
>::
max
();
}
else
{
CHECK
(
naive_param
.
weight_bounds_size
()
>=
2
)
<<
"invalid repeated size for weight_bounds:"
<<
naive_param
.
weight_bounds_size
();
_min_bound
=
naive_param
.
weight_bounds
(
0
);
_max_bound
=
naive_param
.
weight_bounds
(
1
);
}
}
void
SparseNaiveSGDRule
::
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
const
float
*
push_value
,
float
scale
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
w
[
i
]
-=
learning_rate_
*
push_value
[
i
];
BoundValue
(
w
[
i
]);
}
}
void
SparseNaiveSGDRule
::
InitValueWork
(
float
*
value
,
float
*
sgd
,
bool
zero_init
)
{
if
(
zero_init
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
value
[
i
]
=
0
;
}
}
else
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
value
[
i
]
=
(
local_uniform_real_distribution
<
float
>
()(
local_random_engine
())
*
2
-
1
)
*
_initial_range
;
BoundValue
(
value
[
i
]);
}
}
}
void
SparseAdaGradSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
auto
adagrad_param
=
param
.
adagrad
();
learning_rate_
=
adagrad_param
.
learning_rate
();
_initial_g2sum
=
adagrad_param
.
initial_g2sum
();
_initial_range
=
adagrad_param
.
initial_range
();
if
(
adagrad_param
.
weight_bounds_size
()
==
0
)
{
_min_bound
=
-
std
::
numeric_limits
<
float
>::
max
();
_max_bound
=
std
::
numeric_limits
<
float
>::
max
();
}
else
{
CHECK
(
adagrad_param
.
weight_bounds_size
()
>=
2
)
<<
"invalid repeated size for weight_bounds:"
<<
adagrad_param
.
weight_bounds_size
();
_min_bound
=
adagrad_param
.
weight_bounds
(
0
);
_max_bound
=
adagrad_param
.
weight_bounds
(
1
);
}
}
void
SparseAdaGradSGDRule
::
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
const
float
*
grad
,
float
scale
)
{
float
&
g2sum
=
sgd
[
G2SumIndex
()];
double
add_g2sum
=
0
;
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
i
++
)
{
double
scaled_grad
=
grad
[
i
]
/
scale
;
w
[
i
]
-=
learning_rate_
*
scaled_grad
*
sqrt
(
_initial_g2sum
/
(
_initial_g2sum
+
g2sum
));
BoundValue
(
w
[
i
]);
add_g2sum
+=
scaled_grad
*
scaled_grad
;
}
g2sum
+=
add_g2sum
/
_embedding_dim
;
}
void
SparseAdaGradSGDRule
::
InitValueWork
(
float
*
value
,
float
*
sgd
,
bool
zero_init
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
if
(
zero_init
)
{
value
[
i
]
=
0.0
;
BoundValue
(
value
[
i
]);
}
else
{
value
[
i
]
=
(
local_uniform_real_distribution
<
double
>
()(
local_random_engine
())
*
2
-
1
)
*
_initial_range
;
BoundValue
(
value
[
i
]);
}
}
sgd
[
G2SumIndex
()]
=
0
;
}
void
StdAdaGradSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
auto
adagrad_param
=
param
.
adagrad
();
learning_rate_
=
adagrad_param
.
learning_rate
();
_initial_g2sum
=
adagrad_param
.
initial_g2sum
();
_initial_range
=
adagrad_param
.
initial_range
();
if
(
adagrad_param
.
weight_bounds_size
()
==
0
)
{
_min_bound
=
-
std
::
numeric_limits
<
float
>::
max
();
_max_bound
=
std
::
numeric_limits
<
float
>::
max
();
}
else
{
CHECK
(
adagrad_param
.
weight_bounds_size
()
>=
2
)
<<
"invalid repeated size for weight_bounds:"
<<
adagrad_param
.
weight_bounds_size
();
_min_bound
=
adagrad_param
.
weight_bounds
(
0
);
_max_bound
=
adagrad_param
.
weight_bounds
(
1
);
}
}
void
StdAdaGradSGDRule
::
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
const
float
*
grad
,
float
scale
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
i
++
)
{
float
&
g2sum
=
sgd
[
G2SumIndex
()
+
i
];
double
scaled_grad
=
grad
[
i
]
/
scale
;
w
[
i
]
-=
learning_rate_
*
scaled_grad
*
sqrt
(
_initial_g2sum
/
(
_initial_g2sum
+
g2sum
));
BoundValue
(
w
[
i
]);
g2sum
+=
scaled_grad
*
scaled_grad
;
}
}
void
StdAdaGradSGDRule
::
InitValueWork
(
float
*
value
,
float
*
sgd
,
bool
zero_init
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
if
(
zero_init
)
{
value
[
i
]
=
0.0
;
BoundValue
(
value
[
i
]);
}
else
{
value
[
i
]
=
(
local_uniform_real_distribution
<
double
>
()(
local_random_engine
())
*
2
-
1
)
*
_initial_range
;
BoundValue
(
value
[
i
]);
}
sgd
[
G2SumIndex
()
+
i
]
=
0
;
}
}
void
SparseAdamSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
auto
adam_param
=
param
.
adam
();
learning_rate_
=
adam_param
.
learning_rate
();
_initial_range
=
adam_param
.
initial_range
();
_beta1_decay_rate
=
adam_param
.
beta1_decay_rate
();
_beta2_decay_rate
=
adam_param
.
beta2_decay_rate
();
_ada_epsilon
=
adam_param
.
ada_epsilon
();
if
(
adam_param
.
weight_bounds_size
()
==
0
)
{
_min_bound
=
-
std
::
numeric_limits
<
float
>::
max
();
_max_bound
=
std
::
numeric_limits
<
float
>::
max
();
}
else
{
CHECK
(
adam_param
.
weight_bounds_size
()
>=
2
)
<<
"invalid repeated size for weight_bounds:"
<<
adam_param
.
weight_bounds_size
();
_min_bound
=
adam_param
.
weight_bounds
(
0
);
_max_bound
=
adam_param
.
weight_bounds
(
1
);
}
}
void
SparseAdamSGDRule
::
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
const
float
*
grad
,
float
scale
)
{
float
*
gsum
=
sgd
+
GSumIndex
();
float
*
g2sum
=
sgd
+
G2SumIndex
();
float
*
beta1_pow
=
sgd
+
Beta1PowIndex
();
float
*
beta2_pow
=
sgd
+
Beta2PowIndex
();
const
float
*
g
=
grad
;
float
lr
=
learning_rate_
;
float
beta1_pow_
=
*
beta1_pow
;
float
beta2_pow_
=
*
beta2_pow
;
// lr not change in one update
lr
*=
sqrt
(
1
-
beta2_pow_
)
/
(
1
-
beta1_pow_
);
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
i
++
)
{
// Calculation
gsum
[
i
]
=
_beta1_decay_rate
*
gsum
[
i
]
+
(
1
-
_beta1_decay_rate
)
*
g
[
i
];
g2sum
[
i
]
=
_beta2_decay_rate
*
g2sum
[
i
]
+
(
1
-
_beta2_decay_rate
)
*
g
[
i
]
*
g
[
i
];
w
[
i
]
=
w
[
i
]
-
lr
*
(
gsum
[
i
]
/
(
sqrt
(
g2sum
[
i
])
+
_ada_epsilon
));
BoundValue
(
w
[
i
]);
}
// update beta_pow_decay
(
*
beta1_pow
)
*=
_beta1_decay_rate
;
(
*
beta2_pow
)
*=
_beta2_decay_rate
;
}
void
SparseAdamSGDRule
::
InitValueWork
(
float
*
value
,
float
*
sgd
,
bool
zero_init
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
if
(
zero_init
)
{
value
[
i
]
=
0.0
;
BoundValue
(
value
[
i
]);
}
else
{
value
[
i
]
=
(
local_uniform_real_distribution
<
double
>
()(
local_random_engine
())
*
2
-
1
)
*
_initial_range
;
BoundValue
(
value
[
i
]);
}
}
// init rule gsum and g2sum
for
(
size_t
i
=
GSumIndex
();
i
<
Beta1PowIndex
();
i
++
)
{
sgd
[
i
]
=
0.0
;
}
// init beta1_pow and beta2_pow
*
(
sgd
+
Beta1PowIndex
())
=
_beta1_decay_rate
;
*
(
sgd
+
Beta2PowIndex
())
=
_beta2_decay_rate
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/sparse_sgd_rule.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <thread>
#include <vector>
#include "glog/logging.h" // for CHECK
#include "paddle/fluid/distributed/common/local_random.h" // for local_uniform_real_distribution
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
namespace
paddle
{
namespace
distributed
{
class
SparseValueSGDRule
{
public:
SparseValueSGDRule
()
{}
virtual
~
SparseValueSGDRule
()
{}
virtual
void
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
_name
=
param
.
name
();
}
virtual
void
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
const
float
*
push_value
,
float
scale
)
=
0
;
virtual
void
InitValueWork
(
float
*
value
,
float
*
sgd
,
bool
zero_init
)
=
0
;
virtual
size_t
Dim
()
=
0
;
const
std
::
string
&
GetName
()
const
{
return
_name
;
}
void
InitValue
(
float
*
value
,
float
*
sgd
,
bool
zero_init
=
true
)
{
InitValueWork
(
value
,
sgd
,
zero_init
);
}
void
UpdateValue
(
float
*
w
,
float
*
sgd
,
const
float
*
push_value
,
float
scale
=
1
)
{
UpdateValueWork
(
w
,
sgd
,
push_value
,
scale
);
}
template
<
class
T
>
void
BoundValue
(
T
&
w
)
{
// NOLINT
if
(
!
(
w
>=
_min_bound
))
{
w
=
(
T
)
_min_bound
;
}
else
if
(
!
(
w
<=
_max_bound
))
{
w
=
(
T
)
_max_bound
;
}
}
float
&
MinBound
()
{
return
_min_bound
;
}
float
&
MaxBound
()
{
return
_max_bound
;
}
protected:
float
_min_bound
;
float
_max_bound
;
float
_initial_range
;
size_t
_embedding_dim
;
private:
std
::
string
_name
;
};
REGISTER_PSCORE_REGISTERER
(
SparseValueSGDRule
);
class
SparseNaiveSGDRule
:
public
SparseValueSGDRule
{
public:
virtual
void
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
);
virtual
void
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
const
float
*
push_value
,
float
scale
);
virtual
void
InitValueWork
(
float
*
value
,
float
*
sgd
,
bool
zero_init
);
virtual
size_t
Dim
()
{
return
0
;
}
private:
float
learning_rate_
;
};
class
SparseAdaGradSGDRule
:
public
SparseValueSGDRule
{
public:
virtual
void
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
);
virtual
void
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
const
float
*
push_value
,
float
scale
);
virtual
void
InitValueWork
(
float
*
value
,
float
*
sgd
,
bool
zero_init
);
virtual
size_t
Dim
()
{
return
1
;
}
size_t
G2SumIndex
()
{
return
0
;
}
private:
float
learning_rate_
;
float
_initial_g2sum
;
};
class
StdAdaGradSGDRule
:
public
SparseValueSGDRule
{
public:
virtual
void
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
);
virtual
void
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
const
float
*
push_value
,
float
scale
);
virtual
void
InitValueWork
(
float
*
value
,
float
*
sgd
,
bool
zero_init
);
virtual
size_t
Dim
()
{
return
_embedding_dim
;
}
size_t
G2SumIndex
()
{
return
0
;
}
private:
float
learning_rate_
;
float
_initial_g2sum
;
};
class
SparseAdamSGDRule
:
public
SparseValueSGDRule
{
public:
virtual
void
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
);
virtual
void
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
const
float
*
push_value
,
float
scale
);
virtual
void
InitValueWork
(
float
*
value
,
float
*
sgd
,
bool
zero_init
);
virtual
size_t
Dim
()
{
return
_embedding_dim
*
2
+
2
;
}
size_t
GSumIndex
()
{
return
0
;
}
size_t
G2SumIndex
()
{
return
GSumIndex
()
+
_embedding_dim
;
}
size_t
Beta1PowIndex
()
{
return
G2SumIndex
()
+
_embedding_dim
;
}
size_t
Beta2PowIndex
()
{
return
Beta1PowIndex
()
+
1
;
}
protected:
float
learning_rate_
;
float
_beta1_decay_rate
;
float
_beta2_decay_rate
;
float
_ada_epsilon
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h"
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/common/local_random.h"
#include "paddle/fluid/distributed/common/topk_calculator.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/utils/string/string_helper.h"
DECLARE_bool
(
pserver_print_missed_key_num_every_push
);
DECLARE_bool
(
pserver_create_value_when_push
);
DECLARE_bool
(
pserver_enable_create_feasign_randomly
);
DEFINE_bool
(
pserver_open_strict_check
,
false
,
"pserver_open_strict_check"
);
DEFINE_string
(
rocksdb_path
,
"database"
,
"path of sparse table rocksdb file"
);
DEFINE_int32
(
pserver_load_batch_size
,
5000
,
"load batch size for ssd"
);
namespace
paddle
{
namespace
distributed
{
int32_t
SSDSparseTable
::
Initialize
()
{
MemorySparseTable
::
Initialize
();
_db
=
paddle
::
distributed
::
RocksDBHandler
::
GetInstance
();
_db
->
initialize
(
FLAGS_rocksdb_path
,
_real_local_shard_num
);
return
0
;
}
int32_t
SSDSparseTable
::
InitializeShard
()
{
return
0
;
}
int32_t
SSDSparseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
return
PullSparsePtr
(
pull_values
,
keys
,
context
.
num
);
}
else
{
float
*
pull_values
=
context
.
pull_context
.
values
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
return
PullSparse
(
pull_values
,
pull_value
.
feasigns_
,
pull_value
.
numel_
);
}
}
int32_t
SSDSparseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
return
PushSparse
(
context
.
push_context
.
keys
,
context
.
push_context
.
ptr_values
,
context
.
num
);
}
else
{
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
const
float
*
values
=
context
.
push_context
.
values
;
size_t
num
=
context
.
num
;
return
PushSparse
(
keys
,
values
,
num
);
}
}
int32_t
SSDSparseTable
::
PullSparse
(
float
*
pull_values
,
const
uint64_t
*
keys
,
size_t
num
)
{
CostTimer
timer
(
"pserver_downpour_sparse_select_all"
);
size_t
value_size
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
mf_value_size
=
_value_accesor
->
GetAccessorInfo
().
mf_size
/
sizeof
(
float
);
size_t
select_value_size
=
_value_accesor
->
GetAccessorInfo
().
select_size
/
sizeof
(
float
);
{
// 从table取值 or create
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
_real_local_shard_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
(
keys
[
i
]
%
_sparse_table_shard_num
)
%
_avg_local_shard_num
;
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
std
::
atomic
<
uint32_t
>
missed_keys
{
0
};
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_shards_task_pool
.
size
()]
->
enqueue
(
[
this
,
shard_id
,
&
task_keys
,
value_size
,
mf_value_size
,
select_value_size
,
pull_values
,
keys
,
&
missed_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_size
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
auto
itr
=
local_shard
.
find
(
key
);
size_t
data_size
=
value_size
-
mf_value_size
;
if
(
itr
==
local_shard
.
end
())
{
// pull rocksdb
std
::
string
tmp_string
(
""
);
if
(
_db
->
get
(
shard_id
,
reinterpret_cast
<
char
*>
(
&
key
),
sizeof
(
uint64_t
),
tmp_string
)
>
0
)
{
++
missed_keys
;
if
(
FLAGS_pserver_create_value_when_push
)
{
memset
(
data_buffer
,
0
,
sizeof
(
float
)
*
data_size
);
}
else
{
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
data_size
);
float
*
data_ptr
=
const_cast
<
float
*>
(
feature_value
.
data
());
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
data_ptr
,
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
}
}
else
{
data_size
=
tmp_string
.
size
()
/
sizeof
(
float
);
memcpy
(
data_buffer_ptr
,
paddle
::
string
::
str_to_float
(
tmp_string
),
data_size
*
sizeof
(
float
));
// from rocksdb to mem
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
data_size
);
memcpy
(
const_cast
<
float
*>
(
feature_value
.
data
()),
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
_db
->
del_data
(
shard_id
,
reinterpret_cast
<
char
*>
(
&
key
),
sizeof
(
uint64_t
));
}
}
else
{
data_size
=
itr
.
value
().
size
();
memcpy
(
data_buffer_ptr
,
itr
.
value
().
data
(),
data_size
*
sizeof
(
float
));
}
for
(
size_t
mf_idx
=
data_size
;
mf_idx
<
value_size
;
++
mf_idx
)
{
data_buffer
[
mf_idx
]
=
0.0
;
}
int
pull_data_idx
=
keys
[
i
].
second
;
float
*
select_data
=
pull_values
+
pull_data_idx
*
select_value_size
;
_value_accesor
->
Select
(
&
select_data
,
(
const
float
**
)
&
data_buffer_ptr
,
1
);
}
return
0
;
});
}
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
tasks
[
i
].
wait
();
}
if
(
FLAGS_pserver_print_missed_key_num_every_push
)
{
LOG
(
WARNING
)
<<
"total pull keys:"
<<
num
<<
" missed_keys:"
<<
missed_keys
.
load
();
}
}
return
0
;
}
int32_t
SSDSparseTable
::
PullSparsePtr
(
char
**
pull_values
,
const
uint64_t
*
keys
,
size_t
num
)
{
CostTimer
timer
(
"pserver_ssd_sparse_select_all"
);
size_t
value_size
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
mf_value_size
=
_value_accesor
->
GetAccessorInfo
().
mf_size
/
sizeof
(
float
);
{
// 从table取值 or create
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
_real_local_shard_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
(
keys
[
i
]
%
_sparse_table_shard_num
)
%
_avg_local_shard_num
;
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
std
::
atomic
<
uint32_t
>
missed_keys
{
0
};
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_shards_task_pool
.
size
()]
->
enqueue
(
[
this
,
shard_id
,
&
task_keys
,
value_size
,
mf_value_size
,
pull_values
,
&
missed_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_size
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
auto
itr
=
local_shard
.
find
(
key
);
size_t
data_size
=
value_size
-
mf_value_size
;
FixedFeatureValue
*
ret
=
NULL
;
if
(
itr
==
local_shard
.
end
())
{
// pull rocksdb
std
::
string
tmp_string
(
""
);
if
(
_db
->
get
(
shard_id
,
reinterpret_cast
<
char
*>
(
&
key
),
sizeof
(
uint64_t
),
tmp_string
)
>
0
)
{
++
missed_keys
;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
data_size
);
float
*
data_ptr
=
const_cast
<
float
*>
(
feature_value
.
data
());
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
data_ptr
,
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
ret
=
&
feature_value
;
}
else
{
data_size
=
tmp_string
.
size
()
/
sizeof
(
float
);
memcpy
(
data_buffer_ptr
,
paddle
::
string
::
str_to_float
(
tmp_string
),
data_size
*
sizeof
(
float
));
// from rocksdb to mem
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
data_size
);
memcpy
(
const_cast
<
float
*>
(
feature_value
.
data
()),
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
_db
->
del_data
(
shard_id
,
reinterpret_cast
<
char
*>
(
&
key
),
sizeof
(
uint64_t
));
ret
=
&
feature_value
;
}
}
else
{
ret
=
itr
.
value_ptr
();
}
int
pull_data_idx
=
keys
[
i
].
second
;
pull_values
[
pull_data_idx
]
=
reinterpret_cast
<
char
*>
(
ret
);
}
return
0
;
});
}
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
tasks
[
i
].
wait
();
}
if
(
FLAGS_pserver_print_missed_key_num_every_push
)
{
LOG
(
WARNING
)
<<
"total pull keys:"
<<
num
<<
" missed_keys:"
<<
missed_keys
.
load
();
}
}
return
0
;
}
int32_t
SSDSparseTable
::
PushSparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
)
{
CostTimer
timer
(
"pserver_downpour_sparse_update_all"
);
// 构造value push_value的数据指针
size_t
value_col
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
GetAccessorInfo
().
mf_size
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
GetAccessorInfo
().
update_size
/
sizeof
(
float
);
{
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
_real_local_shard_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
(
keys
[
i
]
%
_sparse_table_shard_num
)
%
_avg_local_shard_num
;
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_shards_task_pool
.
size
()]
->
enqueue
(
[
this
,
shard_id
,
value_col
,
mf_value_col
,
update_value_col
,
values
,
&
task_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_col
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
const
float
*
update_data
=
values
+
push_data_idx
*
update_value_col
;
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
if
(
FLAGS_pserver_enable_create_feasign_randomly
&&
!
_value_accesor
->
CreateValue
(
1
,
update_data
))
{
continue
;
}
auto
value_size
=
value_col
-
mf_value_col
;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
value_size
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
const_cast
<
float
*>
(
feature_value
.
data
()),
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
itr
=
local_shard
.
find
(
key
);
}
auto
&
feature_value
=
itr
.
value
();
float
*
value_data
=
const_cast
<
float
*>
(
feature_value
.
data
());
size_t
value_size
=
feature_value
.
size
();
if
(
value_size
==
value_col
)
{
// 已拓展到最大size, 则就地update
_value_accesor
->
Update
(
&
value_data
,
&
update_data
,
1
);
}
else
{
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy
(
data_buffer_ptr
,
value_data
,
value_size
*
sizeof
(
float
));
_value_accesor
->
Update
(
&
data_buffer_ptr
,
&
update_data
,
1
);
if
(
_value_accesor
->
NeedExtendMF
(
data_buffer
))
{
feature_value
.
resize
(
value_col
);
value_data
=
const_cast
<
float
*>
(
feature_value
.
data
());
_value_accesor
->
Create
(
&
value_data
,
1
);
}
memcpy
(
value_data
,
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
}
}
return
0
;
});
}
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
tasks
[
i
].
wait
();
}
}
/*
//update && value 的转置
thread_local Eigen::MatrixXf update_matrix;
float* transposed_update_data[update_value_col];
make_matrix_with_eigen(num, update_value_col, update_matrix,
transposed_update_data);
copy_array_to_eigen(values, update_matrix);
thread_local Eigen::MatrixXf value_matrix;
float* transposed_value_data[value_col];
make_matrix_with_eigen(num, value_col, value_matrix, transposed_value_data);
copy_matrix_to_eigen((const float**)(value_ptrs->data()), value_matrix);
//批量update
{
CostTimer accessor_timer("pslib_downpour_sparse_update_accessor");
_value_accesor->update(transposed_value_data, (const
float**)transposed_update_data, num);
}
copy_eigen_to_matrix(value_matrix, value_ptrs->data());
*/
return
0
;
}
int32_t
SSDSparseTable
::
PushSparse
(
const
uint64_t
*
keys
,
const
float
**
values
,
size_t
num
)
{
CostTimer
timer
(
"pserver_downpour_sparse_update_all"
);
// 构造value push_value的数据指针
size_t
value_col
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
GetAccessorInfo
().
mf_size
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
GetAccessorInfo
().
update_size
/
sizeof
(
float
);
{
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
_real_local_shard_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
(
keys
[
i
]
%
_sparse_table_shard_num
)
%
_avg_local_shard_num
;
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_shards_task_pool
.
size
()]
->
enqueue
(
[
this
,
shard_id
,
value_col
,
mf_value_col
,
update_value_col
,
values
,
&
task_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_col
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
const
float
*
update_data
=
values
[
push_data_idx
];
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
if
(
FLAGS_pserver_enable_create_feasign_randomly
&&
!
_value_accesor
->
CreateValue
(
1
,
update_data
))
{
continue
;
}
auto
value_size
=
value_col
-
mf_value_col
;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
value_size
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
const_cast
<
float
*>
(
feature_value
.
data
()),
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
itr
=
local_shard
.
find
(
key
);
}
auto
&
feature_value
=
itr
.
value
();
float
*
value_data
=
const_cast
<
float
*>
(
feature_value
.
data
());
size_t
value_size
=
feature_value
.
size
();
if
(
value_size
==
value_col
)
{
// 已拓展到最大size, 则就地update
_value_accesor
->
Update
(
&
value_data
,
&
update_data
,
1
);
}
else
{
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy
(
data_buffer_ptr
,
value_data
,
value_size
*
sizeof
(
float
));
_value_accesor
->
Update
(
&
data_buffer_ptr
,
&
update_data
,
1
);
if
(
_value_accesor
->
NeedExtendMF
(
data_buffer
))
{
feature_value
.
resize
(
value_col
);
value_data
=
const_cast
<
float
*>
(
feature_value
.
data
());
_value_accesor
->
Create
(
&
value_data
,
1
);
}
memcpy
(
value_data
,
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
}
}
return
0
;
});
}
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
tasks
[
i
].
wait
();
}
}
return
0
;
}
int32_t
SSDSparseTable
::
Shrink
(
const
std
::
string
&
param
)
{
int
thread_num
=
_real_local_shard_num
<
20
?
_real_local_shard_num
:
20
;
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
uint64_t
mem_count
=
0
;
uint64_t
ssd_count
=
0
;
LOG
(
INFO
)
<<
"SSDSparseTable begin shrink shard:"
<<
i
;
auto
&
shard
=
_local_shards
[
i
];
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();)
{
if
(
_value_accesor
->
Shrink
(
it
.
value
().
data
()))
{
it
=
shard
.
erase
(
it
);
mem_count
++
;
}
else
{
++
it
;
}
}
auto
*
it
=
_db
->
get_iterator
(
i
);
for
(
it
->
SeekToFirst
();
it
->
Valid
();
it
->
Next
())
{
if
(
_value_accesor
->
Shrink
(
paddle
::
string
::
str_to_float
(
it
->
value
().
data
())))
{
_db
->
del_data
(
i
,
it
->
key
().
data
(),
it
->
key
().
size
());
ssd_count
++
;
}
else
{
_db
->
put
(
i
,
it
->
key
().
data
(),
it
->
key
().
size
(),
it
->
value
().
data
(),
it
->
value
().
size
());
}
}
delete
it
;
LOG
(
INFO
)
<<
"SSDSparseTable shrink success. shard:"
<<
i
<<
" delete MEM["
<<
mem_count
<<
"] SSD["
<<
ssd_count
<<
"]"
;
// _db->flush(i);
}
return
0
;
}
int32_t
SSDSparseTable
::
UpdateTable
()
{
// TODO implement with multi-thread
int
count
=
0
;
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
auto
&
shard
=
_local_shards
[
i
];
// from mem to ssd
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();)
{
if
(
_value_accesor
->
SaveSSD
(
it
.
value
().
data
()))
{
_db
->
put
(
i
,
(
char
*
)
&
it
.
key
(),
sizeof
(
uint64_t
),
(
char
*
)
it
.
value
().
data
(),
it
.
value
().
size
()
*
sizeof
(
float
));
count
++
;
it
=
shard
.
erase
(
it
);
}
else
{
++
it
;
}
}
_db
->
flush
(
i
);
}
LOG
(
INFO
)
<<
"Table>> update count: "
<<
count
;
return
0
;
}
int64_t
SSDSparseTable
::
LocalSize
()
{
int64_t
local_size
=
0
;
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
local_size
+=
_local_shards
[
i
].
size
();
}
// TODO rocksdb size
return
local_size
;
}
int32_t
SSDSparseTable
::
Save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
if
(
_real_local_shard_num
==
0
)
{
_local_show_threshold
=
-
1
;
return
0
;
}
int
save_param
=
atoi
(
param
.
c_str
());
// batch_model:0 xbox:1
// if (save_param == 5) {
// return save_patch(path, save_param);
// }
// LOG(INFO) << "table cache rate is: " << _config.sparse_table_cache_rate();
LOG
(
INFO
)
<<
"table cache rate is: "
<<
_config
.
sparse_table_cache_rate
();
LOG
(
INFO
)
<<
"enable_sparse_table_cache: "
<<
_config
.
enable_sparse_table_cache
();
LOG
(
INFO
)
<<
"LocalSize: "
<<
LocalSize
();
if
(
_config
.
enable_sparse_table_cache
())
{
LOG
(
INFO
)
<<
"Enable sparse table cache, top n:"
<<
_cache_tk_size
;
}
_cache_tk_size
=
LocalSize
()
*
_config
.
sparse_table_cache_rate
();
TopkCalculator
tk
(
_real_local_shard_num
,
_cache_tk_size
);
size_t
file_start_idx
=
_avg_local_shard_num
*
_shard_idx
;
std
::
string
table_path
=
TableDir
(
path
);
_afs_client
.
remove
(
paddle
::
string
::
format_string
(
"%s/part-%03d-*"
,
table_path
.
c_str
(),
_shard_idx
));
int
thread_num
=
_real_local_shard_num
<
20
?
_real_local_shard_num
:
20
;
// std::atomic<uint32_t> feasign_size;
std
::
atomic
<
uint32_t
>
feasign_size_all
{
0
};
// feasign_size = 0;
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
FsChannelConfig
channel_config
;
if
(
_config
.
compress_in_save
()
&&
(
save_param
==
0
||
save_param
==
3
))
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d-%05d.gz"
,
table_path
.
c_str
(),
_shard_idx
,
file_start_idx
+
i
);
}
else
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d-%05d"
,
table_path
.
c_str
(),
_shard_idx
,
file_start_idx
+
i
);
}
channel_config
.
converter
=
_value_accesor
->
Converter
(
save_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
Converter
(
save_param
).
deconverter
;
int
err_no
=
0
;
int
retry_num
=
0
;
bool
is_write_failed
=
false
;
int
feasign_size
=
0
;
auto
&
shard
=
_local_shards
[
i
];
do
{
err_no
=
0
;
feasign_size
=
0
;
is_write_failed
=
false
;
auto
write_channel
=
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
,
&
err_no
);
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();
++
it
)
{
if
(
_config
.
enable_sparse_table_cache
()
&&
(
save_param
==
1
||
save_param
==
2
)
&&
_value_accesor
->
Save
(
it
.
value
().
data
(),
4
))
{
// tk.push(i, it.value().data()[2]);
tk
.
push
(
i
,
_value_accesor
->
GetField
(
it
.
value
().
data
(),
"show"
));
}
if
(
_value_accesor
->
Save
(
it
.
value
().
data
(),
save_param
))
{
std
::
string
format_value
=
_value_accesor
->
ParseToString
(
it
.
value
().
data
(),
it
.
value
().
size
());
if
(
0
!=
write_channel
->
write_line
(
paddle
::
string
::
format_string
(
"%lu %s"
,
it
.
key
(),
format_value
.
c_str
())))
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"SSDSparseTable save failed, retry it! path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
break
;
}
++
feasign_size
;
}
}
if
(
err_no
==
-
1
&&
!
is_write_failed
)
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"SSDSparseTable save failed after write, retry it! "
<<
"path:"
<<
channel_config
.
path
<<
" , retry_num="
<<
retry_num
;
}
if
(
is_write_failed
)
{
_afs_client
.
remove
(
channel_config
.
path
);
continue
;
}
// delta and cache and revert is all in mem, base in rocksdb
if
(
save_param
!=
1
)
{
auto
*
it
=
_db
->
get_iterator
(
i
);
for
(
it
->
SeekToFirst
();
it
->
Valid
();
it
->
Next
())
{
bool
need_save
=
_value_accesor
->
Save
(
paddle
::
string
::
str_to_float
(
it
->
value
().
data
()),
save_param
);
_value_accesor
->
UpdateStatAfterSave
(
paddle
::
string
::
str_to_float
(
it
->
value
().
data
()),
save_param
);
if
(
need_save
)
{
std
::
string
format_value
=
_value_accesor
->
ParseToString
(
paddle
::
string
::
str_to_float
(
it
->
value
().
data
()),
it
->
value
().
size
()
/
sizeof
(
float
));
if
(
0
!=
write_channel
->
write_line
(
paddle
::
string
::
format_string
(
"%lu %s"
,
*
((
uint64_t
*
)
const_cast
<
char
*>
(
it
->
key
().
data
())),
format_value
.
c_str
())))
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"SSDSparseTable save failed, retry it! path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
break
;
}
if
(
save_param
==
3
)
{
_db
->
put
(
i
,
it
->
key
().
data
(),
it
->
key
().
size
(),
it
->
value
().
data
(),
it
->
value
().
size
());
}
++
feasign_size
;
}
}
delete
it
;
}
write_channel
->
close
();
if
(
err_no
==
-
1
)
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"SSDSparseTable save failed after write, retry it! "
<<
"path:"
<<
channel_config
.
path
<<
" , retry_num="
<<
retry_num
;
}
if
(
is_write_failed
)
{
_afs_client
.
remove
(
channel_config
.
path
);
}
}
while
(
is_write_failed
);
feasign_size_all
+=
feasign_size
;
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();
++
it
)
{
_value_accesor
->
UpdateStatAfterSave
(
it
.
value
().
data
(),
save_param
);
}
}
if
(
save_param
==
3
)
{
UpdateTable
();
_cache_tk_size
=
LocalSize
()
*
_config
.
sparse_table_cache_rate
();
LOG
(
INFO
)
<<
"SSDSparseTable update success."
;
}
LOG
(
INFO
)
<<
"SSDSparseTable save success, path:"
<<
paddle
::
string
::
format_string
(
"%s/%03d/part-%03d-"
,
path
.
c_str
(),
_config
.
table_id
(),
_shard_idx
)
<<
" from "
<<
file_start_idx
<<
" to "
<<
file_start_idx
+
_real_local_shard_num
-
1
;
// return feasign_size_all;
_local_show_threshold
=
tk
.
top
();
LOG
(
INFO
)
<<
"local cache threshold: "
<<
_local_show_threshold
;
// int32 may overflow need to change return value
return
0
;
}
int64_t
SSDSparseTable
::
CacheShuffle
(
const
std
::
string
&
path
,
const
std
::
string
&
param
,
double
cache_threshold
,
std
::
function
<
std
::
future
<
int32_t
>
(
int
msg_type
,
int
to_pserver_id
,
std
::
string
&
msg
)
>
send_msg_func
,
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>&
shuffled_channel
,
const
std
::
vector
<
Table
*>&
table_ptrs
)
{
LOG
(
INFO
)
<<
"cache shuffle with cache threshold: "
<<
cache_threshold
<<
" param:"
<<
param
;
int
save_param
=
atoi
(
param
.
c_str
());
// batch_model:0 xbox:1
if
(
!
_config
.
enable_sparse_table_cache
()
||
cache_threshold
<
0
)
{
LOG
(
WARNING
)
<<
"cache shuffle failed not enable table cache or cache threshold < 0 "
<<
_config
.
enable_sparse_table_cache
()
<<
" or "
<<
cache_threshold
;
// return -1;
}
int
shuffle_node_num
=
_config
.
sparse_table_cache_file_num
();
LOG
(
INFO
)
<<
"Table>> shuffle node num is: "
<<
shuffle_node_num
;
int
thread_num
=
_real_local_shard_num
<
20
?
_real_local_shard_num
:
20
;
std
::
vector
<
paddle
::
framework
::
ChannelWriter
<
std
::
pair
<
uint64_t
,
std
::
string
>>>
writers
(
_real_local_shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>>
datas
(
_real_local_shard_num
);
int
feasign_size
=
0
;
std
::
vector
<
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>>
tmp_channels
;
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
tmp_channels
.
push_back
(
paddle
::
framework
::
MakeChannel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
());
}
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
paddle
::
framework
::
ChannelWriter
<
std
::
pair
<
uint64_t
,
std
::
string
>>&
writer
=
writers
[
i
];
// std::shared_ptr<paddle::framework::ChannelObject<std::pair<uint64_t,
// std::string>>> tmp_chan =
// paddle::framework::MakeChannel<std::pair<uint64_t,
// std::string>>();
writer
.
Reset
(
tmp_channels
[
i
].
get
());
auto
&
shard
=
_local_shards
[
i
];
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();
++
it
)
{
if
(
_value_accesor
->
SaveCache
(
it
.
value
().
data
(),
save_param
,
cache_threshold
))
{
std
::
string
format_value
=
_value_accesor
->
ParseToString
(
it
.
value
().
data
(),
it
.
value
().
size
());
std
::
pair
<
uint64_t
,
std
::
string
>
pkv
(
it
.
key
(),
format_value
.
c_str
());
writer
<<
pkv
;
++
feasign_size
;
}
}
writer
.
Flush
();
writer
.
channel
()
->
Close
();
}
LOG
(
INFO
)
<<
"SSDSparseTable cache KV save success to Channel feasigh size: "
<<
feasign_size
<<
" and start sparse cache data shuffle real local shard num: "
<<
_real_local_shard_num
;
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
local_datas
;
for
(
int
idx_shard
=
0
;
idx_shard
<
_real_local_shard_num
;
++
idx_shard
)
{
paddle
::
framework
::
ChannelWriter
<
std
::
pair
<
uint64_t
,
std
::
string
>>&
writer
=
writers
[
idx_shard
];
auto
channel
=
writer
.
channel
();
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>&
data
=
datas
[
idx_shard
];
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
ars
(
shuffle_node_num
);
while
(
channel
->
Read
(
data
))
{
for
(
auto
&
t
:
data
)
{
auto
pserver_id
=
paddle
::
distributed
::
local_random_engine
()()
%
shuffle_node_num
;
if
(
pserver_id
!=
_shard_idx
)
{
ars
[
pserver_id
]
<<
t
;
}
else
{
local_datas
.
emplace_back
(
std
::
move
(
t
));
}
}
std
::
vector
<
std
::
future
<
int32_t
>>
total_status
;
std
::
vector
<
uint32_t
>
send_data_size
(
shuffle_node_num
,
0
);
std
::
vector
<
int
>
send_index
(
shuffle_node_num
);
for
(
int
i
=
0
;
i
<
shuffle_node_num
;
++
i
)
{
send_index
[
i
]
=
i
;
}
std
::
random_shuffle
(
send_index
.
begin
(),
send_index
.
end
());
for
(
int
index
=
0
;
index
<
shuffle_node_num
;
++
index
)
{
size_t
i
=
send_index
[
index
];
if
(
i
==
_shard_idx
)
{
continue
;
}
if
(
ars
[
i
].
Length
()
==
0
)
{
continue
;
}
std
::
string
msg
(
ars
[
i
].
Buffer
(),
ars
[
i
].
Length
());
auto
ret
=
send_msg_func
(
101
,
i
,
msg
);
total_status
.
push_back
(
std
::
move
(
ret
));
send_data_size
[
i
]
+=
ars
[
i
].
Length
();
}
for
(
auto
&
t
:
total_status
)
{
t
.
wait
();
}
ars
.
clear
();
ars
=
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
(
shuffle_node_num
);
data
=
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
();
}
}
shuffled_channel
->
Write
(
std
::
move
(
local_datas
));
LOG
(
INFO
)
<<
"cache shuffle finished"
;
return
0
;
}
int32_t
SSDSparseTable
::
SaveCache
(
const
std
::
string
&
path
,
const
std
::
string
&
param
,
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>&
shuffled_channel
)
{
if
(
_shard_idx
>=
_config
.
sparse_table_cache_file_num
())
{
return
0
;
}
int
save_param
=
atoi
(
param
.
c_str
());
// batch_model:0 xbox:1
std
::
string
table_path
=
paddle
::
string
::
format_string
(
"%s/%03d_cache/"
,
path
.
c_str
(),
_config
.
table_id
());
_afs_client
.
remove
(
paddle
::
string
::
format_string
(
"%s/part-%03d"
,
table_path
.
c_str
(),
_shard_idx
));
uint32_t
feasign_size
=
0
;
FsChannelConfig
channel_config
;
// not compress cache model
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d"
,
table_path
.
c_str
(),
_shard_idx
);
channel_config
.
converter
=
_value_accesor
->
Converter
(
save_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
Converter
(
save_param
).
deconverter
;
auto
write_channel
=
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
);
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
data
;
bool
is_write_failed
=
false
;
shuffled_channel
->
Close
();
while
(
shuffled_channel
->
Read
(
data
))
{
for
(
auto
&
t
:
data
)
{
++
feasign_size
;
if
(
0
!=
write_channel
->
write_line
(
paddle
::
string
::
format_string
(
"%lu %s"
,
t
.
first
,
t
.
second
.
c_str
())))
{
LOG
(
ERROR
)
<<
"Cache Table save failed, "
"path:"
<<
channel_config
.
path
<<
", retry it!"
;
is_write_failed
=
true
;
break
;
}
}
data
=
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
();
}
if
(
is_write_failed
)
{
_afs_client
.
remove
(
channel_config
.
path
);
}
write_channel
->
close
();
LOG
(
INFO
)
<<
"SSDSparseTable cache save success, feasign: "
<<
feasign_size
<<
", path: "
<<
channel_config
.
path
;
shuffled_channel
->
Open
();
return
feasign_size
;
}
int32_t
SSDSparseTable
::
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
return
MemorySparseTable
::
Load
(
path
,
param
);
}
//加载path目录下数据[start_idx, end_idx)
int32_t
SSDSparseTable
::
Load
(
size_t
start_idx
,
size_t
end_idx
,
const
std
::
vector
<
std
::
string
>&
file_list
,
const
std
::
string
&
param
)
{
if
(
start_idx
>=
file_list
.
size
())
{
return
0
;
}
int
load_param
=
atoi
(
param
.
c_str
());
size_t
feature_value_size
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
mf_value_size
=
_value_accesor
->
GetAccessorInfo
().
mf_size
/
sizeof
(
float
);
end_idx
=
static_cast
<
int
>
(
end_idx
)
<
_sparse_table_shard_num
?
end_idx
:
_sparse_table_shard_num
;
int
thread_num
=
(
end_idx
-
start_idx
)
<
20
?
(
end_idx
-
start_idx
)
:
20
;
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
for
(
size_t
i
=
start_idx
;
i
<
end_idx
;
++
i
)
{
FsChannelConfig
channel_config
;
channel_config
.
path
=
file_list
[
i
];
channel_config
.
converter
=
_value_accesor
->
Converter
(
load_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
Converter
(
load_param
).
deconverter
;
int
retry_num
=
0
;
int
err_no
=
0
;
bool
is_read_failed
=
false
;
std
::
vector
<
std
::
pair
<
char
*
,
int
>>
ssd_keys
;
std
::
vector
<
std
::
pair
<
char
*
,
int
>>
ssd_values
;
std
::
vector
<
uint64_t
>
tmp_key
;
ssd_keys
.
reserve
(
FLAGS_pserver_load_batch_size
);
ssd_values
.
reserve
(
FLAGS_pserver_load_batch_size
);
tmp_key
.
reserve
(
FLAGS_pserver_load_batch_size
);
do
{
ssd_keys
.
clear
();
ssd_values
.
clear
();
tmp_key
.
clear
();
err_no
=
0
;
is_read_failed
=
false
;
std
::
string
line_data
;
auto
read_channel
=
_afs_client
.
open_r
(
channel_config
,
0
,
&
err_no
);
char
*
end
=
NULL
;
int
local_shard_id
=
i
%
_avg_local_shard_num
;
auto
&
shard
=
_local_shards
[
local_shard_id
];
float
data_buffer
[
FLAGS_pserver_load_batch_size
*
feature_value_size
];
float
*
data_buffer_ptr
=
data_buffer
;
uint64_t
mem_count
=
0
;
uint64_t
ssd_count
=
0
;
uint64_t
mem_mf_count
=
0
;
uint64_t
ssd_mf_count
=
0
;
try
{
while
(
read_channel
->
read_line
(
line_data
)
==
0
&&
line_data
.
size
()
>
1
)
{
uint64_t
key
=
std
::
strtoul
(
line_data
.
data
(),
&
end
,
10
);
if
(
FLAGS_pserver_open_strict_check
)
{
if
(
key
%
_sparse_table_shard_num
!=
i
)
{
LOG
(
WARNING
)
<<
"SSDSparseTable key:"
<<
key
<<
" not match shard,"
<<
" file_idx:"
<<
i
<<
" shard num:"
<<
_sparse_table_shard_num
<<
" file:"
<<
channel_config
.
path
;
continue
;
}
}
size_t
value_size
=
_value_accesor
->
ParseFromString
(
++
end
,
data_buffer_ptr
);
// ssd or mem
if
(
_value_accesor
->
SaveSSD
(
data_buffer_ptr
))
{
tmp_key
.
emplace_back
(
key
);
ssd_keys
.
emplace_back
(
std
::
make_pair
((
char
*
)
&
tmp_key
.
back
(),
sizeof
(
uint64_t
)));
ssd_values
.
emplace_back
(
std
::
make_pair
((
char
*
)
data_buffer_ptr
,
value_size
*
sizeof
(
float
)));
data_buffer_ptr
+=
feature_value_size
;
if
(
static_cast
<
int
>
(
ssd_keys
.
size
())
==
FLAGS_pserver_load_batch_size
)
{
_db
->
put_batch
(
local_shard_id
,
ssd_keys
,
ssd_values
,
ssd_keys
.
size
());
ssd_keys
.
clear
();
ssd_values
.
clear
();
tmp_key
.
clear
();
data_buffer_ptr
=
data_buffer
;
}
ssd_count
++
;
if
(
value_size
>
feature_value_size
-
mf_value_size
)
{
ssd_mf_count
++
;
}
}
else
{
auto
&
value
=
shard
[
key
];
value
.
resize
(
value_size
);
_value_accesor
->
ParseFromString
(
end
,
value
.
data
());
mem_count
++
;
if
(
value_size
>
feature_value_size
-
mf_value_size
)
{
mem_mf_count
++
;
}
}
}
// last batch
if
(
ssd_keys
.
size
()
>
0
)
{
_db
->
put_batch
(
local_shard_id
,
ssd_keys
,
ssd_values
,
ssd_keys
.
size
());
}
read_channel
->
close
();
if
(
err_no
==
-
1
)
{
++
retry_num
;
is_read_failed
=
true
;
LOG
(
ERROR
)
<<
"SSDSparseTable load failed after read, retry it! path:"
<<
channel_config
.
path
<<
" , retry_num="
<<
retry_num
;
continue
;
}
_db
->
flush
(
local_shard_id
);
LOG
(
INFO
)
<<
"Table>> load done. ALL["
<<
mem_count
+
ssd_count
<<
"] MEM["
<<
mem_count
<<
"] MEM_MF["
<<
mem_mf_count
<<
"] SSD["
<<
ssd_count
<<
"] SSD_MF["
<<
ssd_mf_count
<<
"]."
;
}
catch
(...)
{
++
retry_num
;
is_read_failed
=
true
;
LOG
(
ERROR
)
<<
"SSDSparseTable load failed after read, retry it! path:"
<<
channel_config
.
path
<<
" , retry_num="
<<
retry_num
;
}
}
while
(
is_read_failed
);
}
LOG
(
INFO
)
<<
"load num:"
<<
LocalSize
();
LOG
(
INFO
)
<<
"SSDSparseTable load success, path from "
<<
file_list
[
start_idx
]
<<
" to "
<<
file_list
[
end_idx
-
1
];
_cache_tk_size
=
LocalSize
()
*
_config
.
sparse_table_cache_rate
();
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/ssd_sparse_table.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
namespace
paddle
{
namespace
distributed
{
class
SSDSparseTable
:
public
MemorySparseTable
{
public:
typedef
SparseTableShard
<
uint64_t
,
FixedFeatureValue
>
shard_type
;
SSDSparseTable
()
{}
virtual
~
SSDSparseTable
()
{}
int32_t
Initialize
()
override
;
int32_t
InitializeShard
()
override
;
// exchange data
int32_t
UpdateTable
();
int32_t
Pull
(
TableContext
&
context
)
override
;
int32_t
Push
(
TableContext
&
context
)
override
;
int32_t
PullSparse
(
float
*
pull_values
,
const
uint64_t
*
keys
,
size_t
num
);
int32_t
PullSparsePtr
(
char
**
pull_values
,
const
uint64_t
*
keys
,
size_t
num
);
int32_t
PushSparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
);
int32_t
PushSparse
(
const
uint64_t
*
keys
,
const
float
**
values
,
size_t
num
);
int32_t
Flush
()
override
{
return
0
;
}
virtual
int32_t
Shrink
(
const
std
::
string
&
param
)
override
;
virtual
void
Clear
()
override
{
for
(
int
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
_local_shards
[
i
].
clear
();
}
}
virtual
int32_t
Save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
virtual
int32_t
SaveCache
(
const
std
::
string
&
path
,
const
std
::
string
&
param
,
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>&
shuffled_channel
)
override
;
virtual
double
GetCacheThreshold
()
override
{
return
_local_show_threshold
;
}
virtual
int64_t
CacheShuffle
(
const
std
::
string
&
path
,
const
std
::
string
&
param
,
double
cache_threshold
,
std
::
function
<
std
::
future
<
int32_t
>
(
int
msg_type
,
int
to_pserver_id
,
std
::
string
&
msg
)
>
send_msg_func
,
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>&
shuffled_channel
,
const
std
::
vector
<
Table
*>&
table_ptrs
)
override
;
//加载path目录下数据
virtual
int32_t
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
//加载path目录下数据[start_idx, end_idx)
virtual
int32_t
Load
(
size_t
start_idx
,
size_t
end_idx
,
const
std
::
vector
<
std
::
string
>&
file_list
,
const
std
::
string
&
param
);
int64_t
LocalSize
();
private:
RocksDBHandler
*
_db
;
int64_t
_cache_tk_size
;
double
_local_show_threshold
{
0.0
};
};
}
// namespace distributed
}
// namespace paddle
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment