Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
d2d32668
Commit
d2d32668
authored
Apr 26, 2023
by
yuguo960516yuguo
Browse files
2.3.0-dtk-22.04.2
parent
ad08b8ce
Pipeline
#226
failed with stages
in 0 seconds
Changes
268
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6936 additions
and
0 deletions
+6936
-0
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
...istributed/fleet_executor/test/source_interceptor_test.cc
+85
-0
paddle/fluid/distributed/index_dataset/CMakeLists.txt
paddle/fluid/distributed/index_dataset/CMakeLists.txt
+19
-0
paddle/fluid/distributed/index_dataset/index_dataset.proto
paddle/fluid/distributed/index_dataset/index_dataset.proto
+33
-0
paddle/fluid/distributed/index_dataset/index_sampler.cc
paddle/fluid/distributed/index_dataset/index_sampler.cc
+138
-0
paddle/fluid/distributed/index_dataset/index_sampler.h
paddle/fluid/distributed/index_dataset/index_sampler.h
+139
-0
paddle/fluid/distributed/index_dataset/index_wrapper.cc
paddle/fluid/distributed/index_dataset/index_wrapper.cc
+202
-0
paddle/fluid/distributed/index_dataset/index_wrapper.h
paddle/fluid/distributed/index_dataset/index_wrapper.h
+125
-0
paddle/fluid/distributed/ps/CMakeLists.txt
paddle/fluid/distributed/ps/CMakeLists.txt
+4
-0
paddle/fluid/distributed/ps/README.md
paddle/fluid/distributed/ps/README.md
+39
-0
paddle/fluid/distributed/ps/service/CMakeLists.txt
paddle/fluid/distributed/ps/service/CMakeLists.txt
+136
-0
paddle/fluid/distributed/ps/service/README.md
paddle/fluid/distributed/ps/service/README.md
+8
-0
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
+1919
-0
paddle/fluid/distributed/ps/service/brpc_ps_client.h
paddle/fluid/distributed/ps/service/brpc_ps_client.h
+365
-0
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
+871
-0
paddle/fluid/distributed/ps/service/brpc_ps_server.h
paddle/fluid/distributed/ps/service/brpc_ps_server.h
+230
-0
paddle/fluid/distributed/ps/service/brpc_utils.cc
paddle/fluid/distributed/ps/service/brpc_utils.cc
+355
-0
paddle/fluid/distributed/ps/service/brpc_utils.h
paddle/fluid/distributed/ps/service/brpc_utils.h
+98
-0
paddle/fluid/distributed/ps/service/communicator/CMakeLists.txt
.../fluid/distributed/ps/service/communicator/CMakeLists.txt
+15
-0
paddle/fluid/distributed/ps/service/communicator/communicator.cc
...fluid/distributed/ps/service/communicator/communicator.cc
+1494
-0
paddle/fluid/distributed/ps/service/communicator/communicator.h
.../fluid/distributed/ps/service/communicator/communicator.h
+661
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
class
FakeInterceptor
:
public
Interceptor
{
public:
FakeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
step_
=
0
;
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
std
::
cout
<<
"FakeInterceptor run in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
InterceptorMessage
reply
;
reply
.
set_message_type
(
DATA_IS_USELESS
);
Send
(
SOURCE_ID
,
reply
);
step_
++
;
if
(
step_
==
node_
->
max_run_times
())
{
carrier_
->
WakeUp
();
}
}
}
private:
int64_t
step_
;
};
TEST
(
SourceInterceptor
,
Source
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/index_dataset/CMakeLists.txt
0 → 100644
View file @
d2d32668
proto_library
(
index_dataset_proto SRCS index_dataset.proto
)
cc_library
(
index_wrapper
SRCS index_wrapper.cc
DEPS index_dataset_proto fs
)
if
(
WITH_MKLDNN
)
cc_library
(
index_sampler
SRCS index_sampler.cc
DEPS xxhash index_wrapper eigen3 mkldnn
)
else
()
cc_library
(
index_sampler
SRCS index_sampler.cc
DEPS xxhash index_wrapper eigen3
)
endif
()
if
(
WITH_PYTHON
)
py_proto_compile
(
index_dataset_py_proto SRCS index_dataset.proto
)
endif
()
paddle/fluid/distributed/index_dataset/index_dataset.proto
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax
=
"proto2"
;
package
paddle
.
distributed
;
message
IndexNode
{
required
uint64
id
=
1
;
required
bool
is_leaf
=
2
;
required
float
probability
=
3
;
optional
string
item_name
=
4
;
}
message
TreeMeta
{
required
int32
height
=
1
;
required
int32
branch
=
2
;
}
message
KVItem
{
required
bytes
key
=
1
;
required
bytes
value
=
2
;
}
paddle/fluid/distributed/index_dataset/index_sampler.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
distributed
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
LayerWiseSampler
::
sample
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
user_inputs
,
const
std
::
vector
<
uint64_t
>&
target_ids
,
bool
with_hierarchy
)
{
auto
input_num
=
target_ids
.
size
();
auto
user_feature_num
=
user_inputs
[
0
].
size
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
outputs
(
input_num
*
layer_counts_sum_
,
std
::
vector
<
uint64_t
>
(
user_feature_num
+
2
));
auto
max_layer
=
tree_
->
Height
();
size_t
idx
=
0
;
for
(
size_t
i
=
0
;
i
<
input_num
;
i
++
)
{
auto
travel_codes
=
tree_
->
GetTravelCodes
(
target_ids
[
i
],
start_sample_layer_
);
auto
travel_path
=
tree_
->
GetNodes
(
travel_codes
);
for
(
size_t
j
=
0
;
j
<
travel_path
.
size
();
j
++
)
{
// user
if
(
j
>
0
&&
with_hierarchy
)
{
auto
ancestor_codes
=
tree_
->
GetAncestorCodes
(
user_inputs
[
i
],
max_layer
-
j
-
1
);
auto
hierarchical_user
=
tree_
->
GetNodes
(
ancestor_codes
);
for
(
int
idx_offset
=
0
;
idx_offset
<=
layer_counts_
[
j
];
idx_offset
++
)
{
for
(
size_t
k
=
0
;
k
<
user_feature_num
;
k
++
)
{
outputs
[
idx
+
idx_offset
][
k
]
=
hierarchical_user
[
k
].
id
();
}
}
}
else
{
for
(
int
idx_offset
=
0
;
idx_offset
<=
layer_counts_
[
j
];
idx_offset
++
)
{
for
(
size_t
k
=
0
;
k
<
user_feature_num
;
k
++
)
{
outputs
[
idx
+
idx_offset
][
k
]
=
user_inputs
[
i
][
k
];
}
}
}
// sampler ++
outputs
[
idx
][
user_feature_num
]
=
travel_path
[
j
].
id
();
outputs
[
idx
][
user_feature_num
+
1
]
=
1.0
;
idx
+=
1
;
for
(
int
idx_offset
=
0
;
idx_offset
<
layer_counts_
[
j
];
idx_offset
++
)
{
int
sample_res
=
0
;
do
{
sample_res
=
sampler_vec_
[
j
]
->
Sample
();
}
while
(
layer_ids_
[
j
][
sample_res
].
id
()
==
travel_path
[
j
].
id
());
outputs
[
idx
+
idx_offset
][
user_feature_num
]
=
layer_ids_
[
j
][
sample_res
].
id
();
outputs
[
idx
+
idx_offset
][
user_feature_num
+
1
]
=
0
;
}
idx
+=
layer_counts_
[
j
];
}
}
return
outputs
;
}
void
LayerWiseSampler
::
sample_from_dataset
(
const
uint16_t
sample_slot
,
std
::
vector
<
paddle
::
framework
::
Record
>*
src_datas
,
std
::
vector
<
paddle
::
framework
::
Record
>*
sample_results
)
{
sample_results
->
clear
();
for
(
auto
&
data
:
*
src_datas
)
{
VLOG
(
1
)
<<
"src data size = "
<<
src_datas
->
size
();
VLOG
(
1
)
<<
"float data size = "
<<
data
.
float_feasigns_
.
size
();
// data.Print();
uint64_t
start_idx
=
sample_results
->
size
();
VLOG
(
1
)
<<
"before sample, sample_results.size = "
<<
start_idx
;
uint64_t
sample_feasign_idx
=
-
1
;
bool
sample_sign
=
false
;
for
(
unsigned
int
i
=
0
;
i
<
data
.
uint64_feasigns_
.
size
();
i
++
)
{
VLOG
(
1
)
<<
"slot"
<<
i
<<
" = "
<<
data
.
uint64_feasigns_
[
i
].
slot
();
if
(
data
.
uint64_feasigns_
[
i
].
slot
()
==
sample_slot
)
{
sample_sign
=
true
;
sample_feasign_idx
=
i
;
}
if
(
sample_sign
)
break
;
}
VLOG
(
1
)
<<
"sample_feasign_idx: "
<<
sample_feasign_idx
;
if
(
sample_sign
)
{
auto
target_id
=
data
.
uint64_feasigns_
[
sample_feasign_idx
].
sign
().
uint64_feasign_
;
auto
travel_codes
=
tree_
->
GetTravelCodes
(
target_id
,
start_sample_layer_
);
auto
travel_path
=
tree_
->
GetNodes
(
travel_codes
);
for
(
unsigned
int
j
=
0
;
j
<
travel_path
.
size
();
j
++
)
{
paddle
::
framework
::
Record
instance
(
data
);
instance
.
uint64_feasigns_
[
sample_feasign_idx
].
sign
().
uint64_feasign_
=
travel_path
[
j
].
id
();
sample_results
->
push_back
(
instance
);
for
(
int
idx_offset
=
0
;
idx_offset
<
layer_counts_
[
j
];
idx_offset
++
)
{
int
sample_res
=
0
;
do
{
sample_res
=
sampler_vec_
[
j
]
->
Sample
();
}
while
(
layer_ids_
[
j
][
sample_res
].
id
()
==
travel_path
[
j
].
id
());
paddle
::
framework
::
Record
instance
(
data
);
instance
.
uint64_feasigns_
[
sample_feasign_idx
].
sign
().
uint64_feasign_
=
layer_ids_
[
j
][
sample_res
].
id
();
VLOG
(
1
)
<<
"layer id :"
<<
layer_ids_
[
j
][
sample_res
].
id
();
// sample_feasign_idx + 1 == label's id
instance
.
uint64_feasigns_
[
sample_feasign_idx
+
1
]
.
sign
()
.
uint64_feasign_
=
0
;
sample_results
->
push_back
(
instance
);
}
VLOG
(
1
)
<<
"layer end!!!!!!!!!!!!!!!!!!"
;
}
}
}
VLOG
(
1
)
<<
"after sample, sample_results.size = "
<<
sample_results
->
size
();
return
;
}
std
::
vector
<
uint64_t
>
float2int
(
std
::
vector
<
double
>
tmp
)
{
std
::
vector
<
uint64_t
>
tmp_int
;
for
(
auto
i
:
tmp
)
tmp_int
.
push_back
(
uint64_t
(
i
));
return
tmp_int
;
}
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/index_dataset/index_sampler.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/math/sampler.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
class
IndexSampler
{
public:
virtual
~
IndexSampler
()
{}
IndexSampler
()
{}
template
<
typename
T
>
static
std
::
shared_ptr
<
IndexSampler
>
Init
(
const
std
::
string
&
name
)
{
std
::
shared_ptr
<
IndexSampler
>
instance
=
nullptr
;
instance
.
reset
(
new
T
(
name
));
return
instance
;
}
virtual
void
init_layerwise_conf
(
const
std
::
vector
<
uint16_t
>&
layer_sample_counts
,
uint16_t
start_sample_layer
=
1
,
uint16_t
seed
=
0
)
{}
virtual
void
init_beamsearch_conf
(
const
int64_t
k
)
{}
virtual
std
::
vector
<
std
::
vector
<
uint64_t
>>
sample
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
user_inputs
,
const
std
::
vector
<
uint64_t
>&
input_targets
,
bool
with_hierarchy
=
false
)
=
0
;
virtual
void
sample_from_dataset
(
const
uint16_t
sample_slot
,
std
::
vector
<
paddle
::
framework
::
Record
>*
src_datas
,
std
::
vector
<
paddle
::
framework
::
Record
>*
sample_results
)
=
0
;
};
class
LayerWiseSampler
:
public
IndexSampler
{
public:
virtual
~
LayerWiseSampler
()
{}
explicit
LayerWiseSampler
(
const
std
::
string
&
name
)
{
tree_
=
IndexWrapper
::
GetInstance
()
->
get_tree_index
(
name
);
}
void
init_layerwise_conf
(
const
std
::
vector
<
uint16_t
>&
layer_sample_counts
,
uint16_t
start_sample_layer
,
uint16_t
seed
)
override
{
seed_
=
seed
;
start_sample_layer_
=
start_sample_layer
;
PADDLE_ENFORCE_GT
(
start_sample_layer_
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"start sampler layer = [%d], it should greater than 0."
,
start_sample_layer_
));
PADDLE_ENFORCE_LT
(
start_sample_layer_
,
tree_
->
Height
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"start sampler layer = [%d], it should less than "
"max_layer, which is [%d]."
,
start_sample_layer_
,
tree_
->
Height
()));
size_t
i
=
0
;
layer_counts_sum_
=
0
;
layer_counts_
.
clear
();
int
cur_layer
=
start_sample_layer_
;
while
(
cur_layer
<
tree_
->
Height
())
{
int
layer_sample_num
=
1
;
if
(
i
<
layer_sample_counts
.
size
())
{
layer_sample_num
=
layer_sample_counts
[
i
];
}
layer_counts_sum_
+=
layer_sample_num
+
1
;
layer_counts_
.
push_back
(
layer_sample_num
);
VLOG
(
3
)
<<
"[INFO] level "
<<
cur_layer
<<
" sample_layer_counts.push_back: "
<<
layer_sample_num
;
cur_layer
+=
1
;
i
+=
1
;
}
reverse
(
layer_counts_
.
begin
(),
layer_counts_
.
end
());
VLOG
(
3
)
<<
"sample counts sum: "
<<
layer_counts_sum_
;
auto
max_layer
=
tree_
->
Height
();
sampler_vec_
.
clear
();
layer_ids_
.
clear
();
auto
layer_index
=
max_layer
-
1
;
size_t
idx
=
0
;
while
(
layer_index
>=
start_sample_layer_
)
{
auto
layer_codes
=
tree_
->
GetLayerCodes
(
layer_index
);
layer_ids_
.
push_back
(
tree_
->
GetNodes
(
layer_codes
));
auto
sampler_temp
=
std
::
make_shared
<
paddle
::
operators
::
math
::
UniformSampler
>
(
layer_ids_
[
idx
].
size
()
-
1
,
seed_
);
sampler_vec_
.
push_back
(
sampler_temp
);
layer_index
--
;
idx
++
;
}
}
std
::
vector
<
std
::
vector
<
uint64_t
>>
sample
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
user_inputs
,
const
std
::
vector
<
uint64_t
>&
target_ids
,
bool
with_hierarchy
)
override
;
void
sample_from_dataset
(
const
uint16_t
sample_slot
,
std
::
vector
<
paddle
::
framework
::
Record
>*
src_datas
,
std
::
vector
<
paddle
::
framework
::
Record
>*
sample_results
)
override
;
private:
std
::
vector
<
int
>
layer_counts_
;
int64_t
layer_counts_sum_
{
0
};
std
::
shared_ptr
<
TreeIndex
>
tree_
{
nullptr
};
int
seed_
{
0
};
int
start_sample_layer_
{
1
};
std
::
vector
<
std
::
shared_ptr
<
paddle
::
operators
::
math
::
Sampler
>>
sampler_vec_
;
std
::
vector
<
std
::
vector
<
IndexNode
>>
layer_ids_
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/index_dataset/index_wrapper.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/io/fs.h"
namespace
paddle
{
namespace
distributed
{
std
::
shared_ptr
<
IndexWrapper
>
IndexWrapper
::
s_instance_
(
nullptr
);
int
TreeIndex
::
Load
(
const
std
::
string
filename
)
{
int
err_no
;
auto
fp
=
paddle
::
framework
::
fs_open_read
(
filename
,
&
err_no
,
""
);
PADDLE_ENFORCE_NE
(
fp
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Open file %s failed. Please check whether the file exists."
,
filename
));
int
num
=
0
;
max_id_
=
0
;
fake_node_
.
set_id
(
0
);
fake_node_
.
set_is_leaf
(
false
);
fake_node_
.
set_probability
(
0.0
);
max_code_
=
0
;
size_t
ret
=
fread
(
&
num
,
sizeof
(
num
),
1
,
fp
.
get
());
while
(
ret
==
1
&&
num
>
0
)
{
std
::
string
content
(
num
,
'\0'
);
size_t
read_num
=
fread
(
const_cast
<
char
*>
(
content
.
data
()),
1
,
num
,
fp
.
get
());
PADDLE_ENFORCE_EQ
(
read_num
,
static_cast
<
size_t
>
(
num
),
platform
::
errors
::
InvalidArgument
(
"Read from file: %s failed. Valid Format is "
"an integer representing the length of the following string, "
"and the string itself.We got an iteger[% d], "
"but the following string's length is [%d]."
,
filename
,
num
,
read_num
));
KVItem
item
;
PADDLE_ENFORCE_EQ
(
item
.
ParseFromString
(
content
),
true
,
platform
::
errors
::
InvalidArgument
(
"Parse from file: %s failed. It's "
"content can't be parsed by KVItem."
,
filename
));
if
(
item
.
key
()
==
".tree_meta"
)
{
meta_
.
ParseFromString
(
item
.
value
());
}
else
{
auto
code
=
std
::
stoull
(
item
.
key
());
IndexNode
node
;
node
.
ParseFromString
(
item
.
value
());
// PADDLE_ENFORCE_NE(node.id(), 0,
// platform::errors::InvalidArgument(
// "Node'id should not be equel to zero."));
if
(
node
.
is_leaf
())
{
id_codes_map_
[
node
.
id
()]
=
code
;
}
data_
[
code
]
=
node
;
if
(
node
.
id
()
>
max_id_
)
{
max_id_
=
node
.
id
();
}
if
(
code
>
max_code_
)
{
max_code_
=
code
;
}
}
ret
=
fread
(
&
num
,
sizeof
(
num
),
1
,
fp
.
get
());
}
total_nodes_num_
=
data_
.
size
();
max_code_
+=
1
;
return
0
;
}
std
::
vector
<
IndexNode
>
TreeIndex
::
GetNodes
(
const
std
::
vector
<
uint64_t
>&
codes
)
{
std
::
vector
<
IndexNode
>
nodes
;
nodes
.
reserve
(
codes
.
size
());
for
(
size_t
i
=
0
;
i
<
codes
.
size
();
i
++
)
{
if
(
CheckIsValid
(
codes
[
i
]))
{
nodes
.
push_back
(
data_
.
at
(
codes
[
i
]));
}
else
{
nodes
.
push_back
(
fake_node_
);
}
}
return
nodes
;
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetLayerCodes
(
int
level
)
{
uint64_t
level_num
=
static_cast
<
uint64_t
>
(
std
::
pow
(
meta_
.
branch
(),
level
));
uint64_t
level_offset
=
level_num
-
1
;
std
::
vector
<
uint64_t
>
res
;
res
.
reserve
(
level_num
);
for
(
uint64_t
i
=
0
;
i
<
level_num
;
i
++
)
{
auto
code
=
level_offset
+
i
;
if
(
CheckIsValid
(
code
))
{
res
.
push_back
(
code
);
}
}
return
res
;
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetAncestorCodes
(
const
std
::
vector
<
uint64_t
>&
ids
,
int
level
)
{
std
::
vector
<
uint64_t
>
res
;
res
.
reserve
(
ids
.
size
());
int
cur_level
;
for
(
size_t
i
=
0
;
i
<
ids
.
size
();
i
++
)
{
if
(
id_codes_map_
.
find
(
ids
[
i
])
==
id_codes_map_
.
end
())
{
res
.
push_back
(
max_code_
);
}
else
{
auto
code
=
id_codes_map_
.
at
(
ids
[
i
]);
cur_level
=
meta_
.
height
()
-
1
;
while
(
level
>=
0
&&
cur_level
>
level
)
{
code
=
(
code
-
1
)
/
meta_
.
branch
();
cur_level
--
;
}
res
.
push_back
(
code
);
}
}
return
res
;
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetChildrenCodes
(
uint64_t
ancestor
,
int
level
)
{
auto
level_code_num
=
static_cast
<
uint64_t
>
(
std
::
pow
(
meta_
.
branch
(),
level
));
auto
code_min
=
level_code_num
-
1
;
auto
code_max
=
meta_
.
branch
()
*
level_code_num
-
1
;
std
::
vector
<
uint64_t
>
parent
;
parent
.
push_back
(
ancestor
);
std
::
vector
<
uint64_t
>
res
;
size_t
p_idx
=
0
;
while
(
true
)
{
size_t
p_size
=
parent
.
size
();
for
(;
p_idx
<
p_size
;
p_idx
++
)
{
for
(
int
i
=
0
;
i
<
meta_
.
branch
();
i
++
)
{
auto
code
=
parent
[
p_idx
]
*
meta_
.
branch
()
+
i
+
1
;
if
(
data_
.
find
(
code
)
!=
data_
.
end
())
parent
.
push_back
(
code
);
}
}
if
((
code_min
<=
parent
[
p_idx
])
&&
(
parent
[
p_idx
]
<
code_max
))
{
break
;
}
}
return
std
::
vector
<
uint64_t
>
(
parent
.
begin
()
+
p_idx
,
parent
.
end
());
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetTravelCodes
(
uint64_t
id
,
int
start_level
)
{
std
::
vector
<
uint64_t
>
res
;
PADDLE_ENFORCE_NE
(
id_codes_map_
.
find
(
id
),
id_codes_map_
.
end
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"id = %d doesn't exist in Tree."
,
id
));
auto
code
=
id_codes_map_
.
at
(
id
);
int
level
=
meta_
.
height
()
-
1
;
while
(
level
>=
start_level
)
{
res
.
push_back
(
code
);
code
=
(
code
-
1
)
/
meta_
.
branch
();
level
--
;
}
return
res
;
}
std
::
vector
<
IndexNode
>
TreeIndex
::
GetAllLeafs
()
{
std
::
vector
<
IndexNode
>
res
;
res
.
reserve
(
id_codes_map_
.
size
());
for
(
auto
&
ite
:
id_codes_map_
)
{
auto
code
=
ite
.
second
;
res
.
push_back
(
data_
.
at
(
code
));
}
return
res
;
}
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/index_dataset/index_wrapper.h
0 → 100644
View file @
d2d32668
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_dataset.pb.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
class
Index
{
public:
Index
()
{}
~
Index
()
{}
};
class
TreeIndex
:
public
Index
{
public:
TreeIndex
()
{}
~
TreeIndex
()
{}
int
Height
()
{
return
meta_
.
height
();
}
int
Branch
()
{
return
meta_
.
branch
();
}
uint64_t
TotalNodeNums
()
{
return
total_nodes_num_
;
}
uint64_t
EmbSize
()
{
return
max_id_
+
1
;
}
int
Load
(
const
std
::
string
path
);
inline
bool
CheckIsValid
(
int
code
)
{
if
(
data_
.
find
(
code
)
!=
data_
.
end
())
{
return
true
;
}
else
{
return
false
;
}
}
std
::
vector
<
IndexNode
>
GetNodes
(
const
std
::
vector
<
uint64_t
>&
codes
);
std
::
vector
<
uint64_t
>
GetLayerCodes
(
int
level
);
std
::
vector
<
uint64_t
>
GetAncestorCodes
(
const
std
::
vector
<
uint64_t
>&
ids
,
int
level
);
std
::
vector
<
uint64_t
>
GetChildrenCodes
(
uint64_t
ancestor
,
int
level
);
std
::
vector
<
uint64_t
>
GetTravelCodes
(
uint64_t
id
,
int
start_level
);
std
::
vector
<
IndexNode
>
GetAllLeafs
();
std
::
unordered_map
<
uint64_t
,
IndexNode
>
data_
;
std
::
unordered_map
<
uint64_t
,
uint64_t
>
id_codes_map_
;
uint64_t
total_nodes_num_
;
TreeMeta
meta_
;
uint64_t
max_id_
;
uint64_t
max_code_
;
IndexNode
fake_node_
;
};
using
TreePtr
=
std
::
shared_ptr
<
TreeIndex
>
;
class
IndexWrapper
{
public:
virtual
~
IndexWrapper
()
{}
IndexWrapper
()
{}
void
clear_tree
()
{
tree_map
.
clear
();
}
TreePtr
get_tree_index
(
const
std
::
string
name
)
{
PADDLE_ENFORCE_NE
(
tree_map
.
find
(
name
),
tree_map
.
end
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"tree [%s] doesn't exist. Please insert it firstly "
"by API[
\'
insert_tree_index
\'
]."
,
name
));
return
tree_map
[
name
];
}
void
insert_tree_index
(
const
std
::
string
name
,
const
std
::
string
tree_path
)
{
if
(
tree_map
.
find
(
name
)
!=
tree_map
.
end
())
{
VLOG
(
0
)
<<
"Tree "
<<
name
<<
" has already existed."
;
return
;
}
TreePtr
tree
=
std
::
make_shared
<
TreeIndex
>
();
int
ret
=
tree
->
Load
(
tree_path
);
PADDLE_ENFORCE_EQ
(
ret
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Load tree[%s] from path[%s] failed. Please "
"check whether the file exists."
,
name
,
tree_path
));
tree_map
.
insert
(
std
::
pair
<
std
::
string
,
TreePtr
>
{
name
,
tree
});
}
static
std
::
shared_ptr
<
IndexWrapper
>
GetInstancePtr
()
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
distributed
::
IndexWrapper
());
}
return
s_instance_
;
}
static
IndexWrapper
*
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
distributed
::
IndexWrapper
());
}
return
s_instance_
.
get
();
}
private:
static
std
::
shared_ptr
<
IndexWrapper
>
s_instance_
;
std
::
unordered_map
<
std
::
string
,
TreePtr
>
tree_map
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/ps/CMakeLists.txt
0 → 100644
View file @
d2d32668
set_property
(
GLOBAL PROPERTY RPC_DEPS sendrecv_rpc
${
BRPC_DEPS
}
string_helper
)
add_subdirectory
(
table
)
add_subdirectory
(
service
)
add_subdirectory
(
wrapper
)
paddle/fluid/distributed/ps/README.md
0 → 100644
View file @
d2d32668
# 目录说明
Table: for param storage and update
-----MemorySparseTable: table for sparse param, used in cpu async mode
-----MemoryDenseTable: table for dense param, used in cpu async/geo mode
-----MemorySparseGeoTable: table for sparse param, used in cpu async mode
-----CommonGraphTable: table used for graph learning
-----BarrierTable: table for barrier function, used in cpu sync mode
-----TensorTable: table which run program, used for learning rate decay only
ValueAccessor: for pull param and push gradient
-----CtrCommonAccessor: pull/push value with show/click, float type
-----CtrDoubleAccessor: same as CtrCommonAccessor, other than show/click with double type
-----SparseAccessor: used for common embedding, pull value without show/click, push value with show/click
-----CommMergeAccessor: used for dense table only, for get param dim
PsService(proto): for server to handle request
-----PsBaseService
----------BrpcPsService: for cpu dnn training task
----------GraphBrpcService: for graph learning
-----HeterService: for dnn training task with heterogeneous computing resources
PSServer: recv request from trainer and handle it by service
-----BrpcPsServer: for cpu dnn training task
-----GraphBrpcServer: for graph learning
-----PsLocalServer: for GpuPS
HeterServer: for HeterPS
PSClient: pull param and push gradient for trainer
-----BrpcPsClient: for cpu dnn training task
----------GraphBrpcClient: for graph learning
-----PsLocalClient: for GpuPS
HeterClient: for HeterPS
PSCore: Wrapper for InitServer
GraphPyService: for graph learning
paddle/fluid/distributed/ps/service/CMakeLists.txt
0 → 100644
View file @
d2d32668
set
(
BRPC_SRCS ps_client.cc server.cc
)
set_source_files_properties
(
${
BRPC_SRCS
}
)
if
(
WITH_HETERPS
)
set
(
BRPC_DEPS
brpc
ssl
crypto
protobuf
gflags
glog
zlib
leveldb
snappy
gflags
glog
device_context
rocksdb
)
else
()
set
(
BRPC_DEPS
brpc
ssl
crypto
protobuf
gflags
glog
zlib
leveldb
snappy
gflags
glog
device_context
)
endif
()
brpc_library
(
sendrecv_rpc
SRCS
${
BRPC_SRCS
}
PROTO
sendrecv.proto
DEPS
${
BRPC_DEPS
}
)
#set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
get_property
(
RPC_DEPS GLOBAL PROPERTY RPC_DEPS
)
set_source_files_properties
(
communicator/communicator.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ps_service/service.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
brpc_ps_server.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
brpc_ps_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ps_local_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
brpc_utils.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
heter_server.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
heter_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ps_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
server.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
graph_brpc_server.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
graph_brpc_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
brpc_utils
SRCS brpc_utils.cc
DEPS tensor device_context
${
COMMON_DEPS
}
${
RPC_DEPS
}
)
cc_library
(
downpour_server
SRCS graph_brpc_server.cc brpc_ps_server.cc
DEPS boost eigen3 table brpc_utils simple_threadpool
${
RPC_DEPS
}
)
cc_library
(
downpour_client
SRCS graph_brpc_client.cc brpc_ps_client.cc ps_local_client.cc
DEPS boost eigen3 table brpc_utils simple_threadpool
${
RPC_DEPS
}
)
cc_library
(
client
SRCS ps_client.cc
DEPS downpour_client boost
${
RPC_DEPS
}
)
cc_library
(
server
SRCS server.cc
DEPS downpour_server boost
${
RPC_DEPS
}
)
cc_library
(
communicator
SRCS communicator/communicator.cc
DEPS scope
client
boost
table
math_function
selected_rows_functor
${
RPC_DEPS
}
)
cc_library
(
ps_service
SRCS ps_service/service.cc
DEPS communicator client server boost
${
RPC_DEPS
}
)
cc_library
(
heter_client
SRCS heter_client.cc
DEPS brpc_utils
${
COMMON_DEPS
}
${
RPC_DEPS
}
)
cc_library
(
heter_server
SRCS heter_server.cc
DEPS heter_client brpc_utils
${
COMMON_DEPS
}
${
RPC_DEPS
}
)
set_source_files_properties
(
ps_service/graph_py_service.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
graph_py_service
SRCS ps_service/graph_py_service.cc
DEPS ps_service
)
#add_subdirectory(communicator)
paddle/fluid/distributed/ps/service/README.md
0 → 100644
View file @
d2d32668
# 目录说明
*
PSServer
*
PSClient
*
PsService
*
Communicator
*
MessageBusFramework
*
*
.proto
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include <memory>
#include <sstream>
#include <string>
#include "paddle/fluid/framework/archive.h"
static
const
int
max_port
=
65535
;
DEFINE_int32
(
pserver_push_dense_merge_limit
,
12
,
"limit max push_dense local merge requests"
);
DEFINE_int32
(
pserver_push_sparse_merge_limit
,
12
,
"limit max push_sparse local merge requests"
);
DEFINE_int32
(
pserver_pull_dense_limit
,
12
,
"limit max push_sparse local merge requests"
);
DEFINE_int32
(
pserver_async_push_dense_interval_ms
,
10
,
"async push_dense to server interval"
);
DEFINE_int32
(
pserver_async_push_sparse_interval_ms
,
10
,
"async push_sparse to server interval"
);
DEFINE_bool
(
pserver_scale_gradient_by_merge
,
false
,
"scale dense gradient when merged"
);
DEFINE_int32
(
pserver_communicate_compress_type
,
0
,
"none:0 snappy:1 gzip:2 zlib:3 lz4:4"
);
DEFINE_int32
(
pserver_max_async_call_num
,
13
,
"max task num in async_call_server"
);
DEFINE_int32
(
pserver_timeout_ms
,
500000
,
"pserver request server timeout_ms"
);
DEFINE_int32
(
pserver_connect_timeout_ms
,
10000
,
"pserver connect server timeout_ms"
);
DEFINE_int32
(
pserver_sparse_merge_thread
,
1
,
"pserver sparse merge thread num"
);
DEFINE_int32
(
pserver_sparse_table_shard_num
,
1000
,
"sparse table shard for save & load"
);
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
distributed
{
inline
size_t
get_sparse_shard
(
uint32_t
shard_num
,
uint32_t
server_num
,
uint64_t
key
)
{
size_t
remind
=
shard_num
%
server_num
;
size_t
local_shard_num
=
remind
==
0
?
shard_num
/
server_num
:
shard_num
/
server_num
+
1
;
return
(
key
%
shard_num
)
/
local_shard_num
;
}
void
DownpourPsClientService
::
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
int
ret
=
_client
->
HandleClient2ClientMsg
(
request
->
cmd_id
(),
request
->
client_id
(),
request
->
data
());
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
if
(
ret
!=
0
)
{
response
->
set_err_code
(
-
1
);
response
->
set_err_msg
(
"handle_client2client_msg failed"
);
}
}
// 启动client端RpcService 用于数据互发等操作
int32_t
BrpcPsClient
::
StartClientService
()
{
if
(
_service
.
Configure
(
this
,
_client_id
)
!=
0
)
{
LOG
(
ERROR
)
<<
"service initialize failed, service_name:DownpourPsClientService"
;
return
-
1
;
}
_server
.
AddService
(
&
_service
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
);
brpc
::
ServerOptions
options
;
int
start_port
=
8500
;
options
.
num_threads
=
24
;
if
(
_server
.
Start
(
butil
::
my_ip_cstr
(),
brpc
::
PortRange
(
start_port
,
max_port
),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPsServer start failed"
;
return
-
1
;
}
_server_started
=
true
;
_env
->
RegistePsClient
(
butil
::
my_ip_cstr
(),
_server
.
listen_address
().
port
,
_client_id
);
return
0
;
}
int32_t
BrpcPsClient
::
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
pserver_timeout_ms
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
pserver_connect_timeout_ms
;
options
.
max_retry
=
max_retry
;
std
::
vector
<
PSHost
>
client_list
=
_env
->
GetPsClients
();
VLOG
(
1
)
<<
"BrpcPsClient::create_c2c_connection client_list size: "
<<
client_list
.
size
();
for
(
auto
cc
:
client_list
)
{
VLOG
(
1
)
<<
"BrpcPsClient::create_c2c_connection client_list: "
<<
cc
.
ToString
();
}
_client_channels
.
resize
(
client_list
.
size
());
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
for
(
size_t
i
=
0
;
i
<
client_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
client_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
client_list
[
i
].
port
));
_client_channels
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
_client_channels
[
i
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
))
{
VLOG
(
0
)
<<
"BrpcPSClient connect to Client:"
<<
server_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
client_list
[
i
].
ip
,
client_list
[
i
].
port
);
if
(
_client_channels
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPSClient connect to Client:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
os
<<
server_ip_port
<<
","
;
}
LOG
(
INFO
)
<<
"Client connect success:"
<<
os
.
str
();
return
0
;
}
int32_t
BrpcPsClient
::
Initialize
()
{
_async_call_num
=
0
;
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
FLAGS_pserver_timeout_ms
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
FLAGS_pserver_connect_timeout_ms
;
options
.
max_retry
=
3
;
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
std
::
string
client_ip
(
butil
::
my_ip_cstr
());
// 获取server列表,并连接
std
::
vector
<
PSHost
>
server_list
=
_env
->
GetPsServers
();
_server_channels
.
resize
(
server_list
.
size
());
for
(
size_t
i
=
0
;
i
<
server_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
server_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
server_list
[
i
].
port
));
for
(
size_t
j
=
0
;
j
<
_server_channels
[
i
].
size
();
++
j
)
{
_server_channels
[
i
][
j
].
reset
(
new
brpc
::
Channel
());
if
(
_server_channels
[
i
][
j
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"BrpcPSclient connect to Server:"
<<
server_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
server_list
[
i
].
ip
,
server_list
[
i
].
port
);
if
(
_server_channels
[
i
][
j
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPSclient connect to Server:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
}
os
<<
server_ip_port
<<
","
;
}
// 启动client探听接口, 并相互建立连接
StartClientService
();
// 异步push 请求队列初始化
const
auto
&
worker_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
int
i
=
0
;
i
<
worker_param
.
downpour_table_param_size
();
++
i
)
{
auto
type
=
worker_param
.
downpour_table_param
(
i
).
type
();
auto
table_id
=
worker_param
.
downpour_table_param
(
i
).
table_id
();
if
(
type
==
PS_DENSE_TABLE
)
{
_push_dense_task_queue_map
[
table_id
]
=
paddle
::
framework
::
MakeChannel
<
DenseAsyncTask
*>
();
}
if
(
type
==
PS_SPARSE_TABLE
)
{
_push_sparse_task_queue_map
[
table_id
]
=
paddle
::
framework
::
MakeChannel
<
SparseAsyncTask
*>
();
_push_sparse_merge_count_map
[
table_id
]
=
0
;
}
}
auto
&
profiler
=
CostProfiler
::
instance
();
profiler
.
register_profiler
(
"pserver_client_pull_dense"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse_param"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse_local"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse_parse"
);
profiler
.
register_profiler
(
"client_push_sparse_put"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse_merge"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse_rpc"
);
profiler
.
register_profiler
(
"pserver_client_push_dense"
);
profiler
.
register_profiler
(
"pserver_client_push_dense_parse"
);
profiler
.
register_profiler
(
"push_dense_put"
);
profiler
.
register_profiler
(
"pserver_client_push_dense_merge"
);
profiler
.
register_profiler
(
"pserver_client_push_dense_rpc"
);
profiler
.
register_profiler
(
"pserver_client_push_dense_send"
);
_running
=
true
;
_flushing
=
false
;
// 启动异步push线程
_async_push_sparse_thread
=
std
::
thread
(
std
::
bind
(
&
BrpcPsClient
::
PushSparseTaskConsume
,
this
));
// _async_push_sparse_thread.detach();
_async_push_dense_thread
=
std
::
thread
(
std
::
bind
(
&
BrpcPsClient
::
PushDenseTaskConsume
,
this
));
// for debug
// _print_thread =
// std::thread(std::bind(&BrpcPsClient::PrintQueueSizeThread, this));
return
0
;
}
int
DownpourBrpcClosure
::
check_response
(
size_t
request_idx
,
int
cmd_id
)
{
if
(
_cntls
[
request_idx
]
->
Failed
())
{
LOG
(
ERROR
)
<<
"resquest cmd_id:"
<<
cmd_id
<<
" failed, "
"err:"
<<
_cntls
[
request_idx
]
->
ErrorText
();
return
-
1
;
}
if
(
_responses
[
request_idx
].
err_code
()
!=
0
)
{
LOG
(
ERROR
)
<<
"response ret bad, server_idx:"
<<
request_idx
<<
"cmd_id:"
<<
cmd_id
<<
" err_code:"
<<
_responses
[
request_idx
].
err_code
()
<<
" err_msg:"
<<
_responses
[
request_idx
].
err_msg
();
return
-
1
;
}
return
0
;
}
int
DownpourBrpcClosure
::
check_save_response
(
size_t
request_idx
,
int
cmd_id
)
{
int32_t
feasign_size
=
0
;
if
(
_cntls
[
request_idx
]
->
Failed
())
{
LOG
(
ERROR
)
<<
"resquest cmd_id:"
<<
cmd_id
<<
" failed, "
"err:"
<<
_cntls
[
request_idx
]
->
ErrorText
();
return
-
1
;
}
feasign_size
=
_responses
[
request_idx
].
err_code
();
if
(
feasign_size
<
0
)
{
LOG
(
ERROR
)
<<
"response ret bad, server_idx:"
<<
request_idx
<<
"cmd_id:"
<<
cmd_id
<<
" err_code:"
<<
_responses
[
request_idx
].
err_code
()
<<
" err_msg:"
<<
_responses
[
request_idx
].
err_msg
();
return
-
1
;
}
return
feasign_size
;
}
std
::
string
DownpourBrpcClosure
::
get_response
(
size_t
request_idx
,
int
cmd_id
)
{
std
::
string
data
=
_responses
[
request_idx
].
data
();
return
data
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PrintTableStat
(
uint32_t
table_id
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
table_id
](
void
*
done
)
{
int
ret
=
0
;
uint64_t
feasign_size
=
0
;
uint64_t
mf_size
=
0
;
paddle
::
framework
::
BinaryArchive
ar
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PRINT_TABLE_STAT
)
!=
0
)
{
ret
=
-
1
;
break
;
}
std
::
string
resp
=
closure
->
get_response
(
i
,
PS_PRINT_TABLE_STAT
);
ar
.
SetReadBuffer
(
const_cast
<
char
*>
(
resp
.
c_str
()),
resp
.
length
(),
nullptr
);
feasign_size
+=
ar
.
Get
<
uint64_t
>
();
mf_size
+=
ar
.
Get
<
uint64_t
>
();
}
closure
->
set_promise_value
(
ret
);
std
::
cout
<<
"table id: "
<<
table_id
<<
", feasign size: "
<<
feasign_size
<<
", mf size: "
<<
mf_size
<<
std
::
endl
;
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PRINT_TABLE_STAT
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_timeout_ms
(
10800000
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
SendCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>
&
params
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
cmd_id
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
cmd_id
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
cmd_id
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
for
(
const
auto
&
param
:
params
)
{
closure
->
request
(
i
)
->
add_params
(
param
);
}
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_timeout_ms
(
10800000
*
2
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
SendSaveCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>
&
params
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
cmd_id
](
void
*
done
)
{
int
ret
=
0
;
uint32_t
feasign_size
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_save_response
(
i
,
cmd_id
)
<
0
)
{
ret
=
-
1
;
break
;
}
feasign_size
+=
closure
->
check_save_response
(
i
,
cmd_id
);
}
if
(
ret
==
0
)
{
closure
->
set_promise_value
(
feasign_size
);
}
else
{
closure
->
set_promise_value
(
ret
);
}
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
cmd_id
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
for
(
const
auto
&
param
:
params
)
{
closure
->
request
(
i
)
->
add_params
(
param
);
}
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_timeout_ms
(
10800000
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
{
return
SendCmd
(
table_id
,
PS_SHRINK_TABLE
,
{
threshold
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
return
SendCmd
(
-
1
,
PS_LOAD_ALL_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
return
SendCmd
(
table_id
,
PS_LOAD_ONE_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
VLOG
(
1
)
<<
"BrpcPsClient::save path "
<<
epoch
;
return
SendSaveCmd
(
-
1
,
PS_SAVE_ALL_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
VLOG
(
1
)
<<
"BrpcPsClient::save one table path "
<<
epoch
<<
" table_id "
<<
table_id
;
return
SendSaveCmd
(
table_id
,
PS_SAVE_ONE_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
CacheShuffle
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
1
)
<<
"BrpcPsClient send cmd for cache shuffle"
;
return
SendSaveCmd
(
table_id
,
PS_CACHE_SHUFFLE
,
{
path
,
mode
,
cache_threshold
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
CacheShuffleMultiTable
(
std
::
vector
<
int
>
tables
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
1
)
<<
"BrpcPsClient send cmd for cache shuffle multi table one path"
;
std
::
vector
<
std
::
string
>
param
;
param
.
push_back
(
path
);
param
.
push_back
(
mode
);
param
.
push_back
(
cache_threshold
);
for
(
size_t
i
=
0
;
i
<
tables
.
size
();
i
++
)
{
param
.
push_back
(
std
::
to_string
(
tables
[
i
]));
}
return
SendSaveCmd
(
0
,
PS_CACHE_SHUFFLE
,
param
);
}
std
::
future
<
int32_t
>
BrpcPsClient
::
SaveCache
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
{
return
SendSaveCmd
(
table_id
,
PS_SAVE_ONE_CACHE_TABLE
,
{
path
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
GetCacheThreshold
(
uint32_t
table_id
,
double
&
cache_threshold
)
{
int
cmd_id
=
PS_GET_CACHE_THRESHOLD
;
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
cmd_id
,
&
cache_threshold
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
std
::
vector
<
double
>
cache_thresholds
(
request_call_num
,
0
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
cmd_id
)
!=
0
)
{
ret
=
-
1
;
break
;
}
std
::
string
cur_res
=
closure
->
get_response
(
i
,
cmd_id
);
cache_thresholds
[
i
]
=
std
::
stod
(
cur_res
);
}
double
sum_threshold
=
0.0
;
int
count
=
0
;
for
(
auto
t
:
cache_thresholds
)
{
if
(
t
>=
0
)
{
sum_threshold
+=
t
;
++
count
;
}
}
if
(
count
==
0
)
{
cache_threshold
=
0
;
}
else
{
cache_threshold
=
sum_threshold
/
count
;
}
VLOG
(
1
)
<<
"client get cache threshold: "
<<
cache_threshold
;
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
cmd_id
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_timeout_ms
(
10800000
);
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Clear
()
{
return
SendCmd
(
-
1
,
PS_CLEAR_ALL_TABLE
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Clear
(
uint32_t
table_id
)
{
return
SendCmd
(
table_id
,
PS_CLEAR_ONE_TABLE
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Flush
()
{
VLOG
(
0
)
<<
"BrpcPsClient::flush begin"
;
_flushing
=
true
;
std
::
promise
<
int
>
promise
;
std
::
future
<
int32_t
>
fut
=
promise
.
get_future
();
do
{
VLOG
(
3
)
<<
"wait _async_call_num:"
<<
_async_call_num
;
usleep
(
100000
);
// sleep 100ms wait async end
}
while
(
_async_call_num
>
0
);
VLOG
(
1
)
<<
"flush _async_call_num = 0"
;
promise
.
set_value
(
0
);
_flushing
=
false
;
VLOG
(
0
)
<<
"BrpcPsClient::flush done"
;
PrintQueueSize
();
return
fut
;
}
void
BrpcPsClient
::
PrintQueueSize
()
{
for
(
auto
&
push_sparse_task_itr
:
_push_sparse_task_queue_map
)
{
auto
table_id
=
push_sparse_task_itr
.
first
;
auto
queue_size
=
push_sparse_task_itr
.
second
->
Size
();
VLOG
(
0
)
<<
"BrpcPsClient::PrintQueueSize: table "
<<
table_id
<<
" size: "
<<
queue_size
;
}
for
(
auto
&
task_queue_itr
:
_push_dense_task_queue_map
)
{
auto
table_id
=
task_queue_itr
.
first
;
auto
queue_size
=
task_queue_itr
.
second
->
Size
();
VLOG
(
0
)
<<
"BrpcPsClient::PrintQueueSize: table "
<<
table_id
<<
" size: "
<<
queue_size
;
}
}
void
BrpcPsClient
::
PrintQueueSizeThread
()
{
while
(
_running
)
{
usleep
(
1000000
*
60
*
2
);
PrintQueueSize
();
}
}
void
BrpcPsClient
::
FinalizeWorker
()
{
Flush
();
VLOG
(
0
)
<<
"BrpcPsClient::FinalizeWorker begin join thread"
;
_running
=
false
;
_async_push_dense_thread
.
join
();
_async_push_sparse_thread
.
join
();
// _print_thread.join();
VLOG
(
0
)
<<
"BrpcPsClient::FinalizeWorker begin join server"
;
_server
.
Stop
(
1000
);
_server
.
Join
();
_server_started
=
false
;
VLOG
(
0
)
<<
"BrpcPsClient::FinalizeWorker done"
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
StopServer
()
{
return
SendCmd
(
-
1
,
PS_STOP_SERVER
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
StartProfiler
()
{
return
SendCmd
(
-
1
,
PS_START_PROFILER
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
StopProfiler
()
{
return
SendCmd
(
-
1
,
PS_STOP_PROFILER
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
)
{
return
SendCmd
(
table_id
,
PS_BARRIER
,
{
std
::
to_string
(
barrier_type
)});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
uint64_t
>
*
keys
,
int
pserver_idx
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
keys
,
values
,
accessor
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
uint32_t
shard_nums
;
if
(
closure
->
check_response
(
0
,
PS_PULL_GEO_PARAM
)
!=
0
)
{
ret
=
-
1
;
}
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
&
shard_nums
),
sizeof
(
uint32_t
));
keys
->
resize
(
shard_nums
);
values
->
resize
(
shard_nums
*
accessor
->
GetAccessorInfo
().
update_dim
);
io_buffer_itr
.
copy_and_forward
((
void
*
)(
keys
->
data
()),
// NOLINT
sizeof
(
uint64_t
)
*
shard_nums
);
io_buffer_itr
.
copy_and_forward
(
(
void
*
)(
values
->
data
()),
// NOLINT
shard_nums
*
accessor
->
GetAccessorInfo
().
update_size
);
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
closure
->
request
(
0
)
->
set_cmd_id
(
PS_PULL_GEO_PARAM
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
PsService_Stub
rpc_stub
(
GetCmdChannel
(
pserver_idx
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
// for GEO
std
::
future
<
int32_t
>
BrpcPsClient
::
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
// 发送RPC请求
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
size_t
request_call_num
=
_server_channels
.
size
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
ids
;
std
::
vector
<
std
::
vector
<
const
float
*>>
value_ptrs
;
ids
.
resize
(
request_call_num
);
value_ptrs
.
resize
(
request_call_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
pserver_idx
=
keys
[
i
]
%
request_call_num
;
ids
[
pserver_idx
].
push_back
(
keys
[
i
]);
value_ptrs
[
pserver_idx
].
push_back
(
update_values
[
i
]);
}
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
auto
kvs
=
ids
[
shard_idx
];
auto
value_ptr
=
value_ptrs
[
shard_idx
];
size_t
kv_size
=
kvs
.
size
();
uint32_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_PARAM
);
push_request
->
set_table_id
(
table_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
value_size
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
for
(
size_t
i
=
0
;
i
<
kv_size
;
++
i
)
{
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
value_size
);
push_data_ptr
+=
value_size
;
}
PsService_Stub
rpc_stub
(
GetSparseChannel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
rpc_stub
.
service
(
closure
->
cntl
(
shard_idx
),
closure
->
request
(
shard_idx
),
closure
->
response
(
shard_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_dense"
);
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
fea_dim
=
accessor
->
GetAccessorInfo
().
fea_dim
;
size_t
request_call_num
=
_server_channels
.
size
();
uint32_t
num_per_shard
=
DenseDimPerShard
(
fea_dim
,
request_call_num
);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
num_per_shard
,
regions
,
region_num
,
accessor
](
void
*
done
)
{
int
ret
=
0
;
size_t
region_idx
=
0
;
// 当前填充的region偏移
size_t
region_data_idx
=
0
;
// 当前填充的region内data偏移
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
size_t
shard_data_size
=
num_per_shard
*
accessor
->
GetAccessorInfo
().
select_size
;
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
auto
&
res_io_buffer
=
closure
->
cntl
(
i
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
shard_buffer_remain
=
res_io_buffer
.
size
();
if
(
shard_buffer_remain
!=
shard_data_size
)
{
LOG
(
ERROR
)
<<
"expect res_size:"
<<
shard_data_size
<<
", but size:"
<<
shard_buffer_remain
<<
", ignore this response"
;
ret
=
-
1
;
break
;
}
while
(
shard_buffer_remain
>
0
&&
region_idx
<
region_num
)
{
auto
&
region
=
regions
[
region_idx
];
if
(
region
.
size
-
region_data_idx
>=
shard_buffer_remain
)
{
// region待填充空间 >= 分片buffer数据, 直接拷贝置入
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
region
.
data
+
region_data_idx
),
shard_buffer_remain
);
region_data_idx
+=
shard_buffer_remain
;
shard_buffer_remain
=
0
;
}
else
if
(
region
.
size
-
region_data_idx
==
0
)
{
// region填满,切换到下一个region
++
region_idx
;
region_data_idx
=
0
;
}
else
{
// region不足以容纳所有数据,则能放多少 拷贝多少
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
region
.
data
+
region_data_idx
),
region
.
size
-
region_data_idx
);
shard_buffer_remain
-=
(
region
.
size
-
region_data_idx
);
++
region_idx
;
region_data_idx
=
0
;
}
}
}
closure
->
set_promise_value
(
ret
);
});
closure
->
add_timer
(
timer
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PULL_DENSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
closure
->
request
(
i
)
->
add_params
((
char
*
)
&
num_per_shard
,
// NOLINT
sizeof
(
num_per_shard
));
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
accessor_info
=
accessor
->
GetAccessorInfo
();
size_t
request_call_num
=
_server_channels
.
size
();
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std
::
vector
<
std
::
vector
<
Region
>>
regions_partition
(
request_call_num
);
uint32_t
num_per_shard
=
DenseDimPerShard
(
accessor_info
.
fea_dim
,
request_call_num
);
size_t
shard_data_size
=
num_per_shard
*
accessor_info
.
update_size
;
size_t
current_region_idx
=
0
;
size_t
current_region_data_idx
=
0
;
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
size_t
shard_data_remain_size
=
shard_data_size
;
while
(
shard_data_remain_size
>
0
&&
current_region_idx
<
region_num
)
{
const
auto
&
region
=
regions
[
current_region_idx
];
size_t
region_remain_size
=
region
.
size
-
current_region_data_idx
;
if
(
shard_data_remain_size
>=
region_remain_size
)
{
regions_partition
[
i
].
push_back
(
Region
(
region
.
data
+
current_region_data_idx
,
region_remain_size
));
++
current_region_idx
;
current_region_data_idx
=
0
;
shard_data_remain_size
-=
region_remain_size
;
}
else
{
regions_partition
[
i
].
push_back
(
Region
(
region
.
data
+
current_region_data_idx
,
shard_data_remain_size
));
current_region_data_idx
+=
shard_data_remain_size
;
shard_data_remain_size
=
0
;
}
}
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_DENSE_PARAM
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
static
const
int
REGION_ASSIGN_BUFFER_SIZE
=
1024
*
10
;
static
char
region_assign_buffer
[
REGION_ASSIGN_BUFFER_SIZE
];
// 用于数据补齐
// 开始多shard并行拷贝&请求
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_DENSE_PARAM
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
auto
&
request_buffer
=
closure
->
cntl
(
i
)
->
request_attachment
();
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
num_per_shard
),
sizeof
(
uint32_t
));
auto
&
region_list
=
regions_partition
[
i
];
size_t
fill_remain_size
=
shard_data_size
;
for
(
auto
&
region
:
region_list
)
{
fill_remain_size
-=
region
.
size
;
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
region
.
data
),
region
.
size
);
}
// 保证各分片数据对齐
while
(
fill_remain_size
>
0
)
{
size_t
fill_num
=
fill_remain_size
>
REGION_ASSIGN_BUFFER_SIZE
?
REGION_ASSIGN_BUFFER_SIZE
:
fill_remain_size
;
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
region_assign_buffer
),
fill_num
);
fill_remain_size
-=
fill_num
;
}
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
// 发送RPC请求
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
size_t
request_call_num
=
_server_channels
.
size
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
ids
;
std
::
vector
<
std
::
vector
<
const
float
*>>
value_ptrs
;
ids
.
resize
(
request_call_num
);
value_ptrs
.
resize
(
request_call_num
);
const
auto
&
server_param
=
_config
.
server_param
().
downpour_server_param
();
uint64_t
shard_num
=
FLAGS_pserver_sparse_table_shard_num
;
for
(
int
i
=
0
;
i
<
server_param
.
downpour_table_param_size
();
++
i
)
{
const
auto
&
table_param
=
server_param
.
downpour_table_param
(
i
);
if
(
table_param
.
table_id
()
==
table_id
)
{
shard_num
=
table_param
.
shard_num
();
break
;
}
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
pserver_idx
=
get_sparse_shard
(
shard_num
,
request_call_num
,
keys
[
i
]);
ids
[
pserver_idx
].
push_back
(
keys
[
i
]);
value_ptrs
[
pserver_idx
].
push_back
(
update_values
[
i
]);
}
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
auto
kvs
=
ids
[
shard_idx
];
auto
value_ptr
=
value_ptrs
[
shard_idx
];
size_t
kv_size
=
kvs
.
size
();
uint32_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_TABLE
);
push_request
->
set_table_id
(
table_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
value_size
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
for
(
size_t
i
=
0
;
i
<
kv_size
;
++
i
)
{
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
value_size
);
push_data_ptr
+=
value_size
;
}
PsService_Stub
rpc_stub
(
GetSparseChannel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
rpc_stub
.
service
(
closure
->
cntl
(
shard_idx
),
closure
->
request
(
shard_idx
),
closure
->
response
(
shard_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
done
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
auto
*
accessor
=
GetTableAccessor
(
table_id
);
uint32_t
num_per_shard
=
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
request_call_num
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_DENSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
auto
*
push_data
=
closure
->
request
(
i
)
->
mutable_data
();
push_data
->
clear
();
push_data
->
resize
(
sizeof
(
uint32_t
)
+
num_per_shard
*
sizeof
(
float
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
&
num_per_shard
,
sizeof
(
uint32_t
));
memcpy
(
push_data_ptr
+
sizeof
(
uint32_t
),
total_send_data
+
i
*
num_per_shard
,
num_per_shard
*
sizeof
(
float
));
// closure->cntl(i)->set_request_compress_type(
// (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_GLOBAL_STEP
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
auto
*
push_data
=
closure
->
request
(
i
)
->
mutable_data
();
push_data
->
clear
();
int32_t
num_per_shard
=
1
;
push_data
->
resize
(
sizeof
(
uint32_t
)
+
num_per_shard
*
sizeof
(
int64_t
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
&
num_per_shard
,
sizeof
(
uint32_t
));
memcpy
(
push_data_ptr
+
sizeof
(
uint32_t
),
total_send_data
,
num_per_shard
*
sizeof
(
int64_t
));
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_sparse"
);
auto
local_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_sparse_local"
);
size_t
request_call_num
=
_server_channels
.
size
();
auto
shard_sorted_kvs
=
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
*>>>>
();
shard_sorted_kvs
->
resize
(
request_call_num
);
const
auto
&
server_param
=
_config
.
server_param
().
downpour_server_param
();
uint64_t
shard_num
=
FLAGS_pserver_sparse_table_shard_num
;
for
(
int
i
=
0
;
i
<
server_param
.
downpour_table_param_size
();
++
i
)
{
const
auto
&
table_param
=
server_param
.
downpour_table_param
(
i
);
if
(
table_param
.
table_id
()
==
table_id
)
{
shard_num
=
table_param
.
shard_num
();
break
;
}
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
shard_id
=
get_sparse_shard
(
shard_num
,
request_call_num
,
keys
[
i
]);
shard_sorted_kvs
->
at
(
shard_id
).
push_back
({
keys
[
i
],
select_values
[
i
]});
}
auto
*
accessor
=
GetTableAccessor
(
table_id
);
size_t
value_size
=
accessor
->
GetAccessorInfo
().
select_size
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
shard_sorted_kvs
->
size
();
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
auto
&
request_kvs
=
shard_sorted_kvs
->
at
(
i
);
auto
&
res_io_buffer
=
closure
->
cntl
(
i
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
uint64_t
last_key
=
UINT64_MAX
;
float
*
last_value_data
=
NULL
;
for
(
size_t
kv_idx
=
0
;
kv_idx
<
request_kvs
.
size
();
++
kv_idx
)
{
auto
*
kv_pair
=
&
(
request_kvs
[
kv_idx
]);
if
(
kv_pair
->
first
==
last_key
)
{
memcpy
(
reinterpret_cast
<
void
*>
(
kv_pair
->
second
),
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
);
}
else
{
last_key
=
kv_pair
->
first
;
last_value_data
=
kv_pair
->
second
;
if
(
value_size
!=
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
))
{
LOG
(
WARNING
)
<<
"res data is lack or not in format"
;
ret
=
-
1
;
break
;
}
}
}
}
closure
->
set_promise_value
(
ret
);
});
closure
->
add_timer
(
timer
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
auto
&
sorted_kvs
=
shard_sorted_kvs
->
at
(
i
);
std
::
sort
(
sorted_kvs
.
begin
(),
sorted_kvs
.
end
(),
[](
const
std
::
pair
<
uint64_t
,
float
*>
&
k1
,
const
std
::
pair
<
uint64_t
,
float
*>
&
k2
)
{
return
k1
.
first
<
k2
.
first
;
});
uint64_t
last_key
=
UINT64_MAX
;
uint32_t
kv_request_count
=
0
;
size_t
sorted_kv_size
=
sorted_kvs
.
size
();
auto
&
request_buffer
=
closure
->
cntl
(
i
)
->
request_attachment
();
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
is_training
),
sizeof
(
bool
));
std
::
vector
<
uint32_t
>
keys_counter
;
keys_counter
.
reserve
(
sorted_kv_size
);
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
++
kv_request_count
;
uint32_t
keys
=
1
;
last_key
=
sorted_kvs
[
kv_idx
].
first
;
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
last_key
),
sizeof
(
uint64_t
));
while
(
kv_idx
<
sorted_kv_size
-
1
&&
last_key
==
sorted_kvs
[
kv_idx
+
1
].
first
)
{
++
kv_idx
;
++
keys
;
}
keys_counter
.
push_back
(
keys
);
}
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
keys_counter
.
data
()),
sizeof
(
uint32_t
)
*
keys_counter
.
size
());
if
(
kv_request_count
==
0
)
{
closure
->
Run
();
}
else
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PULL_SPARSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
closure
->
request
(
i
)
->
add_params
((
char
*
)
&
kv_request_count
,
// NOLINT
sizeof
(
uint32_t
));
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
}
return
fut
;
}
// for GEO
std
::
future
<
int32_t
>
BrpcPsClient
::
PullSparseParam
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_sparse_param"
);
size_t
request_call_num
=
_server_channels
.
size
();
auto
shard_sorted_kvs
=
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
*>>>>
();
shard_sorted_kvs
->
resize
(
request_call_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
shard_id
=
keys
[
i
]
%
request_call_num
;
shard_sorted_kvs
->
at
(
shard_id
).
push_back
({
keys
[
i
],
select_values
[
i
]});
}
auto
*
accessor
=
GetTableAccessor
(
table_id
);
size_t
value_size
=
accessor
->
GetAccessorInfo
().
select_size
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
shard_sorted_kvs
->
size
();
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
auto
&
request_kvs
=
shard_sorted_kvs
->
at
(
i
);
auto
&
res_io_buffer
=
closure
->
cntl
(
i
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
uint64_t
last_key
=
UINT64_MAX
;
float
*
last_value_data
=
NULL
;
// can remove sort&unique
for
(
size_t
kv_idx
=
0
;
kv_idx
<
request_kvs
.
size
();
++
kv_idx
)
{
auto
*
kv_pair
=
&
(
request_kvs
[
kv_idx
]);
if
(
kv_pair
->
first
==
last_key
)
{
memcpy
(
reinterpret_cast
<
void
*>
(
kv_pair
->
second
),
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
);
}
else
{
last_key
=
kv_pair
->
first
;
last_value_data
=
kv_pair
->
second
;
if
(
value_size
!=
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
))
{
LOG
(
WARNING
)
<<
"res data is lack or not in format"
;
ret
=
-
1
;
break
;
}
}
}
}
closure
->
set_promise_value
(
ret
);
});
closure
->
add_timer
(
timer
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
auto
&
sorted_kvs
=
shard_sorted_kvs
->
at
(
i
);
std
::
sort
(
sorted_kvs
.
begin
(),
sorted_kvs
.
end
(),
[](
const
std
::
pair
<
uint64_t
,
float
*>
&
k1
,
const
std
::
pair
<
uint64_t
,
float
*>
&
k2
)
{
return
k1
.
first
<
k2
.
first
;
});
uint64_t
last_key
=
UINT64_MAX
;
uint32_t
kv_request_count
=
0
;
size_t
sorted_kv_size
=
sorted_kvs
.
size
();
auto
&
request_buffer
=
closure
->
cntl
(
i
)
->
request_attachment
();
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
is_training
),
sizeof
(
bool
));
std
::
vector
<
uint32_t
>
keys_counter
;
keys_counter
.
reserve
(
sorted_kv_size
);
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
++
kv_request_count
;
uint32_t
keys
=
1
;
last_key
=
sorted_kvs
[
kv_idx
].
first
;
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
last_key
),
sizeof
(
uint64_t
));
while
(
kv_idx
<
sorted_kv_size
-
1
&&
last_key
==
sorted_kvs
[
kv_idx
+
1
].
first
)
{
++
kv_idx
;
++
keys
;
}
keys_counter
.
push_back
(
keys
);
}
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
keys_counter
.
data
()),
sizeof
(
uint32_t
)
*
keys_counter
.
size
());
if
(
kv_request_count
==
0
)
{
closure
->
Run
();
}
else
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PULL_SPARSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
closure
->
request
(
i
)
->
add_params
((
char
*
)
&
kv_request_count
,
// NOLINT
sizeof
(
uint32_t
));
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
to_client_id
>=
0
&&
static_cast
<
size_t
>
(
to_client_id
)
>=
_client_channels
.
size
())
{
VLOG
(
0
)
<<
"to_client_id is out of range clients, which size is "
<<
_client_channels
.
size
();
promise
->
set_value
(
-
1
);
return
fut
;
}
auto
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
msg_type
](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
int32_t
ret
=
closure
->
check_response
(
0
,
msg_type
+
1000
);
closure
->
set_promise_value
(
ret
);
});
closure
->
add_promise
(
promise
);
closure
->
request
(
0
)
->
set_cmd_id
(
msg_type
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
set_data
(
msg
);
PsService_Stub
rpc_stub
(
_client_channels
[
to_client_id
].
get
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
size_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
0
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_TABLE
);
push_request
->
set_table_id
(
table_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
num
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
num
*
(
sizeof
(
uint64_t
)
+
value_size
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
keys
,
num
*
sizeof
(
uint64_t
));
push_data_ptr
+=
num
*
sizeof
(
uint64_t
);
for
(
uint32_t
i
=
0
;
i
<
num
;
++
i
)
{
memcpy
(
push_data_ptr
,
update_values
[
i
],
value_size
);
push_data_ptr
+=
value_size
;
}
PsService_Stub
rpc_stub
(
GetSparseChannel
(
pserver_idx
));
closure
->
cntl
(
0
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
int32_t
BrpcPsClient
::
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
{
// get var information
std
::
string
var_name
=
""
;
int64_t
var_num
=
0
;
int64_t
var_shape
=
0
;
std
::
string
table_class
;
const
auto
&
worker_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
int
i
=
0
;
i
<
worker_param
.
downpour_table_param_size
();
++
i
)
{
if
(
worker_param
.
downpour_table_param
(
i
).
table_id
()
==
table_id
)
{
var_name
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_name
();
var_num
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_num
();
var_shape
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_dim
();
table_class
=
worker_param
.
downpour_table_param
(
i
).
table_class
();
break
;
}
}
PADDLE_ENFORCE_NE
(
var_name
,
""
,
platform
::
errors
::
InvalidArgument
(
"Cannot find table id %d to save variables."
,
table_id
));
std
::
string
var_store
=
string
::
Sprintf
(
"%s"
,
path
);
MkDirRecursively
(
var_store
.
c_str
());
// pull sparse from server
std
::
vector
<
float
>
save_huge_vec
(
var_num
*
var_shape
);
std
::
vector
<
uint64_t
>
save_key
(
var_num
);
std
::
vector
<
float
*>
save_vec
;
for
(
size_t
i
=
0
;
i
<
save_key
.
size
();
++
i
)
{
save_key
[
i
]
=
i
;
save_vec
.
push_back
(
save_huge_vec
.
data
()
+
i
*
var_shape
);
}
VLOG
(
2
)
<<
"RecvAndSaveTable: table_class: "
<<
table_class
;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// RecvAndSaveTable
if
(
table_class
==
"MemorySparseGeoTable"
)
{
auto
status
=
PullSparseParam
(
reinterpret_cast
<
float
**>
(
save_vec
.
data
()),
table_id
,
save_key
.
data
(),
save_key
.
size
(),
true
);
status
.
wait
();
}
else
{
auto
status
=
PullSparse
(
reinterpret_cast
<
float
**>
(
save_vec
.
data
()),
table_id
,
save_key
.
data
(),
save_key
.
size
(),
true
);
status
.
wait
();
}
// create lod tensor
std
::
shared_ptr
<
framework
::
Scope
>
scope
;
scope
.
reset
(
new
framework
::
Scope
());
auto
place
=
platform
::
CPUPlace
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
Variable
*
var
=
scope
->
Var
(
var_name
);
framework
::
LoDTensor
*
var_tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
std
::
vector
<
int64_t
>
vec_dim
=
{
var_num
,
var_shape
};
var_tensor
->
Resize
(
phi
::
make_ddim
(
vec_dim
));
// copy and save
float
*
tensor_data
=
var_tensor
->
mutable_data
<
float
>
(
place
);
memcpy
(
tensor_data
,
save_huge_vec
.
data
(),
var_num
*
var_shape
*
sizeof
(
float
));
std
::
string
file_name
=
string
::
Sprintf
(
"%s/%s"
,
var_store
,
var_name
);
std
::
ofstream
fout
(
file_name
,
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fout
),
true
,
platform
::
errors
::
Unavailable
(
"Cannot open %s to save variables."
,
file_name
));
framework
::
SerializeToStream
(
fout
,
*
var_tensor
,
dev_ctx
);
fout
.
close
();
return
0
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
{
auto
push_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_sparse"
);
CostTimer
parse_timer
(
"pserver_client_push_sparse_parse"
);
int
push_sparse_async_num
=
_push_sparse_task_queue_map
[
table_id
]
->
Size
();
while
(
push_sparse_async_num
>
FLAGS_pserver_max_async_call_num
)
{
// LOG(INFO) << "PushSparse Waiting for async_call_num comsume,
// task_num:"
// << push_sparse_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
usleep
(
5000
);
// 5ms
push_sparse_async_num
=
_push_sparse_task_queue_map
[
table_id
]
->
Size
();
}
auto
put_timer
=
std
::
make_shared
<
CostTimer
>
(
"client_push_sparse_put"
);
thread_local
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
const
float
*>>>
shard_sorted_kv_list
;
auto
*
accessor
=
GetTableAccessor
(
table_id
);
size_t
request_call_num
=
_server_channels
.
size
();
shard_sorted_kv_list
.
resize
(
request_call_num
);
for
(
auto
&
x
:
shard_sorted_kv_list
)
{
x
.
clear
();
}
const
auto
&
server_param
=
_config
.
server_param
().
downpour_server_param
();
uint64_t
shard_num
=
FLAGS_pserver_sparse_table_shard_num
;
for
(
int
i
=
0
;
i
<
server_param
.
downpour_table_param_size
();
++
i
)
{
const
auto
&
table_param
=
server_param
.
downpour_table_param
(
i
);
if
(
table_param
.
table_id
()
==
table_id
)
{
shard_num
=
table_param
.
shard_num
();
break
;
}
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
shard_id
=
get_sparse_shard
(
shard_num
,
request_call_num
,
keys
[
i
]);
shard_sorted_kv_list
[
shard_id
].
push_back
({
keys
[
i
],
update_values
[
i
]});
}
auto
sparse_task_data
=
_sparse_task_pool
.
get
();
sparse_task_data
->
shared_data
.
resize
(
request_call_num
);
auto
async_task
=
new
SparseAsyncTask
(
sparse_task_data
,
table_id
,
push_timer
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
auto
&
sorted_kv_list
=
shard_sorted_kv_list
[
i
];
size_t
sorted_kv_size
=
sorted_kv_list
.
size
();
auto
&
shard_kv_data
=
async_task
->
data
()
->
shared_data
[
i
];
shard_kv_data
.
key_list
.
resize
(
sorted_kv_size
);
shard_kv_data
.
value_list
.
resize
(
sorted_kv_size
);
if
(
sorted_kv_size
==
0
)
{
shard_kv_data
.
kv_num
=
0
;
continue
;
}
uint32_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
shard_kv_data
.
key_list
[
kv_idx
]
=
sorted_kv_list
[
kv_idx
].
first
;
shard_kv_data
.
value_list
[
kv_idx
].
assign
(
(
const
char
*
)
sorted_kv_list
[
kv_idx
].
second
,
value_size
);
}
shard_kv_data
.
kv_num
=
sorted_kv_size
;
}
std
::
future
<
int
>
fut
=
async_task
->
get_future
();
_push_sparse_task_queue_map
[
table_id
]
->
Put
(
std
::
move
(
async_task
));
return
fut
;
}
void
BrpcPsClient
::
PushSparseTaskConsume
()
{
uint64_t
merge_size
=
FLAGS_pserver_push_sparse_merge_limit
;
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
task_list
;
size_t
request_call_num
=
_server_channels
.
size
();
::
ThreadPool
async_push_sparse_shard_threads
(
FLAGS_pserver_sparse_merge_thread
);
while
(
_running
)
{
auto
async_start_time_ms
=
butil
::
gettimeofday_ms
();
// 所有sparseTable的pushTask 进行处理
for
(
auto
&
push_sparse_task_itr
:
_push_sparse_task_queue_map
)
{
auto
table_id
=
push_sparse_task_itr
.
first
;
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
&
task_queue
=
push_sparse_task_itr
.
second
;
auto
queue_size
=
task_queue
->
Size
();
if
(
queue_size
==
0
)
{
continue
;
}
if
(
merge_size
>
0
&&
(
queue_size
<=
1
&&
_flushing
==
false
))
{
continue
;
}
++
_async_call_num
;
int
merge_count
=
0
;
for
(
size_t
i
=
0
;
i
<
task_list
.
size
();
++
i
)
{
if
(
task_list
[
i
]
->
data
())
{
_sparse_task_pool
.
push
(
task_list
[
i
]
->
data
());
}
}
auto
sparse_task_data
=
_sparse_task_pool
.
get
();
task_list
.
clear
();
int
cur_meger_size
=
task_queue
->
Size
();
// task_list[0] 为一个空SparseAsyncTask, 分shard异步merge结果存入此结构。
sparse_task_data
->
shared_data
.
resize
(
request_call_num
);
auto
push_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_sparse"
);
auto
async_task
=
new
SparseAsyncTask
(
sparse_task_data
,
table_id
,
push_timer
);
task_list
.
reserve
(
cur_meger_size
+
1
);
task_list
.
push_back
(
std
::
move
(
std
::
shared_ptr
<
SparseAsyncTask
>
(
async_task
)));
while
(
!
task_queue
->
Empty
()
&&
merge_count
<
cur_meger_size
)
{
++
merge_count
;
SparseAsyncTask
*
task
;
task_queue
->
Get
(
task
);
task_list
.
push_back
(
std
::
shared_ptr
<
SparseAsyncTask
>
(
task
));
}
_push_sparse_merge_count_map
[
table_id
]
+=
merge_count
;
// 达到或大于 merge_size发送, 发送过程中
std
::
vector
<
int
>
request_kv_num
(
request_call_num
,
0
);
if
(
_push_sparse_merge_count_map
[
table_id
]
>=
merge_size
||
_flushing
==
true
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
for_each
(
task_list
.
begin
()
+
1
,
task_list
.
end
(),
[
&
request_kv_num
,
request_call_num
,
closure
](
std
::
shared_ptr
<
SparseAsyncTask
>
&
task
)
{
closure
->
add_timer
(
task
->
timer
());
closure
->
add_promise
(
task
->
promise
());
});
CostTimer
merge_timer
(
"pserver_client_push_sparse_merge"
);
auto
rpc_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_sparse_rpc"
);
closure
->
add_timer
(
rpc_timer
);
std
::
vector
<
std
::
future
<
int
>>
merge_status
(
request_call_num
);
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
merge_status
[
shard_idx
]
=
async_push_sparse_shard_threads
.
enqueue
(
std
::
bind
(
&
BrpcPsClient
::
PushSparseAsyncShardPush
,
this
,
task_list
,
request_kv_num
,
table_id
,
shard_idx
,
closure
,
accessor
));
}
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
merge_status
[
shard_idx
].
wait
();
}
merge_status
.
clear
();
std
::
vector
<
std
::
future
<
int
>>
().
swap
(
merge_status
);
_push_sparse_merge_count_map
[
table_id
]
=
0
;
}
else
{
// 未达到阈值 只做多路归并
std
::
vector
<
std
::
future
<
int
>>
merge_status
(
request_call_num
);
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
merge_status
[
shard_idx
]
=
async_push_sparse_shard_threads
.
enqueue
(
std
::
bind
(
&
BrpcPsClient
::
PushSparseAsyncShardMerge
,
this
,
task_list
,
request_kv_num
,
table_id
,
shard_idx
,
accessor
));
}
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
merge_status
[
shard_idx
].
wait
();
}
// meger到task_list[0]
auto
async_task
=
new
SparseAsyncTask
(
*
(
task_list
[
0
].
get
()));
task_queue
->
Put
(
std
::
move
(
async_task
));
--
_async_call_num
;
merge_status
.
clear
();
std
::
vector
<
std
::
future
<
int
>>
().
swap
(
merge_status
);
}
}
auto
wait_ms
=
FLAGS_pserver_async_push_sparse_interval_ms
-
(
butil
::
gettimeofday_ms
()
-
async_start_time_ms
);
if
(
wait_ms
>
0
)
{
usleep
(
wait_ms
*
1000
);
}
}
}
void
sparse_local_merge
(
ValueAccessor
*
accessor
,
float
*
merge_data
,
const
float
*
another_data
)
{
size_t
col_num
=
accessor
->
GetAccessorInfo
().
update_dim
;
float
*
merge_data_shell
[
col_num
];
const
float
*
another_data_shell
[
col_num
];
for
(
size_t
i
=
0
;
i
<
col_num
;
++
i
)
{
merge_data_shell
[
i
]
=
merge_data
+
i
;
another_data_shell
[
i
]
=
another_data
+
i
;
}
accessor
->
Merge
(
merge_data_shell
,
another_data_shell
,
1
);
}
int
BrpcPsClient
::
PushSparseAsyncShardMerge
(
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
&
task_list
,
std
::
vector
<
int
>
&
request_kv_num
,
int
table_id
,
int
shard_idx
,
ValueAccessor
*
accessor
)
{
size_t
merged_kv_count
=
0
;
uint32_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
thread_local
std
::
vector
<
std
::
pair
<
uint64_t
,
const
float
*>>
sorted_kv_list
;
sorted_kv_list
.
clear
();
for
(
size_t
i
=
1
;
i
<
task_list
.
size
();
++
i
)
{
size_t
kv_num
=
task_list
[
i
]
->
data
()
->
shared_data
[
shard_idx
].
kv_num
;
auto
&
key_list
=
task_list
[
i
]
->
data
()
->
shared_data
[
shard_idx
].
key_list
;
auto
&
value_list
=
task_list
[
i
]
->
data
()
->
shared_data
[
shard_idx
].
value_list
;
for
(
size_t
j
=
0
;
j
<
kv_num
;
++
j
)
{
if
(
value_list
[
j
].
size
()
<
value_size
)
{
LOG
(
WARNING
)
<<
"value_list["
<<
j
<<
"]: "
<<
value_list
[
j
].
c_str
()
<<
"is invalid."
;
continue
;
}
char
*
task_data_ptr
=
const_cast
<
char
*>
(
value_list
[
j
].
data
());
sorted_kv_list
.
push_back
(
{
key_list
[
j
],
reinterpret_cast
<
float
*>
(
task_data_ptr
)});
}
}
// 按key排序&去重
std
::
sort
(
sorted_kv_list
.
begin
(),
sorted_kv_list
.
end
(),
[](
const
std
::
pair
<
uint64_t
,
const
float
*>
&
k1
,
const
std
::
pair
<
uint64_t
,
const
float
*>
&
k2
)
{
return
k1
.
first
<
k2
.
first
;
});
auto
&
async_task
=
task_list
[
0
];
size_t
sorted_kv_size
=
sorted_kv_list
.
size
();
auto
&
shard_kv_data
=
async_task
->
data
()
->
shared_data
[
shard_idx
];
shard_kv_data
.
key_list
.
resize
(
sorted_kv_size
);
shard_kv_data
.
value_list
.
resize
(
sorted_kv_size
);
// 将去重后数据写入分shard包
if
(
sorted_kv_size
==
0
)
{
shard_kv_data
.
kv_num
=
0
;
return
0
;
}
else
if
(
sorted_kv_size
==
1
)
{
shard_kv_data
.
kv_num
=
1
;
shard_kv_data
.
key_list
[
0
]
=
sorted_kv_list
[
0
].
first
;
shard_kv_data
.
value_list
[
0
].
assign
((
const
char
*
)(
sorted_kv_list
[
0
].
second
),
value_size
);
return
0
;
}
// 去重 本地merge
uint64_t
last_key
=
sorted_kv_list
[
0
].
first
;
const
float
*
last_value_data
=
sorted_kv_list
[
0
].
second
;
float
*
last_merge_data
=
NULL
;
std
::
shared_ptr
<
char
>
merger_buffer
(
new
char
[
value_size
],
array_deleter
<
char
>
());
for
(
size_t
kv_idx
=
1
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
while
(
kv_idx
<
sorted_kv_size
&&
last_key
==
sorted_kv_list
[
kv_idx
].
first
)
{
if
(
last_merge_data
==
NULL
)
{
last_merge_data
=
reinterpret_cast
<
float
*>
(
merger_buffer
.
get
());
memcpy
(
last_merge_data
,
last_value_data
,
value_size
);
}
sparse_local_merge
(
accessor
,
last_merge_data
,
sorted_kv_list
[
kv_idx
].
second
);
++
kv_idx
;
}
if
(
last_merge_data
!=
NULL
)
{
shard_kv_data
.
value_list
[
merged_kv_count
].
assign
(
(
const
char
*
)
last_merge_data
,
value_size
);
last_merge_data
=
NULL
;
}
else
{
shard_kv_data
.
value_list
[
merged_kv_count
].
assign
(
(
const
char
*
)
sorted_kv_list
[
kv_idx
-
1
].
second
,
value_size
);
}
shard_kv_data
.
key_list
[
merged_kv_count
++
]
=
last_key
;
if
(
kv_idx
<
sorted_kv_size
)
{
last_key
=
sorted_kv_list
[
kv_idx
].
first
;
last_value_data
=
sorted_kv_list
[
kv_idx
].
second
;
}
if
(
kv_idx
==
sorted_kv_size
-
1
)
{
shard_kv_data
.
value_list
[
merged_kv_count
].
assign
(
(
const
char
*
)
last_value_data
,
value_size
);
shard_kv_data
.
key_list
[
merged_kv_count
++
]
=
last_key
;
}
}
shard_kv_data
.
kv_num
=
merged_kv_count
;
return
0
;
}
int
BrpcPsClient
::
PushSparseAsyncShardPush
(
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
&
task_list
,
std
::
vector
<
int
>
&
request_kv_num
,
int
table_id
,
int
shard_idx
,
DownpourBrpcClosure
*
closure
,
ValueAccessor
*
accessor
)
{
PushSparseAsyncShardMerge
(
task_list
,
request_kv_num
,
table_id
,
shard_idx
,
accessor
);
size_t
merged_kv_count
=
task_list
[
0
]
->
data
()
->
shared_data
[
shard_idx
].
kv_num
;
auto
&
merged_key_list
=
task_list
[
0
]
->
data
()
->
shared_data
[
shard_idx
].
key_list
;
auto
&
merged_value_list
=
task_list
[
0
]
->
data
()
->
shared_data
[
shard_idx
].
value_list
;
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_TABLE
);
push_request
->
set_table_id
(
table_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
(
reinterpret_cast
<
char
*>
(
&
merged_kv_count
),
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
int
update_size
=
accessor
->
GetAccessorInfo
().
update_size
;
push_data
->
resize
(
merged_kv_count
*
(
sizeof
(
uint64_t
)
+
update_size
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
merged_key_list
.
data
(),
merged_kv_count
*
sizeof
(
uint64_t
));
push_data_ptr
+=
merged_kv_count
*
sizeof
(
uint64_t
);
for
(
size_t
i
=
0
;
i
<
merged_kv_count
;
++
i
)
{
const
char
*
task_data_ptr
=
merged_value_list
[
i
].
data
();
memcpy
(
push_data_ptr
,
(
float
*
)(
task_data_ptr
),
// NOLINT
update_size
);
push_data_ptr
+=
update_size
;
}
PsService_Stub
rpc_stub
(
GetSparseChannel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
rpc_stub
.
service
(
closure
->
cntl
(
shard_idx
),
closure
->
request
(
shard_idx
),
closure
->
response
(
shard_idx
),
closure
);
_push_sparse_merge_count_map
[
table_id
]
=
0
;
return
0
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
int
fea_dim
=
accessor
->
GetAccessorInfo
().
fea_dim
;
int
update_dim
=
accessor
->
GetAccessorInfo
().
update_dim
;
auto
push_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense"
);
auto
parse_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_parse"
);
int
push_dense_async_num
=
_push_dense_task_queue_map
[
table_id
]
->
Size
();
while
(
push_dense_async_num
>
FLAGS_pserver_max_async_call_num
)
{
// LOG(INFO) << "PushDense Waiting for async_call_num comsume,
// task_num:"
// << push_dense_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
usleep
(
5000
);
// 5ms
push_dense_async_num
=
_push_dense_task_queue_map
[
table_id
]
->
Size
();
}
auto
push_dense_timer
=
std
::
make_shared
<
CostTimer
>
(
"push_dense_put"
);
// auto dense_data = _dense_matrix_obj_pool.get();
auto
dense_data
=
std
::
make_shared
<
std
::
vector
<
float
>>
();
auto
async_task
=
new
DenseAsyncTask
(
dense_data
,
table_id
,
push_timer
);
size_t
request_call_num
=
_server_channels
.
size
();
uint32_t
num_per_shard
=
DenseDimPerShard
(
fea_dim
,
request_call_num
);
// 将region数据拷贝到转置矩阵中
async_task
->
data
()
->
resize
(
num_per_shard
*
request_call_num
*
update_dim
);
float
*
data
=
async_task
->
data
()
->
data
();
size_t
data_size
=
async_task
->
data
()
->
size
();
uint32_t
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
region_num
;
++
i
)
{
uint32_t
data_num
=
regions
[
i
].
size
/
sizeof
(
float
);
CHECK
(
pos
+
data_num
<=
data_size
)
<<
"invalid dense size, cur pos["
<<
pos
<<
"]"
<<
" data_num["
<<
data_num
<<
"] size["
<<
data_size
<<
"]"
;
const
float
*
region_data
=
(
const
float
*
)(
regions
[
i
].
data
);
memcpy
(
data
+
pos
,
region_data
,
regions
[
i
].
size
);
pos
+=
data_num
;
}
std
::
future
<
int
>
fut
=
async_task
->
get_future
();
_push_dense_task_queue_map
[
table_id
]
->
Put
(
std
::
move
(
async_task
));
return
fut
;
}
void
BrpcPsClient
::
PushDenseTaskConsume
()
{
uint64_t
merge_size
=
FLAGS_pserver_push_dense_merge_limit
;
static
bool
scale_gradient
=
FLAGS_pserver_scale_gradient_by_merge
;
::
ThreadPool
async_merge_dense_threads
(
10
);
while
(
_running
)
{
auto
async_start_time_ms
=
butil
::
gettimeofday_ms
();
for
(
auto
&
task_queue_itr
:
_push_dense_task_queue_map
)
{
auto
&
task_queue
=
task_queue_itr
.
second
;
auto
queue_size
=
task_queue
->
Size
();
if
(
queue_size
==
0
)
{
continue
;
}
if
(
queue_size
<=
merge_size
&&
_flushing
==
false
)
{
continue
;
}
++
_async_call_num
;
DenseAsyncTask
*
task
;
task_queue
->
Get
(
task
);
auto
*
accessor
=
GetTableAccessor
(
task
->
table_id
());
// 设置请求回调
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
auto
&
total_send_data_vec
=
*
(
task
->
data
());
float
*
total_send_data
=
reinterpret_cast
<
float
*>
(
total_send_data_vec
.
data
());
size_t
total_send_data_size
=
total_send_data_vec
.
size
();
{
CostTimer
merge_timer
(
"pserver_client_push_dense_merge"
);
uint32_t
merge_count
=
0
;
std
::
vector
<
std
::
future
<
int
>>
merge_status
(
merge_size
);
while
(
!
task_queue
->
Empty
()
&&
merge_count
<
merge_size
)
{
auto
*
async_task
=
new
DenseAsyncTask
();
task_queue
->
Get
(
async_task
);
closure
->
add_timer
(
async_task
->
timer
());
closure
->
add_promise
(
async_task
->
promise
());
merge_status
[
merge_count
]
=
async_merge_dense_threads
.
enqueue
([
closure
,
accessor
,
&
total_send_data
,
total_send_data_size
,
async_task
]()
->
int
{
auto
&
tmp_task_vec
=
*
(
async_task
->
data
());
const
float
*
merge_data
=
tmp_task_vec
.
data
();
accessor
->
Merge
(
&
total_send_data
,
&
merge_data
,
total_send_data_size
);
#pragma optimize("", off)
delete
async_task
;
#pragma optimize("", on)
return
0
;
});
++
merge_count
;
}
for
(
size_t
i
=
0
;
i
<
merge_count
;
++
i
)
{
merge_status
[
i
].
wait
();
}
VLOG
(
3
)
<<
"BrpcPsClient::PushDenseTaskConsume before merge "
"total_send_data[0]"
<<
total_send_data
[
0
]
<<
" total_send_data[-2]"
<<
total_send_data
[
total_send_data_size
-
2
]
<<
total_send_data
[
0
]
<<
" total_send_data[-1]"
<<
total_send_data
[
total_send_data_size
-
1
];
if
(
scale_gradient
&&
merge_count
>
1
)
{
Eigen
::
Map
<
Eigen
::
MatrixXf
>
mat
(
total_send_data
,
1
,
total_send_data_size
);
mat
*=
(
1.0
/
(
merge_count
+
1
));
}
VLOG
(
3
)
<<
"BrpcPsClient::PushDenseTaskConsume after merge "
"total_send_data[0]"
<<
total_send_data
[
0
]
<<
" total_send_data[-2]"
<<
total_send_data
[
total_send_data_size
-
2
]
<<
" total_send_data[-1]"
<<
total_send_data
[
total_send_data_size
-
1
]
<<
" merge_count "
<<
merge_count
;
}
std
::
shared_ptr
<
DenseAsyncTask
>
task_ptr
(
task
);
PushDenseRawGradient
(
task_ptr
,
total_send_data
,
total_send_data_size
,
closure
);
}
auto
wait_ms
=
FLAGS_pserver_async_push_dense_interval_ms
-
(
butil
::
gettimeofday_ms
()
-
async_start_time_ms
);
if
(
wait_ms
>
0
)
{
usleep
(
wait_ms
*
1000
);
}
}
}
void
BrpcPsClient
::
PushDenseRawGradient
(
std
::
shared_ptr
<
DenseAsyncTask
>
&
task
,
float
*
total_send_data
,
size_t
total_send_data_size
,
DownpourBrpcClosure
*
closure
)
{
auto
*
accessor
=
GetTableAccessor
(
task
->
table_id
());
size_t
request_call_num
=
_server_channels
.
size
();
// 将数据拷贝到请求buffer区
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_rpc"
);
closure
->
add_timer
(
timer
);
uint32_t
num_per_shard
=
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
request_call_num
);
auto
send_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_send"
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_DENSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
task
->
table_id
());
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
auto
*
push_data
=
closure
->
request
(
i
)
->
mutable_data
();
push_data
->
clear
();
push_data
->
resize
(
sizeof
(
uint32_t
)
+
num_per_shard
*
sizeof
(
float
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
&
num_per_shard
,
sizeof
(
uint32_t
));
memcpy
(
push_data_ptr
+
sizeof
(
uint32_t
),
total_send_data
+
i
*
num_per_shard
,
num_per_shard
*
sizeof
(
float
));
closure
->
cntl
(
i
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_ps_client.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace
brpc
{
class
Channel
;
class
Controller
;
}
// namespace brpc
namespace
google
{
namespace
protobuf
{
class
Closure
;
class
RpcController
;
}
// namespace protobuf
}
// namespace google
namespace
paddle
{
namespace
distributed
{
struct
Region
;
class
DownpourPsClientService
:
public
PsService
{
public:
DownpourPsClientService
()
{}
virtual
~
DownpourPsClientService
()
{}
virtual
int32_t
Configure
(
PSClient
*
client
,
size_t
rank_id
)
{
_client
=
client
;
_rank
=
rank_id
;
return
0
;
}
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
protected:
size_t
_rank
;
PSClient
*
_client
;
};
class
DownpourBrpcClosure
:
public
PSClientClosure
{
public:
DownpourBrpcClosure
(
size_t
num
,
PSClientCallBack
callback
)
:
PSClientClosure
(
callback
)
{
_waiting_num
=
num
;
_cntls
.
resize
(
num
);
_requests
.
resize
(
num
);
_responses
.
resize
(
num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
_cntls
[
i
].
reset
(
new
brpc
::
Controller
());
}
}
virtual
~
DownpourBrpcClosure
()
{}
void
Run
()
override
{
if
(
_waiting_num
.
fetch_sub
(
1
)
==
1
)
{
_callback
(
this
);
delete
this
;
}
}
PsRequestMessage
*
request
(
size_t
i
)
{
return
&
_requests
[
i
];
}
PsResponseMessage
*
response
(
size_t
i
)
{
return
&
_responses
[
i
];
}
brpc
::
Controller
*
cntl
(
size_t
i
)
{
return
_cntls
[
i
].
get
();
}
int
check_response
(
size_t
request_idx
,
int
cmd_id
);
int
check_save_response
(
size_t
request_idx
,
int
cmd_id
);
std
::
string
get_response
(
size_t
request_idx
,
int
cmd_id
);
private:
std
::
atomic
<
int32_t
>
_waiting_num
;
std
::
vector
<
PsRequestMessage
>
_requests
;
std
::
vector
<
PsResponseMessage
>
_responses
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Controller
>>
_cntls
;
};
struct
SharedSparsePushData
{
SharedSparsePushData
()
{}
~
SharedSparsePushData
()
noexcept
{}
size_t
kv_num
;
std
::
vector
<
uint64_t
>
key_list
;
std
::
vector
<
std
::
string
>
value_list
;
};
struct
SparsePushTaskData
{
std
::
vector
<
SharedSparsePushData
>
shared_data
;
// sparse数据按key hash分片
};
// push sparse 对象池
struct
SparseTaskPool
{
std
::
shared_ptr
<
SparsePushTaskData
>
get
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
_mutex
);
if
(
_pool
.
empty
())
{
return
std
::
make_shared
<
SparsePushTaskData
>
();
}
else
{
auto
ret
=
_pool
.
back
();
_pool
.
pop_back
();
return
ret
;
}
}
void
push
(
std
::
shared_ptr
<
SparsePushTaskData
>
data
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
_mutex
);
_pool
.
push_back
(
std
::
move
(
data
));
}
std
::
vector
<
std
::
shared_ptr
<
SparsePushTaskData
>>
_pool
;
std
::
mutex
_mutex
;
};
template
<
class
T
>
struct
array_deleter
{
void
operator
()(
T
*&
x
)
const
{
delete
[]
x
;
}
// NOLINT
};
class
BrpcPsClient
:
public
PSClient
{
public:
BrpcPsClient
()
{}
virtual
~
BrpcPsClient
()
{
if
(
_running
)
{
Flush
();
_running
=
false
;
}
if
(
_async_push_dense_thread
.
joinable
())
{
_async_push_dense_thread
.
join
();
}
if
(
_async_push_sparse_thread
.
joinable
())
{
_async_push_sparse_thread
.
join
();
}
if
(
_server_started
)
{
_server
.
Stop
(
1000
);
_server
.
Join
();
_server_started
=
false
;
}
}
virtual
int32_t
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
);
std
::
future
<
int32_t
>
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
override
;
std
::
future
<
int32_t
>
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Clear
()
override
;
std
::
future
<
int32_t
>
Clear
(
uint32_t
table_id
)
override
;
std
::
future
<
int32_t
>
StopServer
()
override
;
std
::
future
<
int32_t
>
StartProfiler
()
override
;
std
::
future
<
int32_t
>
StopProfiler
()
override
;
void
FinalizeWorker
()
override
;
virtual
std
::
future
<
int32_t
>
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
std
::
future
<
int32_t
>
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
std
::
future
<
int32_t
>
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
void
PushDenseTaskConsume
();
virtual
std
::
future
<
int32_t
>
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
);
virtual
std
::
future
<
int32_t
>
PullSparseParam
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
);
virtual
std
::
future
<
int32_t
>
PrintTableStat
(
uint32_t
table_id
);
virtual
std
::
future
<
int32_t
>
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
);
virtual
std
::
future
<
int32_t
>
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
uint64_t
>
*
keys
,
int
pserver_idx
);
virtual
std
::
future
<
int32_t
>
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
);
virtual
std
::
future
<
int32_t
>
Flush
();
std
::
future
<
int32_t
>
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
override
;
// for local save sparse
virtual
int32_t
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
);
std
::
future
<
int32_t
>
CacheShuffle
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
override
;
std
::
future
<
int32_t
>
CacheShuffleMultiTable
(
std
::
vector
<
int
>
tables
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
);
std
::
future
<
int32_t
>
SaveCache
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
GetCacheThreshold
(
uint32_t
table_id
,
double
&
cache_threshold
)
override
;
void
PrintQueueSize
();
void
PrintQueueSizeThread
();
protected:
virtual
size_t
GetServerNums
()
{
return
_server_channels
.
size
();
}
inline
brpc
::
Channel
*
GetSparseChannel
(
size_t
server_id
)
{
return
_server_channels
[
server_id
][
0
].
get
();
}
inline
brpc
::
Channel
*
GetDenseChannel
(
size_t
server_id
)
{
return
_server_channels
[
server_id
][
1
].
get
();
}
inline
brpc
::
Channel
*
GetCmdChannel
(
size_t
server_id
)
{
return
_server_channels
[
server_id
][
2
].
get
();
}
int32_t
Initialize
()
override
;
private:
inline
uint32_t
DenseDimPerShard
(
uint32_t
dense_dim_total
,
uint32_t
shard_num
)
{
return
dense_dim_total
/
shard_num
+
1
;
}
std
::
future
<
int32_t
>
SendCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>
&
param
);
std
::
future
<
int32_t
>
SendSaveCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>
&
param
);
bool
_running
=
false
;
bool
_flushing
=
false
;
std
::
atomic
<
uint32_t
>
_async_call_num
;
// 异步请求计数
// 异步push dense task
std
::
thread
_async_push_dense_thread
;
typedef
AsyncRequestTask
<
std
::
shared_ptr
<
std
::
vector
<
float
>>>
DenseAsyncTask
;
std
::
unordered_map
<
uint32_t
,
paddle
::
framework
::
Channel
<
DenseAsyncTask
*>>
_push_dense_task_queue_map
;
// 异步push sparse task
std
::
thread
_async_push_sparse_thread
;
typedef
AsyncRequestTask
<
std
::
shared_ptr
<
SparsePushTaskData
>>
SparseAsyncTask
;
std
::
unordered_map
<
uint32_t
,
paddle
::
framework
::
Channel
<
SparseAsyncTask
*>>
_push_sparse_task_queue_map
;
std
::
unordered_map
<
uint32_t
,
uint32_t
>
_push_sparse_merge_count_map
;
std
::
thread
_print_thread
;
int
PushSparseAsyncShardMerge
(
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
&
task_list
,
// NOLINT
std
::
vector
<
int
>
&
request_kv_num
,
int
table_id
,
int
shard_idx
,
// NOLINT
ValueAccessor
*
accessor
);
int
PushSparseAsyncShardPush
(
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
&
task_list
,
// NOLINT
std
::
vector
<
int
>
&
request_kv_num
,
int
table_id
,
int
shard_idx
,
// NOLINT
DownpourBrpcClosure
*
closure
,
ValueAccessor
*
accessor
);
SparseTaskPool
_sparse_task_pool
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_client_channels
;
// client2client
std
::
vector
<
std
::
array
<
std
::
shared_ptr
<
brpc
::
Channel
>
,
3
>>
_server_channels
;
// client2server
std
::
future
<
int32_t
>
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
done
)
override
;
std
::
future
<
int32_t
>
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
override
;
std
::
future
<
int32_t
>
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
override
;
std
::
future
<
int32_t
>
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
override
;
std
::
future
<
int32_t
>
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
override
;
void
PushSparseTaskConsume
();
private:
int32_t
StartClientService
();
void
PushDenseRawGradient
(
std
::
shared_ptr
<
DenseAsyncTask
>
&
task
,
// NOLINT
float
*
total_send_data
,
size_t
total_send_data_size
,
DownpourBrpcClosure
*
closure
);
float
_mae
=
0
;
float
_mse
=
0
;
uint16_t
_push_times
=
0
;
brpc
::
Server
_server
;
DownpourPsClientService
_service
;
bool
_server_started
=
false
;
std
::
atomic_uint
grad_num_
{
0
};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "butil/object_pool.h"
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace
google
{
namespace
protobuf
{
class
Closure
;
class
RpcController
;
}
// namespace protobuf
}
// namespace google
DEFINE_int32
(
pserver_timeout_ms_s2s
,
10000
,
"pserver request server timeout_ms"
);
DEFINE_int32
(
pserver_connect_timeout_ms_s2s
,
10000
,
"pserver connect server timeout_ms"
);
DEFINE_string
(
pserver_connection_type_s2s
,
"pooled"
,
"pserver connection_type[pooled:single]"
);
namespace
paddle
{
namespace
distributed
{
int32_t
BrpcPsServer
::
Initialize
()
{
auto
&
service_config
=
_config
.
downpour_server_param
().
service_param
();
if
(
!
service_config
.
has_service_class
())
{
LOG
(
ERROR
)
<<
"miss service_class in ServerServiceParameter"
;
return
-
1
;
}
auto
*
service
=
CREATE_PSCORE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
if
(
service
==
NULL
)
{
LOG
(
ERROR
)
<<
"service is unregistered, service_name:"
<<
service_config
.
service_class
();
return
-
1
;
}
_service
.
reset
(
service
);
if
(
service
->
Configure
(
this
)
!=
0
||
service
->
Initialize
()
!=
0
)
{
LOG
(
ERROR
)
<<
"service initialize failed, service_name:"
<<
service_config
.
service_class
();
return
-
1
;
}
if
(
_server
.
AddService
(
service
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
)
!=
0
)
{
LOG
(
ERROR
)
<<
"service add to brpc failed, service:"
<<
service_config
.
service_class
();
return
-
1
;
}
return
0
;
}
uint64_t
BrpcPsServer
::
Start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
string
ip_port
=
ip
+
":"
+
std
::
to_string
(
port
);
VLOG
(
0
)
<<
"running server with rank id: "
<<
_rank
<<
", endpoint: "
<<
ip_port
;
brpc
::
ServerOptions
options
;
int
num_threads
=
std
::
thread
::
hardware_concurrency
();
auto
trainers
=
_environment
->
GetTrainers
();
options
.
num_threads
=
trainers
>
num_threads
?
trainers
:
num_threads
;
if
(
_server
.
Start
(
ip_port
.
c_str
(),
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"BrpcPsServer start failed, ip_port= "
<<
ip_port
<<
" , Try Again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
_server
.
Start
(
int_ip_port
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPsServer start failed, ip_port= "
<<
int_ip_port
;
return
0
;
}
}
_environment
->
RegistePsServer
(
ip
,
port
,
_rank
);
cv_
.
wait
(
lock
,
[
&
]
{
return
stoped_
;
});
PSHost
host
;
host
.
ip
=
ip
;
host
.
port
=
port
;
host
.
rank
=
_rank
;
return
host
.
rank
;
}
int32_t
BrpcPsServer
::
StartS2S
()
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
FLAGS_pserver_timeout_ms_s2s
;
options
.
connection_type
=
FLAGS_pserver_connection_type_s2s
;
options
.
connect_timeout_ms
=
FLAGS_pserver_connect_timeout_ms_s2s
;
options
.
max_retry
=
3
;
std
::
vector
<
PSHost
>
pserver_list
=
_environment
->
GetPsServers
();
_pserver_channels
.
resize
(
pserver_list
.
size
());
VLOG
(
2
)
<<
"pserver start s2s server_list size: "
<<
_pserver_channels
.
size
();
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
for
(
size_t
i
=
0
;
i
<
pserver_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
pserver_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
pserver_list
[
i
].
port
));
_pserver_channels
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
_pserver_channels
[
i
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"pserver connect to pserver:"
<<
server_ip_port
<<
" Failed!"
;
}
os
<<
server_ip_port
<<
","
;
}
LOG
(
INFO
)
<<
"pserver connect success: "
<<
os
.
str
();
return
0
;
}
std
::
future
<
int32_t
>
BrpcPsServer
::
SendPServer2PServerMsg
(
int
msg_type
,
int
to_pserver_id
,
const
std
::
string
&
msg
)
{
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
static_cast
<
size_t
>
(
to_pserver_id
)
>=
_pserver_channels
.
size
())
{
LOG
(
FATAL
)
<<
"to_pserver_id is out of range pservers, which size is "
<<
_pserver_channels
.
size
();
promise
->
set_value
(
-
1
);
return
fut
;
}
auto
*
closure
=
new
DownpourPServerBrpcClosure
(
1
,
[
msg_type
](
void
*
done
)
{
auto
*
closure
=
(
DownpourPServerBrpcClosure
*
)
done
;
int32_t
ret
=
closure
->
check_response
(
0
,
msg_type
+
1000
);
closure
->
set_promise_value
(
ret
);
});
closure
->
add_promise
(
promise
);
closure
->
request
(
0
)
->
set_cmd_id
(
101
);
closure
->
request
(
0
)
->
set_client_id
(
_rank
);
closure
->
request
(
0
)
->
set_table_id
(
0
);
closure
->
request
(
0
)
->
set_data
(
msg
);
PsService_Stub
rpc_stub
(
_pserver_channels
[
to_pserver_id
].
get
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
int32_t
BrpcPsServer
::
ReceiveFromPServer
(
int
msg_type
,
int
pserver_id
,
const
std
::
string
&
msg
)
{
if
(
msg
.
length
()
==
0
)
{
LOG
(
WARNING
)
<<
"SERVER>>RESPONSE>>msg = 0 Finish S2S Response"
;
return
0
;
}
paddle
::
framework
::
BinaryArchive
ar
;
ar
.
SetReadBuffer
(
const_cast
<
char
*>
(
msg
.
c_str
()),
msg
.
length
(),
nullptr
);
if
(
ar
.
Cursor
()
==
ar
.
Finish
())
{
LOG
(
WARNING
)
<<
"SERVER>>RESPONSE ar = 0>> Finish S2S Response"
;
return
0
;
}
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
data
;
while
(
ar
.
Cursor
()
<
ar
.
Finish
())
{
data
.
push_back
(
ar
.
Get
<
std
::
pair
<
uint64_t
,
std
::
string
>>
());
}
CHECK
(
ar
.
Cursor
()
==
ar
.
Finish
());
this
->
_shuffled_ins
->
Write
(
std
::
move
(
data
));
return
0
;
}
int32_t
BrpcPsServer
::
Port
()
{
return
_server
.
listen_address
().
port
;
}
int32_t
BrpcPsService
::
Initialize
()
{
_is_initialize_shard_info
=
false
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
BrpcPsService
::
StopServer
;
_service_handler_map
[
PS_PULL_DENSE_TABLE
]
=
&
BrpcPsService
::
PullDense
;
_service_handler_map
[
PS_PUSH_DENSE_TABLE
]
=
&
BrpcPsService
::
PushDense
;
_service_handler_map
[
PS_PULL_SPARSE_TABLE
]
=
&
BrpcPsService
::
PullSparse
;
_service_handler_map
[
PS_PUSH_SPARSE_TABLE
]
=
&
BrpcPsService
::
PushSparse
;
_service_handler_map
[
PS_SAVE_ONE_TABLE
]
=
&
BrpcPsService
::
SaveOneTable
;
_service_handler_map
[
PS_SAVE_ALL_TABLE
]
=
&
BrpcPsService
::
SaveAllTable
;
_service_handler_map
[
PS_SHRINK_TABLE
]
=
&
BrpcPsService
::
ShrinkTable
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
BrpcPsService
::
LoadOneTable
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
BrpcPsService
::
LoadAllTable
;
_service_handler_map
[
PS_CLEAR_ONE_TABLE
]
=
&
BrpcPsService
::
ClearOneTable
;
_service_handler_map
[
PS_CLEAR_ALL_TABLE
]
=
&
BrpcPsService
::
ClearAllTable
;
_service_handler_map
[
PS_PUSH_DENSE_PARAM
]
=
&
BrpcPsService
::
PushDenseParam
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
BrpcPsService
::
PrintTableStat
;
_service_handler_map
[
PS_PULL_GEO_PARAM
]
=
&
BrpcPsService
::
PullGeoParam
;
_service_handler_map
[
PS_PUSH_SPARSE_PARAM
]
=
&
BrpcPsService
::
PushSparseParam
;
_service_handler_map
[
PS_BARRIER
]
=
&
BrpcPsService
::
Barrier
;
_service_handler_map
[
PS_START_PROFILER
]
=
&
BrpcPsService
::
StartProfiler
;
_service_handler_map
[
PS_STOP_PROFILER
]
=
&
BrpcPsService
::
StopProfiler
;
_service_handler_map
[
PS_PUSH_GLOBAL_STEP
]
=
&
BrpcPsService
::
PushGlobalStep
;
// for save cache
_service_handler_map
[
PS_SAVE_ONE_CACHE_TABLE
]
=
&
BrpcPsService
::
SaveCacheTable
;
_service_handler_map
[
PS_GET_CACHE_THRESHOLD
]
=
&
BrpcPsService
::
GetCacheThreshold
;
_service_handler_map
[
PS_CACHE_SHUFFLE
]
=
&
BrpcPsService
::
CacheShuffle
;
auto
&
profiler
=
CostProfiler
::
instance
();
profiler
.
register_profiler
(
"pserver_server_pull_dense"
);
profiler
.
register_profiler
(
"pserver_server_push_dense"
);
profiler
.
register_profiler
(
"pserver_server_pull_sparse"
);
profiler
.
register_profiler
(
"pserver_server_push_sparse"
);
// shard初始化,server启动后才可从env获取到server_list的shard信息
InitializeShardInfo
();
return
0
;
}
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t
BrpcPsService
::
InitializeShardInfo
()
{
if
(
!
_is_initialize_shard_info
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
_initialize_shard_mutex
);
if
(
_is_initialize_shard_info
)
{
return
0
;
}
size_t
shard_num
=
_server
->
Environment
()
->
GetPsServers
().
size
();
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
itr
:
table_map
)
{
itr
.
second
->
SetShard
(
_rank
,
shard_num
);
}
_is_initialize_shard_info
=
true
;
}
return
0
;
}
void
BrpcPsService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
log_label
(
"ReceiveCmd-"
);
if
(
!
request
->
has_table_id
())
{
set_response_code
(
*
response
,
-
1
,
"PsRequestMessage.tabel_id is required"
);
return
;
}
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
auto
*
table
=
_server
->
GetTable
(
request
->
table_id
());
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_base
);
if
(
request
->
cmd_id
()
<
100
)
{
auto
itr
=
_service_handler_map
.
find
(
request
->
cmd_id
());
if
(
itr
==
_service_handler_map
.
end
())
{
std
::
string
err_msg
(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"
);
err_msg
.
append
(
std
::
to_string
(
request
->
cmd_id
()));
set_response_code
(
*
response
,
-
1
,
err_msg
.
c_str
());
return
;
}
serviceHandlerFunc
handler_func
=
itr
->
second
;
int
service_ret
=
(
this
->*
handler_func
)(
table
,
*
request
,
*
response
,
cntl
);
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
service_ret
);
response
->
set_err_msg
(
"server internal error"
);
}
}
else
{
int
service_ret
=
_server
->
HandlePServer2PServerMsg
(
request
->
cmd_id
(),
request
->
client_id
(),
request
->
data
());
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
-
1
);
response
->
set_err_msg
(
"handle_pserver2pserver_msg failed"
);
}
}
}
int32_t
BrpcPsService
::
PullDense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PullDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 1 for num of dense"
);
return
0
;
}
CostTimer
timer
(
"pserver_server_pull_dense"
);
uint32_t
num
=
*
(
const
uint32_t
*
)
request
.
params
(
0
).
c_str
();
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
res_data
->
resize
(
num
*
table
->
ValueAccesor
()
->
GetAccessorInfo
().
select_size
/
sizeof
(
float
));
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
pull_context
.
values
=
res_data
->
data
();
table_context
.
num
=
num
;
table
->
Pull
(
table_context
);
// table->PullDense(res_data->data(), num);
cntl
->
response_attachment
().
append
((
char
*
)(
res_data
->
data
()),
res_data
->
size
()
*
sizeof
(
float
));
butil
::
return_object
(
res_data
);
return
0
;
}
int32_t
BrpcPsService
::
PushDenseParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PushDenseParam"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_buffer
;
auto
&
req_io_buffer
=
cntl
->
request_attachment
();
auto
req_buffer_size
=
req_io_buffer
.
size
();
if
(
req_buffer_size
<
1
)
{
set_response_code
(
response
,
-
1
,
"req attachment is empty"
);
return
0
;
}
push_buffer
.
resize
(
0
);
push_buffer
.
reserve
(
req_buffer_size
);
const
char
*
data
=
(
const
char
*
)
cntl
->
request_attachment
().
fetch
(
const_cast
<
char
*>
(
push_buffer
.
data
()),
req_buffer_size
);
uint32_t
num
=
*
(
const
uint32_t
*
)
data
;
const
float
*
values
=
(
const
float
*
)(
data
+
sizeof
(
uint32_t
));
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
values
;
table_context
.
push_context
.
is_param
=
true
;
table_context
.
num
=
num
;
// if (table->PushDenseParam(values, num) != 0) {
if
(
table
->
Push
(
table_context
)
!=
0
)
{
set_response_code
(
response
,
-
1
,
"PushDenseParam failed"
);
}
return
0
;
}
int32_t
BrpcPsService
::
PushDense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PushDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
req_buffer_size
=
request
.
data
().
size
();
if
(
req_buffer_size
<
1
)
{
// set_response_code(response, 0, "push dense data is empty");
return
0
;
}
CostTimer
timer
(
"pserver_server_push_dense"
);
/*
Push Content:
|--num--|---valuesData---|
|--4B---|----------------|
*/
uint32_t
num
=
*
(
const
uint32_t
*
)(
request
.
data
().
data
());
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
(
const
float
*
)(
request
.
data
().
data
()
+
sizeof
(
uint32_t
));
table_context
.
num
=
num
;
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if
(
table
->
Push
(
table_context
)
!=
0
)
{
// if (table->PushDense(values, num) != 0) {
set_response_code
(
response
,
-
1
,
"PushDense failed"
);
}
return
0
;
}
int32_t
BrpcPsService
::
Barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
auto
trainer_id
=
request
.
client_id
();
auto
barrier_type
=
request
.
params
(
0
);
table
->
Barrier
(
trainer_id
,
barrier_type
);
return
0
;
}
int32_t
BrpcPsService
::
PushSparseParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PushSparseParam"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
push_data
=
request
.
data
();
if
(
push_data
.
size
()
<
1
)
{
// set_response_code(response, 0, "push sparse data is empty");
return
0
;
}
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
uint32_t
num
=
*
(
uint32_t
*
)(
request
.
params
(
0
).
c_str
());
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const
uint64_t
*
keys
=
(
const
uint64_t
*
)
push_data
.
data
();
const
float
*
values
=
(
const
float
*
)(
push_data
.
data
()
+
sizeof
(
uint64_t
)
*
num
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
keys
;
table_context
.
push_context
.
values
=
values
;
table_context
.
push_context
.
is_param
=
true
;
table_context
.
num
=
num
;
// if (table->PushSparseParam(keys, values, num) != 0) {
if
(
table
->
Push
(
table_context
)
!=
0
)
{
set_response_code
(
response
,
-
1
,
"PushSparseParam error"
);
}
return
0
;
}
int32_t
BrpcPsService
::
PullGeoParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_geo_param"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_sparse_request_buffer
;
auto
trainer_id
=
request
.
client_id
();
std
::
vector
<
float
>
values
;
std
::
vector
<
uint64_t
>
ids
;
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
pull_context
.
geo_pull_keys
=
&
ids
;
table_context
.
pull_context
.
geo_pull_values
=
&
values
;
table_context
.
trainer_id
=
trainer_id
;
table
->
Pull
(
table_context
);
// table->PullGeoParam(trainer_id, &values, &ids);
uint32_t
num
=
ids
.
size
();
cntl
->
response_attachment
().
append
((
char
*
)(
&
num
),
sizeof
(
uint32_t
));
cntl
->
response_attachment
().
append
((
char
*
)
ids
.
data
(),
ids
.
size
()
*
sizeof
(
uint64_t
));
cntl
->
response_attachment
().
append
((
char
*
)
values
.
data
(),
values
.
size
()
*
sizeof
(
float
));
return
0
;
}
int32_t
BrpcPsService
::
PullSparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PullSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
req_io_buffer
=
cntl
->
request_attachment
();
auto
req_buffer_size
=
req_io_buffer
.
size
();
if
(
req_buffer_size
<
1
)
{
set_response_code
(
response
,
-
1
,
"req attachment is empty"
);
return
0
;
}
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
CostTimer
timer
(
"pserver_server_pull_sparse"
);
uint32_t
num
=
*
(
uint32_t
*
)(
request
.
params
(
0
).
c_str
());
auto
dim
=
table
->
ValueAccesor
()
->
GetAccessorInfo
().
select_dim
;
thread_local
std
::
string
req_buffer
;
req_buffer
.
reserve
(
req_buffer_size
);
const
void
*
data
=
cntl
->
request_attachment
().
fetch
(
const_cast
<
char
*>
(
req_buffer
.
data
()),
req_buffer_size
);
auto
value
=
PullSparseValue
(
num
,
dim
);
value
.
DeserializeFromBytes
(
const_cast
<
void
*>
(
data
));
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
res_data
->
resize
(
num
*
dim
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
pull_context
.
pull_value
=
value
;
table_context
.
pull_context
.
values
=
res_data
->
data
();
table
->
Pull
(
table_context
);
// table->PullSparse(res_data->data(), value);
cntl
->
response_attachment
().
append
((
char
*
)(
res_data
->
data
()),
res_data
->
size
()
*
sizeof
(
float
));
butil
::
return_object
(
res_data
);
return
0
;
}
int32_t
BrpcPsService
::
PushSparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PushSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
push_data
=
request
.
data
();
if
(
push_data
.
size
()
<
1
)
{
// set_response_code(response, 0, "push sparse data is empty");
return
0
;
}
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
CostTimer
timer
(
"pserver_server_push_sparse"
);
uint32_t
num
=
*
(
uint32_t
*
)(
request
.
params
(
0
).
c_str
());
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
(
const
uint64_t
*
)
push_data
.
data
();
table_context
.
push_context
.
values
=
(
const
float
*
)(
push_data
.
data
()
+
sizeof
(
uint64_t
)
*
num
);
table_context
.
num
=
num
;
// const uint64_t *keys = (const uint64_t *)push_data.data();
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if
(
table
->
Push
(
table_context
)
!=
0
)
{
// if (table->PushSparse(keys, values, num) != 0) {
set_response_code
(
response
,
-
1
,
"PushSparse error"
);
}
return
0
;
}
int32_t
BrpcPsService
::
PrintTableStat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
std
::
pair
<
int64_t
,
int64_t
>
ret
=
table
->
PrintTableStat
();
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
ret
.
first
<<
ret
.
second
;
std
::
string
table_info
(
ar
.
Buffer
(),
ar
.
Length
());
response
.
set_data
(
table_info
);
return
0
;
}
int32_t
BrpcPsService
::
LoadOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 2 for path & load_param"
);
return
-
1
;
}
if
(
table
->
Load
(
request
.
params
(
0
),
request
.
params
(
1
))
!=
0
)
{
set_response_code
(
response
,
-
1
,
"table load failed"
);
return
-
1
;
}
return
0
;
}
int32_t
BrpcPsService
::
LoadAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
&
itr
:
table_map
)
{
if
(
LoadOneTable
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
LOG
(
ERROR
)
<<
"load table["
<<
itr
.
first
<<
"] failed"
;
return
-
1
;
}
}
return
0
;
}
int32_t
BrpcPsService
::
SaveOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 2, path&mode"
);
return
-
1
;
}
table
->
Flush
();
int32_t
feasign_size
=
0
;
VLOG
(
3
)
<<
"save table "
<<
request
.
params
(
0
)
<<
" "
<<
request
.
params
(
1
);
feasign_size
=
table
->
Save
(
request
.
params
(
0
),
request
.
params
(
1
));
if
(
feasign_size
<
0
)
{
set_response_code
(
response
,
-
1
,
"table save failed"
);
return
-
1
;
}
return
feasign_size
;
}
int32_t
BrpcPsService
::
SaveAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
int32_t
feasign_size
=
0
;
for
(
auto
&
itr
:
table_map
)
{
feasign_size
=
SaveOneTable
(
itr
.
second
.
get
(),
request
,
response
,
cntl
);
if
(
feasign_size
<
0
)
{
LOG
(
ERROR
)
<<
"save table["
<<
itr
.
first
<<
"] failed"
;
return
-
1
;
}
}
return
0
;
}
int32_t
BrpcPsService
::
SaveCacheTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 3, path&mode"
);
return
-
1
;
}
table
->
Flush
();
int32_t
feasign_size
=
0
;
// if (_server->_shuffled_ins->size() <= 0) {
// LOG(WARNING) << "shuffled ins size <= 0";
//}
feasign_size
=
table
->
SaveCache
(
request
.
params
(
0
),
request
.
params
(
1
),
_server
->
_shuffled_ins
);
if
(
feasign_size
<
0
)
{
set_response_code
(
response
,
-
1
,
"table save failed"
);
return
-
1
;
}
return
feasign_size
;
}
int32_t
BrpcPsService
::
CacheShuffle
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
// start cache shuffle
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
3
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 3, "
"path&mode&cache_threshold"
);
return
-
1
;
}
table
->
Flush
();
double
cache_threshold
=
std
::
stod
(
request
.
params
(
2
));
LOG
(
INFO
)
<<
"cache threshold for cache shuffle: "
<<
cache_threshold
;
// auto shuffled_ins = paddle::ps::make_channel<std::pair<uint64_t,
// std::string>>();
// shuffled_ins->set_block_size(80000);
_server
->
StartS2S
();
std
::
function
<
std
::
future
<
int32_t
>
(
int
msg_type
,
int
to_pserver_id
,
const
std
::
string
&
msg
)
>
send_msg_func
=
[
this
](
int
msg_type
,
int
to_pserver_id
,
const
std
::
string
&
msg
)
->
std
::
future
<
int32_t
>
{
return
this
->
_server
->
SendPServer2PServerMsg
(
msg_type
,
to_pserver_id
,
msg
);
};
std
::
vector
<
Table
*>
table_ptrs
;
for
(
int
i
=
3
;
i
<
request
.
params_size
();
++
i
)
{
int
table_id
=
std
::
stoi
(
request
.
params
(
i
));
Table
*
table_ptr
=
_server
->
GetTable
(
table_id
);
table_ptrs
.
push_back
(
table_ptr
);
}
if
(
table_ptrs
.
empty
())
{
table_ptrs
.
push_back
(
table
);
}
table
->
CacheShuffle
(
request
.
params
(
0
),
request
.
params
(
1
),
cache_threshold
,
send_msg_func
,
_server
->
_shuffled_ins
,
table_ptrs
);
return
0
;
}
int32_t
BrpcPsService
::
GetCacheThreshold
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
table
->
Flush
();
double
cache_threshold
=
0.0
;
cache_threshold
=
table
->
GetCacheThreshold
();
if
(
cache_threshold
<
0
)
{
LOG
(
WARNING
)
<<
"wrong threshold: "
<<
cache_threshold
;
}
std
::
stringstream
ss
;
ss
<<
std
::
setprecision
(
15
)
<<
cache_threshold
;
std
::
string
cache_threshold_str
=
ss
.
str
();
response
.
set_data
(
cache_threshold_str
);
return
0
;
}
int32_t
BrpcPsService
::
ShrinkTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 1, threshold"
);
return
-
1
;
}
table
->
Flush
();
if
(
table
->
Shrink
(
request
.
params
(
0
))
!=
0
)
{
set_response_code
(
response
,
-
1
,
"table shrink failed"
);
return
-
1
;
}
VLOG
(
3
)
<<
"Pserver Shrink Finished"
;
return
0
;
}
int32_t
BrpcPsService
::
ClearOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
table
->
Flush
();
table
->
Clear
();
return
0
;
}
int32_t
BrpcPsService
::
ClearAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
&
itr
:
table_map
)
{
if
(
ClearOneTable
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
return
-
1
;
}
}
return
0
;
}
int32_t
BrpcPsService
::
StopServer
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
*
p_server
=
_server
;
std
::
thread
t_stop
([
p_server
]()
{
p_server
->
Stop
();
VLOG
(
3
)
<<
"Server Stoped"
;
});
t_stop
.
detach
();
return
0
;
}
int32_t
BrpcPsService
::
StopProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"server_%s_profile"
,
_rank
));
return
0
;
}
int32_t
BrpcPsService
::
StartProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
return
0
;
}
int32_t
BrpcPsService
::
PushGlobalStep
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
);
auto
req_buffer_size
=
request
.
data
().
size
();
if
(
req_buffer_size
<
1
)
{
set_response_code
(
response
,
0
,
"run_program data is empty"
);
return
0
;
}
const
int64_t
*
values
=
(
const
int64_t
*
)(
request
.
data
().
data
()
+
sizeof
(
uint32_t
));
auto
trainer_id
=
request
.
client_id
();
TableContext
context
;
context
.
trainer_id
=
trainer_id
;
context
.
push_context
.
push_steps
=
values
;
// if (table->PushDense(values, trainer_id) != 0) {
if
(
table
->
Push
(
context
)
!=
0
)
{
set_response_code
(
response
,
-
1
,
"run_program failed"
);
}
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_ps_server.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/server.h"
namespace
brpc
{
class
Controller
;
}
// namespace brpc
namespace
google
{
namespace
protobuf
{
class
Closure
;
class
RpcController
;
}
// namespace protobuf
}
// namespace google
namespace
paddle
{
namespace
distributed
{
class
PsRequestMessage
;
class
PsResponseMessage
;
class
Table
;
class
BrpcPsServer
:
public
PSServer
{
public:
BrpcPsServer
()
{}
virtual
~
BrpcPsServer
()
{}
virtual
uint64_t
Start
(
const
std
::
string
&
ip
,
uint32_t
port
);
virtual
int32_t
Stop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
stoped_
=
true
;
cv_
.
notify_all
();
_server
.
Stop
(
1000
);
_server
.
Join
();
return
0
;
}
int32_t
Port
();
virtual
int32_t
StartS2S
()
override
;
virtual
::
std
::
future
<
int32_t
>
SendPServer2PServerMsg
(
int
msg_type
,
int
to_pserver_id
,
const
std
::
string
&
msg
)
override
;
virtual
int32_t
ReceiveFromPServer
(
int
msg_type
,
int
pserver_id
,
const
std
::
string
&
msg
)
override
;
private:
virtual
int32_t
Initialize
();
mutable
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
bool
stoped_
=
false
;
brpc
::
Server
_server
;
std
::
shared_ptr
<
PsBaseService
>
_service
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
};
class
BrpcPsService
;
typedef
int32_t
(
BrpcPsService
::*
serviceHandlerFunc
)(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
class
BrpcPsService
:
public
PsBaseService
{
public:
virtual
int32_t
Initialize
()
override
;
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
private:
int32_t
InitializeShardInfo
();
int32_t
PullDense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PushDense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PushDenseParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PushSparseParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PullSparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PullGeoParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
Barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PushSparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
LoadOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
LoadAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
SaveOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
SaveAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
ShrinkTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
ClearOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
ClearAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
StopServer
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
StartProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
StopProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PrintTableStat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PushGlobalStep
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
CacheShuffle
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
SaveCacheTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
GetCacheThreshold
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
bool
_is_initialize_shard_info
;
std
::
mutex
_initialize_shard_mutex
;
std
::
unordered_map
<
int32_t
,
serviceHandlerFunc
>
_service_handler_map
;
std
::
unordered_map
<
int32_t
,
serviceHandlerFunc
>
_msg_handler_map
;
std
::
vector
<
float
>
_ori_values
;
};
class
DownpourPServerBrpcClosure
:
public
PServerClosure
{
public:
DownpourPServerBrpcClosure
(
size_t
num
,
PServerCallBack
callback
)
:
PServerClosure
(
callback
)
{
_waiting_num
=
num
;
_cntls
.
resize
(
num
);
_requests
.
resize
(
num
);
_responses
.
resize
(
num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
_cntls
[
i
].
reset
(
new
brpc
::
Controller
());
}
}
virtual
~
DownpourPServerBrpcClosure
()
{}
virtual
void
Run
()
override
{
if
(
_waiting_num
.
fetch_sub
(
1
)
==
1
)
{
_callback
(
this
);
delete
this
;
}
}
PsRequestMessage
*
request
(
size_t
i
)
{
return
&
_requests
[
i
];
}
PsResponseMessage
*
response
(
size_t
i
)
{
return
&
_responses
[
i
];
}
brpc
::
Controller
*
cntl
(
size_t
i
)
{
return
_cntls
[
i
].
get
();
}
int
check_response
(
size_t
request_idx
,
int
cmd_id
)
{
return
1
;
}
int
check_save_response
(
size_t
request_idx
,
int
cmd_id
)
{
return
1
;
}
private:
std
::
atomic
<
int32_t
>
_waiting_num
;
std
::
vector
<
PsRequestMessage
>
_requests
;
std
::
vector
<
PsResponseMessage
>
_responses
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Controller
>>
_cntls
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_utils.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include <arpa/inet.h>
#include <netdb.h>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
distributed
{
framework
::
proto
::
VarType
::
Type
VarMessageToVarType
(
VariableMessage
::
Type
type
)
{
switch
(
type
)
{
case
VariableMessage
::
FP32
:
return
framework
::
proto
::
VarType
::
FP32
;
// NOLINT
case
VariableMessage
::
FP64
:
return
framework
::
proto
::
VarType
::
FP64
;
// NOLINT
case
VariableMessage
::
INT32
:
return
framework
::
proto
::
VarType
::
INT32
;
// NOLINT
case
VariableMessage
::
INT64
:
return
framework
::
proto
::
VarType
::
INT64
;
// NOLINT
case
VariableMessage
::
BOOL
:
return
framework
::
proto
::
VarType
::
BOOL
;
// NOLINT
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"VarMessageToVarType:Unsupported type %d"
,
type
));
}
}
void
SerializeToMultiVarMsgAndIOBuf
(
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_name_val
,
const
std
::
vector
<
std
::
string
>&
recv_var_name_val
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
MultiVarMsg
*
request
,
butil
::
IOBuf
*
iobuf
)
{
// 1. message_name
request
->
set_message_name
(
message_name
);
// 2. var_names
for
(
auto
&
send_var_name
:
send_var_name_val
)
{
request
->
add_send_var_names
(
send_var_name
);
}
for
(
auto
&
recv_var_name
:
recv_var_name_val
)
{
request
->
add_recv_var_names
(
recv_var_name
);
}
// 3. VarMessage
for
(
auto
&
send_var_name
:
send_var_name_val
)
{
auto
*
send_var_msg
=
request
->
add_var_messages
();
butil
::
IOBuf
temp_iobuf
;
send_var_msg
->
set_varname
(
send_var_name
);
framework
::
Variable
*
var
=
scope
->
FindVar
(
send_var_name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
SerializeLodTensor
(
var
,
ctx
,
send_var_msg
,
&
temp_iobuf
);
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
SerializeSelectedRows
(
var
,
ctx
,
send_var_msg
,
&
temp_iobuf
);
}
iobuf
->
append
(
temp_iobuf
);
}
}
void
SerializeLodTensor
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
)
{
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
var_msg
->
set_type
(
::
paddle
::
distributed
::
LOD_TENSOR
);
const
framework
::
LoD
lod
=
tensor
->
lod
();
if
(
lod
.
size
()
>
0
)
{
var_msg
->
set_lod_level
(
lod
.
size
());
for
(
auto
&
each
:
lod
)
{
VarMsg
::
LodData
*
lod_inner
=
var_msg
->
add_lod
();
for
(
auto
&
d
:
each
)
{
lod_inner
->
add_lod_data
(
d
);
}
}
}
var_msg
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
framework
::
TransToProtoVarType
(
tensor
->
dtype
())));
for
(
auto
&
dim
:
phi
::
vectorize
(
tensor
->
dims
()))
{
var_msg
->
add_dims
(
dim
);
}
// IO Buffer
if
(
platform
::
is_cpu_place
(
tensor
->
place
()))
{
auto
data_len
=
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
());
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
&
data_len
),
8
);
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
tensor
->
data
()),
data_len
);
}
else
{
#ifdef PADDLE_WITH_CUDA
char
*
temp_ptr
=
new
char
[
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
())];
// NOLINT
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
platform
::
CPUPlace
(),
temp_ptr
,
tensor
->
place
(),
tensor
->
data
(),
tensor
->
numel
()
*
framework
::
SizeOfType
(
framework
::
TransToProtoVarType
(
tensor
->
dtype
())),
stream
);
auto
data_len
=
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
());
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
&
data_len
),
8
);
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
temp_ptr
),
data_len
);
delete
[]
temp_ptr
;
#endif
}
}
void
SerializeSelectedRows
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
)
{
phi
::
SelectedRows
*
slr
=
var
->
GetMutable
<
phi
::
SelectedRows
>
();
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
rows
=
slr
->
mutable_rows
();
var_msg
->
set_type
(
::
paddle
::
distributed
::
SELECTED_ROWS
);
var_msg
->
set_slr_height
(
slr
->
height
());
auto
*
var_data
=
var_msg
->
mutable_data
();
var_data
->
clear
();
var_data
->
resize
(
rows
->
size
()
*
sizeof
(
int64_t
));
char
*
data_ptr
=
const_cast
<
char
*>
(
var_data
->
data
());
memcpy
(
data_ptr
,
&
((
*
rows
)[
0
]),
rows
->
size
()
*
sizeof
(
int64_t
));
var_msg
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
framework
::
TransToProtoVarType
(
tensor
->
dtype
())));
for
(
auto
&
dim
:
phi
::
vectorize
(
tensor
->
dims
()))
{
var_msg
->
add_dims
(
dim
);
}
// IO Buffer
if
(
platform
::
is_cpu_place
(
tensor
->
place
()))
{
auto
data_len
=
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
());
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
&
data_len
),
8
);
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
tensor
->
data
()),
data_len
);
}
else
{
#ifdef PADDLE_WITH_CUDA
char
*
temp_ptr
=
new
char
[
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
())];
// NOLINT
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
platform
::
CPUPlace
(),
temp_ptr
,
tensor
->
place
(),
tensor
->
data
(),
tensor
->
numel
()
*
framework
::
SizeOfType
(
framework
::
TransToProtoVarType
(
tensor
->
dtype
())),
stream
);
auto
data_len
=
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
());
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
&
data_len
),
8
);
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
temp_ptr
),
data_len
);
delete
[]
temp_ptr
;
#endif
}
}
void
DeserializeFromMultiVarMsgAndIOBuf
(
const
MultiVarMsg
&
multi_msg
,
const
butil
::
IOBuf
*
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
*
scope
)
{
butil
::
IOBufBytesIterator
io_buffer_itr
(
*
iobuf
);
// size_t shard_buffer_remain = res_io_buffer.size();
for
(
int
recv_var_index
=
0
;
recv_var_index
<
multi_msg
.
send_var_names_size
();
++
recv_var_index
)
{
const
auto
&
msg
=
multi_msg
.
var_messages
(
recv_var_index
);
auto
*
var
=
scope
->
Var
(
msg
.
varname
());
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
}
void
DeserializeFromMultiVarMsgAndIOBuf
(
const
MultiVarMsg
&
multi_msg
,
const
butil
::
IOBuf
*
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
)
{
butil
::
IOBufBytesIterator
io_buffer_itr
(
*
iobuf
);
// size_t shard_buffer_remain = res_io_buffer.size();
for
(
int
recv_var_index
=
0
;
recv_var_index
<
multi_msg
.
send_var_names_size
();
++
recv_var_index
)
{
const
auto
&
msg
=
multi_msg
.
var_messages
(
recv_var_index
);
auto
*
var
=
scope
->
FindVar
(
msg
.
varname
());
PADDLE_ENFORCE_NE
(
var
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Not find variable %s in scope."
,
msg
.
varname
()));
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
}
void
DeserializeLodTensor
(
framework
::
Variable
*
var
,
const
VarMsg
&
msg
,
butil
::
IOBufBytesIterator
&
io_buffer_itr
,
// NOLINT
const
platform
::
DeviceContext
&
ctx
)
{
const
auto
place
=
ctx
.
GetPlace
();
framework
::
LoDTensor
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
std
::
vector
<
int
>
vec_dim
;
for
(
auto
&
x
:
msg
.
dims
())
{
vec_dim
.
push_back
(
x
);
}
tensor
->
Resize
(
phi
::
make_ddim
(
vec_dim
));
framework
::
LoD
lod
;
for
(
int
i
=
0
;
i
<
msg
.
lod_level
();
++
i
)
{
framework
::
Vector
<
size_t
>
v
;
for
(
int
j
=
0
;
j
<
msg
.
lod
(
i
).
lod_data_size
();
++
j
)
{
v
.
push_back
(
msg
.
lod
(
i
).
lod_data
(
j
));
}
lod
.
push_back
(
v
);
}
tensor
->
set_lod
(
lod
);
void
*
tensor_data
=
tensor
->
mutable_data
(
place
,
framework
::
TransToPhiDataType
(
VarMessageToVarType
(
msg
.
data_type
())));
// IO Buffer
if
(
platform
::
is_cpu_place
(
place
))
{
unsigned
long
data_len
;
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)(
&
data_len
),
8
);
// NOLINT
io_buffer_itr
.
copy_and_forward
(
tensor_data
,
data_len
);
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#ifdef PADDLE_WITH_CUDA
unsigned
long
data_len
;
// NOLINT
char
*
temp_ptr
=
new
char
[
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
())];
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)(
&
data_len
),
8
);
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)
temp_ptr
,
data_len
);
// NOLINT
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
place
,
tensor_data
,
platform
::
CPUPlace
(),
(
void
*
)
temp_ptr
,
// NOLINT
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
()),
stream
);
delete
[]
temp_ptr
;
#endif
}
}
void
DeserializeSelectedRows
(
framework
::
Variable
*
var
,
const
VarMsg
&
msg
,
butil
::
IOBufBytesIterator
&
io_buffer_itr
,
// NOLINT
const
platform
::
DeviceContext
&
ctx
)
{
const
auto
place
=
ctx
.
GetPlace
();
auto
*
slr
=
var
->
GetMutable
<
phi
::
SelectedRows
>
();
framework
::
Tensor
*
tensor
=
slr
->
mutable_value
();
slr
->
set_height
(
msg
.
slr_height
());
std
::
vector
<
int64_t
>
tmp_rows
(
msg
.
dims
()[
0
]);
memcpy
(
tmp_rows
.
data
(),
msg
.
data
().
data
(),
msg
.
dims
()[
0
]
*
sizeof
(
int64_t
));
slr
->
set_rows
(
tmp_rows
);
std
::
vector
<
int
>
vec_dim
;
for
(
auto
&
x
:
msg
.
dims
())
{
vec_dim
.
push_back
(
x
);
}
tensor
->
Resize
(
phi
::
make_ddim
(
vec_dim
));
void
*
tensor_data
=
tensor
->
mutable_data
(
place
,
framework
::
TransToPhiDataType
(
VarMessageToVarType
(
msg
.
data_type
())));
// IO Buffer
if
(
platform
::
is_cpu_place
(
place
))
{
unsigned
long
data_len
;
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)(
&
data_len
),
8
);
// NOLINT
io_buffer_itr
.
copy_and_forward
(
tensor_data
,
data_len
);
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#ifdef PADDLE_WITH_CUDA
char
*
temp_ptr
=
new
char
[
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
())];
// NOLINT
unsigned
long
data_len
;
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)(
&
data_len
),
8
);
// NOLINT
io_buffer_itr
.
copy_and_forward
(
temp_ptr
,
data_len
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
place
,
tensor_data
,
platform
::
CPUPlace
(),
temp_ptr
,
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
()),
stream
);
delete
[]
temp_ptr
;
#endif
}
}
std
::
string
GetIntTypeEndpoint
(
const
std
::
string
&
ip
,
const
uint32_t
&
port
)
{
// There are usually two forms of IP address: ip(int) / ip (hostname)
// If there're some problem with DNS, or ip triggers the bug of Brpc
// We will try to get the IP address of the domain name manually again
std
::
string
ip_port
=
ip
+
":"
+
std
::
to_string
(
port
);
struct
hostent
*
hp
=
NULL
;
hp
=
gethostbyname
(
ip
.
c_str
());
if
(
NULL
==
hp
)
{
LOG
(
ERROR
)
<<
"Brpc Start failed, ip_port= "
<<
ip_port
<<
" , Error infomation: "
<<
hstrerror
(
h_errno
);
}
int
i
=
0
;
char
*
int_ip
=
NULL
;
while
(
hp
->
h_addr_list
[
i
]
!=
NULL
)
{
int_ip
=
inet_ntoa
(
*
(
struct
in_addr
*
)
hp
->
h_addr_list
[
i
]);
VLOG
(
3
)
<<
"Brpc Get host by name, host:"
<<
ip
<<
" -> ip: "
<<
int_ip
;
break
;
}
std
::
string
str_ip
=
int_ip
;
std
::
string
int_ip_port
=
str_ip
+
":"
+
std
::
to_string
(
port
);
return
int_ip_port
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_utils.h
0 → 100644
View file @
d2d32668
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <netdb.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/phi/backends/dynload/port.h"
namespace
butil
{
class
IOBuf
;
class
IOBufBytesIterator
;
}
// namespace butil
namespace
grpc
{
class
ByteBuffer
;
}
// namespace grpc
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
void
SerializeToMultiVarMsgAndIOBuf
(
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_name_val
,
const
std
::
vector
<
std
::
string
>&
recv_var_name_val
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
MultiVarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
);
void
SerializeLodTensor
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
);
void
SerializeSelectedRows
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
butil
::
IOBuf
*
iobuf
);
// Deserialize for Server
void
DeserializeFromMultiVarMsgAndIOBuf
(
const
MultiVarMsg
&
multi_msg
,
const
butil
::
IOBuf
*
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
*
scope
);
// Deserialize for Client
void
DeserializeFromMultiVarMsgAndIOBuf
(
const
MultiVarMsg
&
multi_msg
,
const
butil
::
IOBuf
*
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
);
void
DeserializeLodTensor
(
framework
::
Variable
*
var
,
const
VarMsg
&
msg
,
butil
::
IOBufBytesIterator
&
iobuf
,
// NOLINT
const
platform
::
DeviceContext
&
ctx
);
void
DeserializeSelectedRows
(
framework
::
Variable
*
var
,
const
VarMsg
&
msg
,
butil
::
IOBufBytesIterator
&
iobuf
,
// NOLINT
const
platform
::
DeviceContext
&
ctx
);
std
::
string
GetIntTypeEndpoint
(
const
std
::
string
&
ip
,
const
uint32_t
&
port
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/communicator/CMakeLists.txt
0 → 100644
View file @
d2d32668
get_property
(
RPC_DEPS GLOBAL PROPERTY RPC_DEPS
)
set_source_files_properties
(
communicator.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
communicator
SRCS communicator.cc
DEPS scope
client
boost
table
math_function
selected_rows_functor
${
RPC_DEPS
}
)
paddle/fluid/distributed/ps/service/communicator/communicator.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include <google/protobuf/text_format.h>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
namespace
paddle
{
namespace
distributed
{
using
framework
::
LoDTensor
;
using
phi
::
SelectedRows
;
const
uint32_t
MAX_FEASIGN_NUM
=
1024
*
100
*
100
;
inline
double
GetCurrentUS
()
{
struct
timeval
time
;
gettimeofday
(
&
time
,
NULL
);
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
Communicator
::
Communicator
()
{}
void
Communicator
::
InitGFlag
(
const
std
::
string
&
gflags
)
{
VLOG
(
3
)
<<
"Init With Gflags:"
<<
gflags
;
std
::
vector
<
std
::
string
>
flags
=
paddle
::
string
::
split_string
(
gflags
);
if
(
flags
.
size
()
<
1
)
{
flags
.
push_back
(
"-max_body_size=314217728"
);
flags
.
push_back
(
"-bthread_concurrency=40"
);
flags
.
push_back
(
"-socket_max_unwritten_bytes=2048000000"
);
flags
.
push_back
(
"-max_connection_pool_size=1950"
);
}
auto
it
=
flags
.
begin
();
flags
.
insert
(
it
,
"exe default"
);
char
*
flags_ptr
[
flags
.
size
()];
for
(
size_t
i
=
0
;
i
<
flags
.
size
();
++
i
)
{
flags_ptr
[
i
]
=
(
char
*
)(
flags
[
i
].
c_str
());
// NOLINT
}
int
params_cnt
=
flags
.
size
();
char
**
params_ptr
=
&
(
flags_ptr
[
0
]);
::
GFLAGS_NAMESPACE
::
ParseCommandLineFlags
(
&
params_cnt
,
&
params_ptr
,
true
);
}
std
::
once_flag
Communicator
::
init_flag_
;
std
::
shared_ptr
<
Communicator
>
Communicator
::
communicator_
(
nullptr
);
void
Communicator
::
InitBrpcClient
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
)
{
auto
fleet
=
paddle
::
distributed
::
FleetWrapper
::
GetInstance
();
if
(
_worker_ptr
.
get
()
==
nullptr
)
{
_worker_ptr
=
fleet
->
worker_ptr_
;
}
return
;
}
std
::
vector
<
uint64_t
>
Communicator
::
GetClientInfo
()
{
std
::
vector
<
uint64_t
>
res
=
_ps_env
.
GetClientInfo
();
for
(
auto
rr
:
res
)
{
VLOG
(
2
)
<<
"Communicator::GetClientInfo "
<<
rr
;
}
return
res
;
}
int
Communicator
::
SetClients
(
std
::
vector
<
uint64_t
>
&
host_sign_list
)
{
int
node
=
host_sign_list
.
size
();
return
_ps_env
.
SetPsClients
(
host_sign_list
.
data
(),
node
);
}
void
Communicator
::
RpcRecvDense
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
Scope
*
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcRecvDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
std
::
vector
<
paddle
::
distributed
::
Region
>
regions
;
regions
.
reserve
(
varnames
.
size
());
for
(
auto
&
t
:
varnames
)
{
Variable
*
var
=
scope
->
Var
(
t
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
Variable
*
temp_var
=
xpu_temp_scope_
->
Var
(
t
);
LoDTensor
*
temp_tensor
=
temp_var
->
GetMutable
<
LoDTensor
>
();
temp_tensor
->
Resize
(
tensor
->
dims
());
float
*
temp_data
=
temp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
paddle
::
distributed
::
Region
reg
(
temp_data
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
VLOG
(
1
)
<<
"AsyncCommunicator::RpcRecvDense Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_data
[
tensor
->
numel
()
-
1
];
#endif
}
else
{
float
*
w
=
tensor
->
mutable_data
<
float
>
(
tensor
->
place
());
paddle
::
distributed
::
Region
reg
(
w
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
}
}
auto
status
=
_worker_ptr
->
PullDense
(
regions
.
data
(),
regions
.
size
(),
table_id
);
status
.
wait
();
for
(
auto
&
t
:
varnames
)
{
Variable
*
var
=
scope
->
FindVar
(
t
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
VLOG
(
3
)
<<
"AsyncCommunicator::RecvNoBarrier Var "
<<
t
<<
" On gpu? "
<<
platform
::
is_gpu_place
(
tensor
->
place
());
float
*
temp_recv_data
=
tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
VLOG
(
3
)
<<
"AsyncCommunicator::RpcRecvDense Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_recv_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_recv_data
[
tensor
->
numel
()
-
1
];
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
LoDTensor
*
temp_tensor
=
xpu_temp_scope_
->
FindVar
(
t
)
->
GetMutable
<
LoDTensor
>
();
framework
::
TensorCopy
(
*
temp_tensor
,
tensor
->
place
(),
tensor
);
float
*
temp_data
=
temp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
VLOG
(
1
)
<<
"AsyncCommunicator::RpcRecvDense Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_data
[
tensor
->
numel
()
-
1
];
#endif
}
}
return
;
}
void
Communicator
::
RpcSendDenseParam
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
const
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcSendDenseParam"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
place
=
platform
::
CPUPlace
();
std
::
vector
<
paddle
::
distributed
::
Region
>
regions
;
for
(
auto
&
t
:
varnames
)
{
Variable
*
var
=
scope
.
FindVar
(
t
);
CHECK
(
var
!=
nullptr
)
<<
"var["
<<
t
<<
"] not found"
;
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
Variable
*
temp_var
=
xpu_temp_scope_
->
Var
(
t
);
LoDTensor
*
temp_tensor
=
temp_var
->
GetMutable
<
LoDTensor
>
();
temp_tensor
->
Resize
(
tensor
->
dims
());
float
*
temp_data
=
temp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
framework
::
TensorCopy
(
*
tensor
,
platform
::
CPUPlace
(),
temp_tensor
);
paddle
::
distributed
::
Region
reg
(
temp_data
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
VLOG
(
1
)
<<
"AsyncCommunicator::RpcSendDenseParam Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_data
[
tensor
->
numel
()
-
1
];
#endif
}
else
{
float
*
w
=
tensor
->
mutable_data
<
float
>
(
place
);
paddle
::
distributed
::
Region
reg
(
w
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
VLOG
(
1
)
<<
"AsyncCommunicator::RpcSendDenseParam Var "
<<
t
<<
" talbe_id "
<<
table_id
<<
" Temp_data[0] "
<<
w
[
0
]
<<
" Temp_data[-1] "
<<
w
[
tensor
->
numel
()
-
1
];
}
}
auto
status
=
_worker_ptr
->
PushDenseParam
(
regions
.
data
(),
regions
.
size
(),
table_id
);
status
.
wait
();
VLOG
(
4
)
<<
"RPC Send Dense Param "
<<
table_id
<<
" done!"
;
return
;
}
void
Communicator
::
RpcSendDense
(
const
CommContext
&
ctx
,
const
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcSendDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
&
var_names
=
ctx
.
origin_varnames
;
auto
&
table_id
=
ctx
.
table_id
;
auto
dense_data
=
std
::
make_shared
<
std
::
vector
<
float
>>
();
size_t
request_call_num
=
_worker_ptr
->
GetServerNums
();
uint32_t
num_per_shard
=
DenseDimPerShard
(
ctx
.
height_sections
[
0
],
request_call_num
);
dense_data
->
resize
(
num_per_shard
*
request_call_num
);
// accessor->update_dim() = 1
float
*
data
=
dense_data
->
data
();
uint32_t
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
var_names
.
size
();
++
i
)
{
const
LoDTensor
tensor
=
scope
.
FindVar
(
var_names
[
i
])
->
Get
<
LoDTensor
>
();
size_t
count
=
static_cast
<
size_t
>
(
tensor
.
numel
());
const
float
*
g
=
tensor
.
data
<
float
>
();
CHECK
(
pos
+
count
<=
dense_data
->
size
())
<<
"invalid dense size, cur pos["
<<
pos
<<
"]"
<<
" data_num["
<<
count
<<
"] size["
<<
dense_data
->
size
()
<<
"]"
;
memcpy
(
data
+
pos
,
g
,
count
*
sizeof
(
float
));
pos
+=
count
;
}
++
_async_call_num
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
auto
status
=
_worker_ptr
->
PushDenseRawGradient
(
table_id
,
data
,
dense_data
->
size
(),
closure
);
status
.
wait
();
return
;
}
void
Communicator
::
RpcSendSparseParam
(
const
std
::
string
&
varname
,
int
table_id
,
const
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcSendSparseParam"
,
platform
::
TracerEventType
::
Communication
,
1
);
size_t
request_call_num
=
_worker_ptr
->
GetServerNums
();
std
::
vector
<
float
*>
push_g_vec
;
auto
*
send_var
=
scope
.
FindVar
(
varname
);
auto
*
tensor
=
send_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dim
=
tensor
->
dims
()[
1
];
uint64_t
sparse_num
=
static_cast
<
uint64_t
>
(
tensor
->
dims
()[
0
]);
std
::
vector
<
uint64_t
>
sparse_push_keys
(
sparse_num
);
std
::
iota
(
sparse_push_keys
.
begin
(),
sparse_push_keys
.
end
(),
0
);
push_g_vec
.
reserve
(
sparse_num
);
for
(
auto
i
=
0
;
i
<
static_cast
<
int
>
(
sparse_push_keys
.
size
());
++
i
)
{
push_g_vec
.
push_back
(
tensor
->
data
<
float
>
()
+
i
*
dim
);
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_SPARSE_PARAM
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
status
=
_worker_ptr
->
PushSparseParam
(
table_id
,
sparse_push_keys
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
sparse_push_keys
.
size
(),
closure
);
status
.
wait
();
return
;
}
void
Communicator
::
RpcSendSparse
(
const
std
::
string
&
var_name
,
int
table_id
,
const
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcSendSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
size_t
request_call_num
=
_worker_ptr
->
GetServerNums
();
std
::
vector
<
uint64_t
>
sparse_push_keys
;
std
::
vector
<
float
*>
push_g_vec
;
auto
*
send_var
=
scope
.
FindVar
(
var_name
);
auto
*
tensor
=
send_var
->
GetMutable
<
phi
::
SelectedRows
>
();
auto
dim
=
tensor
->
value
().
dims
()[
1
];
std
::
transform
(
tensor
->
rows
().
begin
(),
tensor
->
rows
().
end
(),
std
::
back_inserter
(
sparse_push_keys
),
[
&
](
int64_t
id
)
{
return
static_cast
<
uint64_t
>
(
id
);
});
for
(
auto
i
=
0
;
i
<
static_cast
<
int
>
(
sparse_push_keys
.
size
());
++
i
)
{
push_g_vec
.
push_back
(
tensor
->
mutable_value
()
->
data
<
float
>
()
+
i
*
dim
);
}
// TODO(wangguanqun): padding_idx is not ignored, this is a bug.
// if padding_idx == padding in datareader, the server will core.
/*
for (size_t i = 0; i < tensor->rows().size(); ++i) {
uint64_t real_id = static_cast<uint64_t>(tensor->rows()[i]);
if (real_id != 0) {
sparse_push_keys.push_back(real_id);
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}
}
*/
++
_async_call_num
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
auto
status
=
_worker_ptr
->
PushSparseRawGradient
(
table_id
,
sparse_push_keys
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
sparse_push_keys
.
size
(),
closure
);
status
.
wait
();
return
;
}
void
Communicator
::
RpcRecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
Scope
*
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcRecvSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
*
send_var
=
scope
->
Var
(
varname
);
auto
*
tensor
=
send_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dim
=
tensor
->
dims
()[
1
];
uint64_t
sparse_num
=
static_cast
<
uint64_t
>
(
tensor
->
dims
()[
0
]);
std
::
vector
<
uint64_t
>
sparse_push_keys
(
sparse_num
);
std
::
iota
(
sparse_push_keys
.
begin
(),
sparse_push_keys
.
end
(),
0
);
std
::
vector
<
float
*>
push_g_vec
;
for
(
auto
i
=
0
;
i
<
static_cast
<
int
>
(
sparse_push_keys
.
size
());
++
i
)
{
push_g_vec
.
push_back
(
tensor
->
data
<
float
>
()
+
i
*
dim
);
}
bool
training
=
true
;
auto
status
=
_worker_ptr
->
PullSparseParam
(
static_cast
<
float
**>
(
push_g_vec
.
data
()),
table_id
,
sparse_push_keys
.
data
(),
sparse_push_keys
.
size
(),
training
);
status
.
wait
();
return
;
}
void
Communicator
::
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
{
if
(
trainer_id_
==
0
)
{
for
(
auto
&
iter
:
recv_varname_to_ctx
)
{
auto
&
table_id
=
iter
.
first
;
auto
&
varnames
=
iter
.
second
;
RpcSendDenseParam
(
varnames
,
table_id
,
*
recv_scope_
);
VLOG
(
1
)
<<
"push dense param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
}
return
;
}
void
Communicator
::
PullDense
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
{
for
(
auto
&
iter
:
recv_varname_to_ctx
)
{
auto
&
table_id
=
iter
.
first
;
auto
&
varnames
=
iter
.
second
;
RpcRecvDense
(
varnames
,
table_id
,
recv_scope_
);
VLOG
(
1
)
<<
"pull dense param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
return
;
}
void
Communicator
::
RpcProfilerControl
()
{
if
(
trainer_id_
==
0
)
{
if
(
!
do_server_profiler_
&&
platform
::
IsProfileEnabled
())
{
// send profiler start flag
do_server_profiler_
=
true
;
auto
start_status
=
_worker_ptr
->
StartProfiler
();
start_status
.
wait
();
}
else
if
(
do_server_profiler_
&&
!
platform
::
IsProfileEnabled
())
{
// send profiler end flag
auto
stop_status
=
_worker_ptr
->
StopProfiler
();
stop_status
.
wait
();
do_server_profiler_
=
false
;
}
}
}
void
Communicator
::
SendGlobalStep
(
const
CommContext
&
ctx
,
int
batches
,
Scope
*
send_scope
)
{
if
(
batches
==
0
)
{
return
;
}
platform
::
RecordEvent
record_event
(
"Communicator->SendGlobalStep"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
&
table_id
=
ctx
.
table_id
;
size_t
request_call_num
=
_worker_ptr
->
GetServerNums
();
auto
&
var_name
=
STEP_COUNTER
;
auto
*
out_var
=
send_scope
->
Var
(
var_name
);
auto
*
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
data
=
out_t
->
mutable_data
<
int64_t
>
({
1
},
platform
::
CPUPlace
());
data
[
0
]
=
static_cast
<
int64_t
>
(
batches
);
VLOG
(
3
)
<<
"Communicator::SendGlobalStep send: "
<<
batches
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_GLOBAL_STEP
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
status
=
_worker_ptr
->
PushGlobalStep
(
table_id
,
data
,
closure
);
status
.
wait
();
return
;
}
void
AsyncCommunicator
::
RecvThread
()
{
if
(
!
independent_recv_
)
return
;
VLOG
(
3
)
<<
"Independent RecvThread Start and Wait"
;
while
(
running_
)
{
int
grad_num
=
grad_num_
.
load
();
if
(
grad_num
>
min_send_grad_num_before_recv_
)
{
RecvByCommunicator
();
grad_num_
.
store
(
0
);
}
else
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
}
}
VLOG
(
1
)
<<
"communicator stopped, independent recv thread exit"
;
}
void
AsyncCommunicator
::
RecvByCommunicator
()
{
if
(
!
running_
)
return
;
RecvNoBarrier
();
VLOG
(
3
)
<<
"run recv graph end"
;
}
void
AsyncCommunicator
::
RecvNoBarrier
()
{
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
&
table_id
=
iter
.
first
;
auto
&
varnames
=
iter
.
second
;
RpcRecvDense
(
varnames
,
table_id
,
recv_scope_
);
}
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
var_names
=
iter
.
second
;
for
(
auto
&
t
:
var_names
)
{
Variable
*
var
=
recv_scope_
->
FindVar
(
t
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
VLOG
(
3
)
<<
"AsyncCommunicator::RecvNoBarrier Var "
<<
t
<<
" On gpu? "
<<
platform
::
is_gpu_place
(
tensor
->
place
());
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
LoDTensor
*
temp_tensor
=
xpu_temp_scope_
->
FindVar
(
t
)
->
GetMutable
<
LoDTensor
>
();
framework
::
TensorCopy
(
*
temp_tensor
,
tensor
->
place
(),
tensor
);
#endif
}
}
}
return
;
}
void
AsyncCommunicator
::
SendByCommunicator
()
{
std
::
vector
<
std
::
future
<
void
>>
tasks
;
tasks
.
reserve
(
send_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
auto
send_recv_task
=
[
this
,
&
ctx
]
{
auto
&
varnames
=
ctx
.
origin_varnames
;
auto
&
table_id
=
ctx
.
table_id
;
size_t
var_nums
=
varnames
.
size
();
auto
&
check_queue
=
send_varname_to_queue_
[
varnames
[
0
]];
std
::
vector
<
std
::
vector
<
std
::
shared_ptr
<
Variable
>>>
vars
;
vars
.
resize
(
var_nums
);
int
merged_var_num
=
0
;
int
wait_times
=
0
;
while
(
merged_var_num
<
max_merge_var_num_
)
{
if
(
check_queue
->
Size
()
==
0
)
{
VLOG
(
4
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
send_wait_times_
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
wait_times
++
;
continue
;
}
else
{
wait_times
=
0
;
for
(
size_t
i
=
0
;
i
<
var_nums
;
i
++
)
{
auto
&
var_name
=
varnames
[
i
];
auto
&
var_queue
=
send_varname_to_queue_
[
var_name
];
vars
[
i
].
push_back
(
var_queue
->
Pop
());
}
merged_var_num
++
;
}
}
if
(
merged_var_num
==
0
)
return
;
for
(
size_t
i
=
0
;
i
<
var_nums
;
i
++
)
{
auto
&
var_name
=
varnames
[
i
];
if
(
var_name
==
STEP_COUNTER
)
{
MergeVars
<
int64_t
>
(
var_name
,
vars
[
i
],
send_scope_
.
get
(),
1
);
}
else
{
MergeVars
<
float
>
(
var_name
,
vars
[
i
],
send_scope_
.
get
(),
1
);
}
}
if
(
ctx
.
is_tensor_table
)
{
SendGlobalStep
(
ctx
,
merged_var_num
,
send_scope_
.
get
());
}
else
if
(
ctx
.
is_sparse
)
{
PADDLE_ENFORCE_EQ
(
varnames
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"sparse variables can only be merged by one variables"
));
RpcSendSparse
(
varnames
[
0
],
table_id
,
*
send_scope_
);
}
else
{
RpcSendDense
(
ctx
,
*
send_scope_
);
if
(
!
independent_recv_
&&
recv_varname_to_ctx_
.
find
(
table_id
)
!=
recv_varname_to_ctx_
.
end
())
{
auto
recv_varnames
=
recv_varname_to_ctx_
.
at
(
table_id
);
RpcRecvDense
(
recv_varnames
,
table_id
,
recv_scope_
);
}
}
if
(
independent_recv_
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
}
};
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_recv_task
)));
}
for
(
auto
&
task
:
tasks
)
{
task
.
wait
();
}
return
;
}
void
AsyncCommunicator
::
PushDensePostProcessing
()
{
if
(
independent_recv_
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
}
return
;
}
void
AsyncCommunicator
::
MainThread
()
{
VLOG
(
3
)
<<
"AsyncCommunicator MainThread start and wait"
;
while
(
waiting_
&&
running_
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
VLOG
(
3
)
<<
"wait for running"
;
}
while
(
running_
)
{
SendByCommunicator
();
RpcProfilerControl
();
}
VLOG
(
1
)
<<
"communicator stopped, send thread exit"
;
}
void
AsyncCommunicator
::
PullSparseToTensorSync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
bool
is_training
,
std
::
vector
<
const
LoDTensor
*>
*
inputs
,
std
::
vector
<
LoDTensor
*>
*
outputs
)
{
std
::
vector
<
uint64_t
>
fea_keys
;
std
::
vector
<
float
*>
pull_result_ptr
;
fea_keys
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
pull_result_ptr
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
std
::
vector
<
float
>
init_value
(
fea_dim
,
0
);
framework
::
LoDTensor
*
output
=
nullptr
;
float
*
output_data
=
nullptr
;
size_t
output_index
=
-
1
;
size_t
output_len
=
0
;
for
(
size_t
index
=
0
;
index
<
inputs
->
size
();
++
index
)
{
const
framework
::
LoDTensor
*
tensor
=
inputs
->
at
(
index
);
const
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
size_t
len
=
tensor
->
numel
();
for
(
size_t
i
=
0
;
i
<
len
;
++
i
,
output_len
+=
fea_dim
)
{
if
(
!
output
||
output_len
==
size_t
(
output
->
numel
()))
{
++
output_index
;
CHECK
(
output_index
<
outputs
->
size
());
// NOLINT
output
=
outputs
->
at
(
output_index
);
output
->
set_lod
(
tensor
->
lod
());
output_data
=
output
->
mutable_data
<
float
>
(
place
);
output_len
=
0
;
CHECK
(
output
->
numel
()
%
fea_dim
==
0
);
// NOLINT
CHECK
(
output_data
!=
nullptr
);
// NOLINT
}
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
i
]);
if
(
real_id
==
padding_id
)
{
memcpy
(
output_data
+
output_len
,
init_value
.
data
(),
sizeof
(
float
)
*
fea_dim
);
continue
;
}
fea_keys
.
push_back
(
real_id
);
pull_result_ptr
.
push_back
(
output_data
+
output_len
);
}
}
auto
status
=
_worker_ptr
->
PullSparse
(
pull_result_ptr
.
data
(),
table_id
,
fea_keys
.
data
(),
fea_keys
.
size
(),
is_training
);
status
.
wait
();
auto
ret
=
status
.
get
();
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"fleet pull sparse failed, status["
<<
ret
<<
"]"
;
sleep
(
sleep_seconds_before_fail_exit_
);
}
}
void
AsyncCommunicator
::
PushSparseFromTensorAsync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
framework
::
LoDTensor
*>
*
inputs
,
const
framework
::
LoDTensor
*
shows
,
const
framework
::
LoDTensor
*
clks
,
std
::
vector
<
framework
::
LoDTensor
*>
*
outputs
)
{
int
batch_size
=
-
1
;
bool
batch_size_consist
=
true
;
for
(
auto
*
input
:
*
inputs
)
{
int
cur_batch_size
=
input
->
lod
().
size
()
?
input
->
lod
()[
0
].
size
()
-
1
:
input
->
dims
()[
0
];
if
(
batch_size
==
-
1
)
{
batch_size
=
cur_batch_size
;
}
else
if
(
batch_size
!=
cur_batch_size
)
{
// CHECK(batch_size == cur_batch_size); // NOLINT
batch_size_consist
=
false
;
break
;
}
}
CHECK
(
batch_size
>
0
);
// NOLINT
int
show_size
=
shows
->
lod
().
size
()
?
shows
->
lod
()[
0
].
size
()
-
1
:
shows
->
dims
()[
0
];
CHECK
(
show_size
==
batch_size
||
show_size
==
1
);
int
clk_size
=
clks
->
lod
().
size
()
?
clks
->
lod
()[
0
].
size
()
-
1
:
clks
->
dims
()[
0
];
CHECK
(
clk_size
==
batch_size
||
clk_size
==
1
);
CHECK
(
outputs
->
size
()
==
inputs
->
size
());
std
::
vector
<
uint64_t
>
push_keys
;
push_keys
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
std
::
vector
<
std
::
vector
<
float
>>
push_values
;
push_values
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
size_t
output_len
=
0
;
size_t
input_idx
=
0
;
VLOG
(
2
)
<<
"fleet.cc::emb_dim: "
<<
fea_dim
<<
" batch_size: "
<<
batch_size
<<
" batch_size_consist: "
<<
batch_size_consist
;
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>();
for
(
size_t
index
=
0
;
index
<
inputs
->
size
();
++
index
)
{
framework
::
LoDTensor
*
g_tensor
=
outputs
->
at
(
index
);
float
*
g
=
g_tensor
->
data
<
float
>
();
if
(
batch_size_consist
)
{
// TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen
::
Map
<
Eigen
::
Matrix
<
float
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
g_mat
(
g
,
g_tensor
->
numel
()
/
fea_dim
,
fea_dim
);
g_mat
.
rightCols
(
fea_dim
-
2
)
*=
batch_size
;
// hard code here, because of cvm_grad op
}
const
framework
::
LoDTensor
*
tensor
=
inputs
->
at
(
index
);
const
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
size_t
len
=
tensor
->
numel
();
output_len
=
0
;
if
(
tensor
->
lod
().
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
tensor
->
lod
()[
0
].
size
()
-
1
;
++
i
)
{
for
(
size_t
j
=
tensor
->
lod
()[
0
][
i
];
j
<
tensor
->
lod
()[
0
][
i
+
1
];
++
j
,
output_len
+=
fea_dim
)
{
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
j
]);
if
(
real_id
==
padding_id
)
{
continue
;
}
push_keys
.
emplace_back
(
real_id
);
push_values
.
emplace_back
(
fea_dim
+
1
);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
// push_values.back()[1] =
// (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
// push_values.back()[2] =
// (i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float
*
data
=
push_values
.
back
().
data
()
+
1
;
// hard code here
memcpy
(
data
,
g
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
++
input_idx
;
}
}
}
else
{
for
(
size_t
i
=
0
;
i
<
len
;
++
i
,
output_len
+=
fea_dim
)
{
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
i
]);
if
(
real_id
==
padding_id
)
{
continue
;
}
push_keys
.
emplace_back
(
real_id
);
push_values
.
emplace_back
(
fea_dim
+
1
);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
// push_values.back()[1] =
// (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
// push_values.back()[2] =
// (i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float
*
data
=
push_values
.
back
().
data
()
+
1
;
memcpy
(
data
,
g
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
++
input_idx
;
}
}
CHECK
(
static_cast
<
int64_t
>
(
output_len
)
==
g_tensor
->
numel
());
}
std
::
vector
<
float
*>
push_g_vec
(
input_idx
,
nullptr
);
for
(
auto
i
=
0u
;
i
<
push_keys
.
size
();
++
i
)
{
push_g_vec
[
i
]
=
push_values
.
at
(
i
).
data
();
}
PADDLE_ENFORCE_EQ
(
this
->
Check
(
table_id
),
true
,
platform
::
errors
::
InvalidArgument
(
"can not find table: %s, please check your config"
,
table_id
));
auto
status
=
_worker_ptr
->
PushSparse
(
table_id
,
push_keys
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
push_keys
.
size
());
}
void
HalfAsyncCommunicator
::
MainThread
()
{
VLOG
(
3
)
<<
"HalfAsyncCommunicator MainThread start and wait"
;
while
(
waiting_
&&
running_
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
VLOG
(
3
)
<<
"wait for running"
;
}
while
(
running_
)
{
SendByCommunicator
();
BarrierSend
();
RecvByCommunicator
();
BarrierRecv
();
BarrierWeakUp
();
}
VLOG
(
1
)
<<
"communicator stopped, send thread exit"
;
}
void
AsyncCommunicator
::
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
send_varname_to_ctx_
=
std
::
move
(
send_varname_to_ctx
);
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
recv_scope_
=
std
::
move
(
recv_scope
);
send_scope_
.
reset
(
new
Scope
());
xpu_temp_scope_
.
reset
(
new
Scope
());
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
auto
&
varnames
=
ctx
.
origin_varnames
;
for
(
auto
&
var_name
:
varnames
)
{
send_varname_to_queue_
[
var_name
]
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>
(
send_queue_size_
);
}
}
send_threadpool_
.
reset
(
new
::
ThreadPool
(
thread_pool_size_
));
}
AsyncCommunicator
::~
AsyncCommunicator
()
{
running_
=
false
;
if
(
main_thread_
)
main_thread_
->
join
();
if
(
recv_thread_
)
recv_thread_
->
join
();
}
void
AsyncCommunicator
::
Start
()
{
VLOG
(
1
)
<<
"Communicator start"
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
VLOG
(
1
)
<<
"start send thread and recv thread"
;
waiting_
=
true
;
running_
=
true
;
// flushing_ = false;
BarrierTriggerReset
(
max_merge_var_num_
);
// start send and recv thread
main_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
MainThread
,
this
)));
if
(
independent_recv_
)
{
recv_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
RecvThread
,
this
)));
}
}
}
void
AsyncCommunicator
::
Stop
()
{
VLOG
(
1
)
<<
"Communicator stop begin"
;
running_
=
false
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
// _worker_ptr->FinalizeWorker();
VLOG
(
1
)
<<
"client finalize_worker done"
;
if
(
recv_thread_
)
{
VLOG
(
1
)
<<
"stop recv thread"
;
recv_thread_
->
join
();
recv_thread_
.
reset
(
nullptr
);
}
if
(
main_thread_
)
{
VLOG
(
1
)
<<
"stop main thread"
;
main_thread_
->
join
();
main_thread_
.
reset
(
nullptr
);
}
}
VLOG
(
1
)
<<
"Communicator stop done"
;
}
bool
AsyncCommunicator
::
Check
(
const
std
::
vector
<
std
::
string
>
&
var_tables
)
{
PADDLE_ENFORCE_EQ
(
var_tables
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"var_tables.size() == 1 is permitted"
));
auto
table_name
=
var_tables
[
0
];
if
(
send_varname_to_ctx_
.
find
(
table_name
)
==
send_varname_to_ctx_
.
end
())
{
return
false
;
}
if
(
table_name
==
STEP_COUNTER
)
{
VLOG
(
3
)
<<
"send step_counter into queue"
;
auto
tmp_var
=
std
::
make_shared
<
Variable
>
();
auto
*
tensor
=
tmp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
phi
::
make_ddim
({
1
}));
auto
*
out_d
=
tensor
->
mutable_data
<
int64_t
>
(
platform
::
CPUPlace
());
out_d
[
0
]
=
1
;
send_varname_to_queue_
[
table_name
]
->
Push
(
tmp_var
);
}
return
true
;
}
bool
AsyncCommunicator
::
Check
(
const
int
table_id
)
{
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
if
(
ctx
.
table_id
==
table_id
)
return
true
;
}
return
false
;
}
void
AsyncCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
{
waiting_
=
false
;
for
(
size_t
i
=
0
;
i
<
var_names
.
size
();
i
++
)
{
auto
*
var
=
scope
.
FindVar
(
var_names
[
i
]);
auto
tmp_grad_var
=
std
::
make_shared
<
Variable
>
();
framework
::
CopyVariable
(
*
var
,
tmp_grad_var
.
get
());
send_varname_to_queue_
[
var_names
[
i
]]
->
Push
(
tmp_grad_var
);
}
}
void
HalfAsyncCommunicator
::
Clean
()
{
for
(
auto
&
iter
:
send_varname_to_queue_
)
{
auto
&
var_name
=
iter
.
first
;
auto
&
var_queue
=
iter
.
second
;
while
(
var_queue
->
Size
()
>
0
)
{
var_queue
->
Pop
();
}
VLOG
(
3
)
<<
"clean var: "
<<
var_name
<<
" done"
;
}
}
void
HalfAsyncCommunicator
::
BarrierTriggerDecrement
()
{
barrier_trigger_
--
;
VLOG
(
3
)
<<
"BarrierTriggerDecrement decrement barrier trigger to "
<<
barrier_trigger_
.
load
();
}
void
HalfAsyncCommunicator
::
BarrierTriggerReset
(
int
initial_val
)
{
barrier_trigger_
.
store
(
initial_val
);
VLOG
(
3
)
<<
"BarrierTriggerReset reset barrier trigger to "
<<
barrier_trigger_
.
load
();
}
void
HalfAsyncCommunicator
::
Barrier
()
{
barrier_counter_
++
;
if
(
!
running_
)
{
VLOG
(
3
)
<<
"Communicator is not running, release barrier"
;
return
;
}
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
barrier_mutex_
);
barrier_cond_
.
wait
(
lk
,
[
this
]
{
return
(
barrier_counter_
==
0
);
});
}
}
int
HalfAsyncCommunicator
::
BatchesCounter
()
{
while
(
running_
)
{
if
(
barrier_counter_
.
load
()
>=
barrier_trigger_
.
load
()
&&
barrier_trigger_
.
load
()
!=
0
)
{
break
;
}
else
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
}
}
return
barrier_counter_
.
load
();
}
void
HalfAsyncCommunicator
::
SendByCommunicator
()
{
int
batches
=
BatchesCounter
();
VLOG
(
1
)
<<
"HalfAsyncCommunicator::BatchesCounter = "
<<
batches
;
if
(
batches
<=
0
)
return
;
std
::
vector
<
std
::
future
<
void
>>
tasks
;
tasks
.
reserve
(
send_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
auto
send_recv_task
=
[
this
,
&
ctx
,
batches
]
{
auto
&
varnames
=
ctx
.
origin_varnames
;
auto
&
table_id
=
ctx
.
table_id
;
size_t
var_nums
=
varnames
.
size
();
std
::
vector
<
std
::
vector
<
std
::
shared_ptr
<
Variable
>>>
vars
;
vars
.
resize
(
var_nums
);
for
(
size_t
i
=
0
;
i
<
var_nums
;
i
++
)
{
auto
&
var_name
=
varnames
[
i
];
auto
&
var_queue
=
send_varname_to_queue_
[
var_name
];
for
(
int
j
=
0
;
j
<
batches
;
j
++
)
vars
[
i
].
push_back
(
var_queue
->
Pop
());
MergeVars
<
float
>
(
var_name
,
vars
[
i
],
send_scope_
.
get
(),
1
);
}
if
(
ctx
.
is_sparse
)
{
PADDLE_ENFORCE_EQ
(
varnames
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"sparse variables can only be merged by one variables"
));
RpcSendSparse
(
varnames
[
0
],
table_id
,
*
send_scope_
);
}
else
{
RpcSendDense
(
ctx
,
*
send_scope_
);
}
};
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_recv_task
)));
}
for
(
auto
&
task
:
tasks
)
{
task
.
wait
();
}
return
;
}
void
HalfAsyncCommunicator
::
BarrierWeakUp
()
{
barrier_counter_
.
store
(
0
);
barrier_cond_
.
notify_all
();
}
void
SyncCommunicator
::
BarrierSend
()
{
if
(
!
running_
)
return
;
BarrierWithTable
(
0
);
VLOG
(
4
)
<<
"BarrierSend with SyncCommunicator"
;
}
void
SyncCommunicator
::
BarrierRecv
()
{
if
(
!
running_
)
return
;
BarrierWithTable
(
1
);
VLOG
(
4
)
<<
"BarrierRecv with SyncCommunicator"
;
}
void
GeoCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->Send"
,
platform
::
TracerEventType
::
Communication
,
1
);
waiting_
=
false
;
auto
before_send
=
GetCurrentUS
();
auto
table_name
=
var_names
[
0
];
size_t
splited_var_nums
=
send_varname_to_ctx_
[
table_name
].
splited_varnames
.
size
();
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
int64_t
>>
ids_table
;
for
(
size_t
j
=
0
;
j
<
splited_var_nums
;
j
++
)
{
ids_table
.
insert
(
std
::
pair
<
std
::
string
,
std
::
unordered_set
<
int64_t
>>
(
send_varname_to_ctx_
[
table_name
].
splited_varnames
[
j
],
std
::
unordered_set
<
int64_t
>
()));
}
auto
*
var
=
scope
.
FindVar
(
table_name
);
PADDLE_ENFORCE_EQ
(
var
->
IsType
<
phi
::
SelectedRows
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Only need to send Sparse Grad in Geo mode."
));
auto
&
rows
=
var
->
Get
<
phi
::
SelectedRows
>
().
rows
();
// insert ids which has not been record
for
(
size_t
j
=
0
;
j
<
rows
.
size
();
j
++
)
{
auto
ep_idx
=
rows
[
j
]
%
splited_var_nums
;
ids_table
.
at
(
send_varname_to_ctx_
[
table_name
].
splited_varnames
[
ep_idx
])
.
insert
(
rows
[
j
]);
}
for
(
auto
&
iter
:
ids_table
)
{
auto
&
key
=
iter
.
first
;
auto
&
sparse_ids_set
=
iter
.
second
;
auto
sparse_ids_vec
=
std
::
make_shared
<
std
::
vector
<
int64_t
>>
();
sparse_ids_vec
->
assign
(
sparse_ids_set
.
begin
(),
sparse_ids_set
.
end
());
sparse_id_queues_
.
at
(
key
)
->
Put
(
sparse_ids_vec
);
VLOG
(
3
)
<<
"push "
<<
sparse_ids_vec
->
size
()
<<
" ids to "
<<
key
<<
"'s queue"
;
}
auto
after_send
=
GetCurrentUS
();
VLOG
(
2
)
<<
"run send op finish. use time "
<<
(
after_send
-
before_send
);
}
void
GeoCommunicator
::
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
send_varname_to_ctx_
=
std
::
move
(
send_varname_to_ctx
);
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
recv_scope_
=
std
::
move
(
recv_scope
);
PADDLE_ENFORCE_GT
(
send_varname_to_ctx
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"send var contexts can not be zero"
));
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
if
(
!
ctx
.
is_sparse
)
{
parallel_task_nums_
+=
1
;
continue
;
}
auto
&
varnames
=
ctx
.
origin_varnames
;
PADDLE_ENFORCE_EQ
(
varnames
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"sparse variables can only be merged by one variables"
));
for
(
auto
&
splited_var
:
ctx
.
splited_varnames
)
{
parallel_task_nums_
+=
1
;
sparse_id_queues_
.
insert
(
std
::
pair
<
std
::
string
,
paddle
::
framework
::
Channel
<
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>>
(
splited_var
,
paddle
::
framework
::
MakeChannel
<
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>
(
send_queue_size_
)));
}
}
send_threadpool_
.
reset
(
new
::
ThreadPool
(
thread_pool_size_
));
delta_scope_
.
reset
(
new
Scope
());
old_scope_
.
reset
(
new
Scope
());
pserver_scope_
.
reset
(
new
Scope
());
}
void
GeoCommunicator
::
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
{
std
::
vector
<
std
::
future
<
void
>>
tasks
;
tasks
.
reserve
(
recv_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
&
table_id
=
iter
.
first
;
auto
&
varnames
=
iter
.
second
;
auto
recv_task
=
[
this
,
&
table_id
,
&
varnames
]
{
InitDense
(
varnames
,
table_id
);
};
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
recv_task
)));
}
for
(
auto
&
task
:
tasks
)
{
task
.
wait
();
}
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
if
(
!
ctx
.
is_sparse
)
continue
;
auto
&
varname
=
ctx
.
origin_varnames
[
0
];
auto
&
table_id
=
ctx
.
table_id
;
auto
param
=
varname
.
substr
(
0
,
varname
.
size
()
-
5
);
InitSparse
(
param
,
table_id
);
}
return
;
}
void
GeoCommunicator
::
InitDense
(
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
)
{
if
(
trainer_id_
==
0
)
{
RpcSendDenseParam
(
varnames
,
table_id
,
*
recv_scope_
);
BarrierWithTable
(
1
);
VLOG
(
1
)
<<
"push dense param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
else
{
BarrierWithTable
(
1
);
RpcRecvDense
(
varnames
,
table_id
,
recv_scope_
);
VLOG
(
1
)
<<
"pull dense param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
// copy to old_scope
for
(
auto
&
t
:
varnames
)
{
auto
*
global_var
=
recv_scope_
->
FindVar
(
t
);
global_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
old_var
=
old_scope_
->
Var
(
t
);
old_var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
CopyVariable
(
*
global_var
,
old_var
);
// init pserver_scope_
auto
*
pserver_var
=
pserver_scope_
->
Var
(
t
);
pserver_var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
CopyVariable
(
*
global_var
,
pserver_var
);
}
VLOG
(
1
)
<<
"init dense table "
<<
table_id
<<
" done"
;
}
void
GeoCommunicator
::
SendDense
(
const
CommContext
&
send_ctx
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->SendDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
&
var_names
=
send_ctx
.
origin_varnames
;
auto
&
table_id
=
send_ctx
.
table_id
;
for
(
auto
&
varname
:
var_names
)
{
auto
param_name
=
GradToParam
(
varname
);
auto
*
var_latest
=
recv_scope_
->
FindVar
(
param_name
);
auto
*
var_timestamp
=
old_scope_
->
FindVar
(
param_name
);
PADDLE_ENFORCE_EQ
(
var_latest
->
IsInitialized
(),
true
,
platform
::
errors
::
Unavailable
(
"%s is not initialized, please check"
,
param_name
));
PADDLE_ENFORCE_EQ
(
var_timestamp
->
IsInitialized
(),
true
,
platform
::
errors
::
Unavailable
(
"%s is not initialized, please check"
,
param_name
));
auto
&
t_latest
=
var_latest
->
Get
<
framework
::
LoDTensor
>
();
auto
t_timestamp
=
var_timestamp
->
GetMutable
<
framework
::
LoDTensor
>
();
phi
::
CPUContext
cpu_ctx
;
auto
*
var_delta
=
delta_scope_
->
Var
(
varname
);
auto
*
t_delta
=
var_delta
->
GetMutable
<
framework
::
LoDTensor
>
();
t_delta
->
mutable_data
<
float
>
(
t_latest
.
dims
(),
cpu_ctx
.
GetPlace
());
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
CPUContext
,
float
>
(
cpu_ctx
);
blas
.
VSUB
(
t_latest
.
numel
(),
t_latest
.
data
<
float
>
(),
t_timestamp
->
data
<
float
>
(),
t_delta
->
data
<
float
>
());
float
coefficient
=
1.0
/
static_cast
<
float
>
(
trainers_
);
blas
.
SCAL
(
t_latest
.
numel
(),
coefficient
,
t_delta
->
data
<
float
>
());
blas
.
VADD
(
t_latest
.
numel
(),
t_timestamp
->
data
<
float
>
(),
t_delta
->
data
<
float
>
(),
t_timestamp
->
data
<
float
>
());
}
RpcSendDense
(
send_ctx
,
*
delta_scope_
);
VLOG
(
1
)
<<
"Finish Send Dense "
<<
var_names
[
0
]
<<
", table_id: "
<<
table_id
;
return
;
}
void
GeoCommunicator
::
RecvDense
(
const
CommContext
&
send_ctx
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->RecvDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
&
table_id
=
send_ctx
.
table_id
;
auto
&
varnames
=
recv_varname_to_ctx_
.
at
(
table_id
);
// 1. recv from pserver
RpcRecvDense
(
varnames
,
table_id
,
pserver_scope_
.
get
());
// 2.1 pserver - old => delta; 2.2 latest + old => latest 2.3 old => pserver
phi
::
CPUContext
cpu_ctx
;
for
(
auto
&
varname
:
varnames
)
{
auto
*
var_latest
=
recv_scope_
->
FindVar
(
varname
);
auto
t_latest
=
var_latest
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
var_old
=
old_scope_
->
FindVar
(
varname
);
auto
t_old
=
var_old
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
var_pserver
=
pserver_scope_
->
FindVar
(
varname
);
auto
t_pserver
=
var_pserver
->
Get
<
framework
::
LoDTensor
>
();
auto
*
var_delta
=
delta_scope_
->
Var
(
varname
);
auto
*
t_delta
=
var_delta
->
GetMutable
<
framework
::
LoDTensor
>
();
t_delta
->
mutable_data
<
float
>
(
t_latest
->
dims
(),
cpu_ctx
.
GetPlace
());
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
CPUContext
,
float
>
(
cpu_ctx
);
blas
.
VSUB
(
t_latest
->
numel
(),
t_pserver
.
data
<
float
>
(),
t_old
->
data
<
float
>
(),
t_delta
->
data
<
float
>
());
blas
.
VADD
(
t_latest
->
numel
(),
t_latest
->
data
<
float
>
(),
t_delta
->
data
<
float
>
(),
t_latest
->
data
<
float
>
());
blas
.
VCOPY
(
t_latest
->
numel
(),
t_pserver
.
data
<
float
>
(),
t_old
->
data
<
float
>
());
}
VLOG
(
1
)
<<
"Finish Recv Dense "
<<
varnames
[
0
]
<<
", table_id: "
<<
table_id
;
return
;
}
void
GeoCommunicator
::
InitSparse
(
const
std
::
string
&
var_name
,
int
table_id
)
{
VLOG
(
1
)
<<
"Init Sparse "
<<
var_name
<<
" : table "
<<
table_id
<<
" begin."
;
if
(
trainer_id_
==
0
)
{
RpcSendSparseParam
(
var_name
,
table_id
,
*
recv_scope_
);
BarrierWithTable
(
1
);
VLOG
(
1
)
<<
"push sparse param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
else
{
BarrierWithTable
(
1
);
RpcRecvSparse
(
var_name
,
table_id
,
recv_scope_
);
VLOG
(
1
)
<<
"pull sparse param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
VLOG
(
1
)
<<
"Init Sparse "
<<
var_name
<<
" : table "
<<
table_id
<<
" done."
;
auto
*
global_var
=
recv_scope_
->
FindVar
(
var_name
);
auto
*
var
=
old_scope_
->
Var
(
var_name
);
framework
::
CopyVariable
(
*
global_var
,
var
);
return
;
}
std
::
vector
<
int64_t
>
GeoCommunicator
::
MergeSparseIds
(
const
std
::
string
&
send_varname
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->MergeSparseIds"
,
platform
::
TracerEventType
::
Communication
,
1
);
size_t
merge_num
=
0
,
wait_times
=
0
;
std
::
unordered_set
<
int64_t
>
sparse_ids
;
while
(
merge_num
<
static_cast
<
size_t
>
(
max_merge_var_num_
))
{
VLOG
(
3
)
<<
"Merge Number of "
<<
send_varname
<<
" = "
<<
merge_num
;
if
(
sparse_id_queues_
.
at
(
send_varname
)
->
Size
()
>
0
)
{
wait_times
=
0
;
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>
pop_ids
=
nullptr
;
sparse_id_queues_
.
at
(
send_varname
)
->
Get
(
pop_ids
);
for
(
size_t
j
=
0
;
j
<
pop_ids
->
size
();
j
++
)
{
sparse_ids
.
insert
(
pop_ids
->
at
(
j
));
}
merge_num
+=
1
;
VLOG
(
3
)
<<
"sparse_id_queues_("
<<
send_varname
<<
") pushed"
;
}
else
if
(
sparse_id_queues_
.
at
(
send_varname
)
->
Size
()
==
0
)
{
VLOG
(
3
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
static_cast
<
size_t
>
(
send_wait_times_
))
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
wait_times
++
;
continue
;
}
}
std
::
vector
<
int64_t
>
res
;
res
.
assign
(
sparse_ids
.
begin
(),
sparse_ids
.
end
());
return
res
;
}
void
GeoCommunicator
::
SendSparse
(
const
std
::
string
&
varname
,
std
::
vector
<
int64_t
>
&
sparse_ids
,
int
table_id
,
int
ep_idx
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->SendSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
if
(
sparse_ids
.
size
()
==
0
)
{
return
;
}
std
::
string
param_name
=
SplitedGradToParam
(
varname
);
VLOG
(
1
)
<<
"In GeoCommunicator::SendSparse("
<<
varname
<<
" "
<<
param_name
<<
", ids.size = "
<<
sparse_ids
.
size
()
<<
", table_id: "
<<
table_id
<<
", ep_idx: "
<<
ep_idx
;
auto
*
var_latest
=
recv_scope_
->
FindVar
(
param_name
);
auto
*
var_old
=
old_scope_
->
FindVar
(
param_name
);
PADDLE_ENFORCE_EQ
(
var_latest
->
IsInitialized
(),
true
,
platform
::
errors
::
Unavailable
(
"%s is not initialized, please check"
,
param_name
));
PADDLE_ENFORCE_EQ
(
var_old
->
IsInitialized
(),
true
,
platform
::
errors
::
Unavailable
(
"%s is not initialized, please check"
,
param_name
));
auto
&
t_latest
=
var_latest
->
Get
<
framework
::
LoDTensor
>
();
auto
*
t_old
=
var_old
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dims1
=
t_latest
.
dims
()[
1
];
phi
::
CPUContext
cpu_ctx
;
auto
*
var_delta
=
delta_scope_
->
Var
(
varname
);
auto
*
t_delta
=
var_delta
->
GetMutable
<
phi
::
SelectedRows
>
();
auto
*
var_t_value
=
t_delta
->
mutable_value
();
var_t_value
->
Resize
({
static_cast
<
int64_t
>
(
sparse_ids
.
size
()),
dims1
});
auto
*
t_value
=
var_t_value
->
mutable_data
<
float
>
(
cpu_ctx
.
GetPlace
());
t_delta
->
set_rows
(
sparse_ids
);
t_delta
->
set_height
(
t_latest
.
dims
()[
0
]);
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
CPUContext
,
float
>
(
cpu_ctx
);
float
coefficient
=
1.0
/
static_cast
<
float
>
(
trainers_
);
std
::
vector
<
float
*>
push_g_vec
;
for
(
auto
j
=
0
;
j
<
static_cast
<
int
>
(
sparse_ids
.
size
());
++
j
)
{
blas
.
VSUB
(
dims1
,
t_latest
.
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
,
t_old
->
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
,
t_value
+
j
*
dims1
);
blas
.
SCAL
(
dims1
,
coefficient
,
t_value
+
j
*
dims1
);
blas
.
VADD
(
dims1
,
t_old
->
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
,
t_value
+
j
*
dims1
,
t_old
->
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
);
push_g_vec
.
push_back
(
t_value
+
j
*
dims1
);
VLOG
(
5
)
<<
"DEBUG GeoCommunicator::SendSparse send sparse key "
<<
sparse_ids
[
j
]
<<
" value[0] "
<<
push_g_vec
[
j
][
0
]
<<
" value[-1] "
<<
push_g_vec
[
j
][
dims1
-
1
];
}
++
_async_call_num
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
this
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
if
(
closure
->
check_response
(
0
,
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
auto
status
=
_worker_ptr
->
PushSparseRawGradientPartial
(
table_id
,
(
const
uint64_t
*
)
sparse_ids
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
sparse_ids
.
size
(),
closure
,
ep_idx
);
status
.
wait
();
VLOG
(
1
)
<<
"Finish Send Sparse "
<<
varname
<<
", ids.size = "
<<
sparse_ids
.
size
()
<<
", table_id: "
<<
table_id
;
return
;
}
void
GeoCommunicator
::
RecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
int
ep_idx
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->RecvSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
// 1. recv from pserver
std
::
vector
<
uint64_t
>
keys
;
std
::
vector
<
float
>
values
;
auto
status
=
_worker_ptr
->
PullGeoParam
(
table_id
,
&
values
,
&
keys
,
ep_idx
);
status
.
wait
();
std
::
string
param
=
SplitedGradToParam
(
varname
);
VLOG
(
1
)
<<
"RecvSparse receive var: "
<<
varname
<<
" "
<<
param
<<
", "
<<
table_id
<<
"; ids Size: "
<<
keys
.
size
()
<<
"; values size: "
<<
values
.
size
();
auto
*
var_latest
=
recv_scope_
->
FindVar
(
param
);
auto
*
var_old
=
old_scope_
->
FindVar
(
param
);
auto
*
t_latest
=
var_latest
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
t_old
=
var_old
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dims1
=
t_latest
->
dims
()[
1
];
auto
numel
=
keys
.
size
()
*
dims1
;
std
::
vector
<
float
>
v_delta
;
v_delta
.
resize
(
numel
);
phi
::
CPUContext
cpu_ctx
;
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
CPUContext
,
float
>
(
cpu_ctx
);
for
(
auto
j
=
0
;
j
<
static_cast
<
int
>
(
keys
.
size
());
++
j
)
{
VLOG
(
5
)
<<
"DEBUG GeoCommunicator::RecvSparse recv sparse key"
<<
keys
[
j
]
<<
"value[0] "
<<
values
[
j
*
dims1
]
<<
" value[-1] "
<<
values
[
j
*
dims1
+
dims1
-
1
];
float
*
latest_data
=
t_latest
->
data
<
float
>
()
+
keys
[
j
]
*
dims1
;
float
*
old_data
=
t_old
->
data
<
float
>
()
+
keys
[
j
]
*
dims1
;
// pserver - old => delta
blas
.
VSUB
(
dims1
,
values
.
data
()
+
j
*
dims1
,
old_data
,
v_delta
.
data
()
+
j
*
dims1
);
// latest + delta => latest
blas
.
VADD
(
dims1
,
latest_data
,
v_delta
.
data
()
+
j
*
dims1
,
latest_data
);
// pserver => old
blas
.
VCOPY
(
dims1
,
values
.
data
()
+
j
*
dims1
,
old_data
);
}
VLOG
(
1
)
<<
"Finish Recv Sparse "
<<
param
<<
", table_id: "
<<
table_id
;
}
void
GeoCommunicator
::
MainThread
()
{
VLOG
(
3
)
<<
"MainThread start and wait"
;
while
(
waiting_
&&
running_
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
VLOG
(
3
)
<<
"wait for running"
;
}
while
(
running_
)
{
std
::
vector
<
std
::
future
<
void
>>
tasks
;
tasks
.
reserve
(
parallel_task_nums_
);
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
auto
&
varnames
=
ctx
.
origin_varnames
;
auto
&
table_id
=
ctx
.
table_id
;
if
(
ctx
.
is_sparse
)
{
PADDLE_ENFORCE_EQ
(
varnames
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"sparse variables can only be merged by one variables"
));
int
pserver_num
=
static_cast
<
int
>
(
ctx
.
epmap
.
size
());
for
(
int
ep_idx
=
0
;
ep_idx
<
pserver_num
;
ep_idx
++
)
{
// varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0
auto
send_recv_task
=
[
this
,
table_id
,
ep_idx
,
&
ctx
]
{
auto
splited_varname
=
ctx
.
splited_varnames
[
ep_idx
];
auto
sparse_ids
=
MergeSparseIds
(
splited_varname
);
SendSparse
(
splited_varname
,
sparse_ids
,
table_id
,
ep_idx
);
RecvSparse
(
splited_varname
,
table_id
,
ep_idx
);
};
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_recv_task
)));
}
}
else
{
auto
send_recv_task
=
[
this
,
&
ctx
]
{
SendDense
(
ctx
);
RecvDense
(
ctx
);
};
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_recv_task
)));
}
}
for
(
auto
&
task
:
tasks
)
{
task
.
wait
();
}
}
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/communicator/communicator.h
0 → 100644
View file @
d2d32668
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
distributed
{
class
PSClient
;
struct
CommContext
;
}
// namespace distributed
}
// namespace paddle
DECLARE_bool
(
communicator_is_sgd_optimizer
);
namespace
paddle
{
namespace
distributed
{
using
Scope
=
framework
::
Scope
;
using
Variable
=
framework
::
Variable
;
template
<
typename
T
>
class
BlockingQueue
{
public:
explicit
BlockingQueue
(
size_t
capacity
)
:
capacity_
(
capacity
)
{
PADDLE_ENFORCE_GT
(
capacity_
,
0
,
platform
::
errors
::
InvalidArgument
(
"The capacity must be greater than 0."
));
}
bool
Push
(
const
T
&
elem
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
WaitForWrite
(
lock
);
queue_
.
push_back
(
elem
);
Notify
();
return
true
;
}
bool
WaitForWrite
(
std
::
unique_lock
<
std
::
mutex
>
&
lock
)
{
// NOLINT
while
(
FullUnlocked
())
{
if
(
empty_waiters_
!=
0
)
{
empty_cond_
.
notify_one
();
}
full_waiters_
++
;
full_cond_
.
wait
(
lock
);
full_waiters_
--
;
}
return
true
;
}
bool
WaitForRead
(
std
::
unique_lock
<
std
::
mutex
>
&
lock
)
{
// NOLINT
while
(
EmptyUnlocked
())
{
if
(
full_waiters_
!=
0
)
{
full_cond_
.
notify_one
();
}
empty_waiters_
++
;
empty_cond_
.
wait
(
lock
);
empty_waiters_
--
;
}
return
true
;
}
bool
EmptyUnlocked
()
{
return
queue_
.
empty
();
}
bool
FullUnlocked
()
{
return
queue_
.
size
()
>=
capacity_
;
}
void
Notify
()
{
if
(
empty_waiters_
!=
0
&&
(
!
EmptyUnlocked
()))
{
empty_cond_
.
notify_one
();
}
if
(
full_waiters_
!=
0
&&
(
!
FullUnlocked
()))
{
full_cond_
.
notify_one
();
}
}
bool
Push
(
T
&&
elem
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
WaitForWrite
(
lock
);
queue_
.
emplace_back
(
std
::
move
(
elem
));
Notify
();
return
true
;
}
T
Pop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
WaitForRead
(
lock
);
T
rc
(
std
::
move
(
queue_
.
front
()));
queue_
.
pop_front
();
Notify
();
return
rc
;
}
size_t
Cap
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
capacity_
;
}
size_t
Size
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
queue_
.
size
();
}
private:
int
empty_waiters_
=
0
;
int
full_waiters_
=
0
;
std
::
condition_variable
empty_cond_
;
std
::
condition_variable
full_cond_
;
const
size_t
capacity_
;
std
::
deque
<
T
>
queue_
;
mutable
std
::
mutex
mutex_
;
};
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
inline
void
MergeVars
(
const
std
::
string
&
var_name
,
const
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
&
vars
,
Scope
*
scope
,
bool
merge_add
=
true
)
{
PADDLE_ENFORCE_NE
(
vars
.
empty
(),
true
,
platform
::
errors
::
InvalidArgument
(
"vector vars are empty."
));
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
var0
=
vars
[
0
];
auto
*
out_var
=
scope
->
Var
(
var_name
);
if
(
var0
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
dims
=
var0
->
Get
<
framework
::
LoDTensor
>
().
dims
();
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" LoDTensor dims "
<<
dims
<<
"; merge add: "
<<
merge_add
;
// init output tensor
auto
*
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
out_t
->
mutable_data
<
T
>
(
dims
,
cpu_place
);
// check the input dims
for
(
auto
&
var
:
vars
)
{
auto
&
var_t
=
var
->
Get
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
var_t
.
dims
(),
dims
,
platform
::
errors
::
InvalidArgument
(
"vars should have the same dims."
));
}
// set output tensor to 0.
phi
::
CPUContext
cpu_ctx
;
phi
::
funcs
::
SetConstant
<
phi
::
CPUContext
,
T
>
constant_functor
;
constant_functor
(
cpu_ctx
,
out_t
,
static_cast
<
T
>
(
0
));
// sum all vars to out
auto
result
=
EigenVector
<
T
>::
Flatten
(
*
out_t
);
for
(
auto
&
var
:
vars
)
{
auto
&
in_t
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
in
=
EigenVector
<
T
>::
Flatten
(
in_t
);
result
.
device
(
*
cpu_ctx
.
eigen_device
())
=
result
+
in
;
}
if
(
!
merge_add
)
{
result
.
device
(
*
cpu_ctx
.
eigen_device
())
=
result
/
static_cast
<
T
>
(
vars
.
size
());
}
}
else
if
(
var0
->
IsType
<
phi
::
SelectedRows
>
())
{
auto
&
slr0
=
var0
->
Get
<
phi
::
SelectedRows
>
();
auto
*
out_slr
=
out_var
->
GetMutable
<
phi
::
SelectedRows
>
();
out_slr
->
mutable_rows
()
->
clear
();
out_slr
->
mutable_value
()
->
mutable_data
<
T
>
({{}},
cpu_place
);
std
::
vector
<
const
phi
::
SelectedRows
*>
inputs
;
inputs
.
reserve
(
vars
.
size
());
for
(
auto
&
var
:
vars
)
{
inputs
.
push_back
(
&
var
->
Get
<
phi
::
SelectedRows
>
());
}
phi
::
CPUContext
dev_ctx
;
if
(
merge_add
)
{
paddle
::
operators
::
math
::
scatter
::
MergeAdd
<
phi
::
CPUContext
,
T
>
merge_add
;
merge_add
(
dev_ctx
,
inputs
,
out_slr
);
}
else
{
paddle
::
operators
::
math
::
scatter
::
MergeAverage
<
phi
::
CPUContext
,
T
>
merge_average
;
merge_average
(
dev_ctx
,
inputs
,
out_slr
);
}
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" SelectedRows height: "
<<
slr0
.
height
()
<<
" dims: "
<<
slr0
.
value
().
dims
()
<<
"; merge add: "
<<
merge_add
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"unsupported var type: %s!"
,
var0
->
Type
()));
}
}
using
RpcCtxMap
=
std
::
unordered_map
<
std
::
string
,
CommContext
>
;
using
RecvCtxMap
=
std
::
unordered_map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
;
using
SparseValue
=
std
::
unordered_map
<
int64_t
,
std
::
vector
<
float
>>
;
class
Communicator
{
public:
Communicator
();
explicit
Communicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs_
)
{
VLOG
(
3
)
<<
"Communicator Init Envs"
;
for
(
auto
&
iter
:
envs_
)
{
envs
[
iter
.
first
]
=
iter
.
second
;
VLOG
(
3
)
<<
iter
.
first
<<
": "
<<
iter
.
second
;
}
barrier_table_id_
=
std
::
stoi
(
envs
.
at
(
"barrier_table_id"
));
trainer_id_
=
std
::
stoi
(
envs
.
at
(
"trainer_id"
));
trainers_
=
std
::
stoi
(
envs
.
at
(
"trainers"
));
}
virtual
void
InitBrpcClient
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
);
virtual
std
::
vector
<
uint64_t
>
GetClientInfo
();
virtual
int
SetClients
(
std
::
vector
<
uint64_t
>
&
host_sign_list
);
// NOLINT
// 1. recv dense param
virtual
void
RpcRecvDense
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
Scope
*
scope
);
// 2. send dense param
virtual
void
RpcSendDenseParam
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
const
Scope
&
scope
);
// 3. send dense grad
virtual
void
RpcSendDense
(
const
CommContext
&
ctx
,
const
Scope
&
scope
);
// 4. send sparse grad
virtual
void
RpcSendSparse
(
const
std
::
string
&
var_name
,
int
table_id
,
const
Scope
&
scope
);
// 5. send sparse param
virtual
void
RpcSendSparseParam
(
const
std
::
string
&
varname
,
int
table_id
,
const
Scope
&
scope
);
// 6. recv sparse param
virtual
void
RpcRecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
Scope
*
scope
);
// 7. send gloabl step
virtual
void
SendGlobalStep
(
const
CommContext
&
ctx
,
int
batches
,
Scope
*
send_scope
);
virtual
~
Communicator
()
{}
virtual
void
RpcProfilerControl
();
virtual
void
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
);
// note: only for pull dense param first before training
virtual
void
PullDense
(
const
RecvCtxMap
&
recv_varname_to_ctx
);
virtual
void
Start
()
=
0
;
virtual
void
Stop
()
=
0
;
virtual
bool
IsRunning
()
{
return
running_
;
}
virtual
void
Clean
()
{}
virtual
bool
Check
(
const
int
table_id
)
=
0
;
virtual
bool
Check
(
const
std
::
vector
<
std
::
string
>
&
var_tables
)
=
0
;
virtual
void
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
=
0
;
virtual
void
RecvNoBarrier
()
{}
virtual
void
Barrier
()
{}
virtual
void
BarrierWithTable
(
uint32_t
barrier_type
)
{
auto
rets
=
_worker_ptr
->
Barrier
(
barrier_table_id_
,
barrier_type
);
rets
.
wait
();
int
status
=
rets
.
get
();
PADDLE_ENFORCE_EQ
(
status
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ret status must be 0 when barrier with table"
));
}
virtual
void
CreateC2CConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
{
_worker_ptr
->
CreateClient2ClientConnection
(
pserver_timeout_ms
,
pserver_connect_timeout_ms
,
max_retry
);
}
virtual
void
BarrierTriggerDecrement
()
{}
virtual
void
BarrierTriggerReset
(
int
init_counter
)
{}
virtual
void
InitEnvs
()
=
0
;
virtual
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{}
static
Communicator
*
GetInstance
()
{
return
communicator_
.
get
();
}
static
std
::
shared_ptr
<
Communicator
>
GetInstantcePtr
()
{
return
communicator_
;
}
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
RpcCtxMap
&
send_ctx
,
const
RecvCtxMap
&
recv_ctx
,
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
,
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithRpcCtx
<
T
>
,
send_ctx
,
recv_ctx
,
dist_desc
,
host_sign_list
,
recv_scope
,
std
::
ref
(
envs
));
return
communicator_
.
get
();
}
// Init is called by InitInstance.
template
<
typename
T
>
static
void
InitWithRpcCtx
(
const
RpcCtxMap
&
send_ctx
,
const
RecvCtxMap
&
recv_ctx
,
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
,
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
{
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
(
std
::
ref
(
envs
)));
communicator_
->
InitEnvs
();
communicator_
->
InitBrpcClient
(
dist_desc
,
host_sign_list
);
communicator_
->
InitImpl
(
send_ctx
,
recv_ctx
,
recv_scope
);
}
}
PSClient
*
GetPsClient
()
{
return
_worker_ptr
.
get
();
}
std
::
shared_ptr
<
paddle
::
distributed
::
PSClient
>
GetPsClientPtr
()
{
return
std
::
move
(
_worker_ptr
);
}
RecvCtxMap
&
GetRecvCtxMap
()
{
return
recv_varname_to_ctx_
;
}
std
::
shared_ptr
<
PSClient
>
_worker_ptr
;
// pointer to worker
protected:
bool
running_
=
false
;
bool
waiting_
=
true
;
bool
flushing_
=
false
;
bool
do_server_profiler_
=
false
;
static
std
::
shared_ptr
<
Communicator
>
communicator_
;
static
std
::
once_flag
init_flag_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
envs
;
// 计算每个shard 对 dense的存储量
inline
uint32_t
DenseDimPerShard
(
uint32_t
dense_dim_total
,
uint32_t
shard_num
)
{
return
dense_dim_total
/
shard_num
+
1
;
}
void
InitGFlag
(
const
std
::
string
&
gflags
);
paddle
::
distributed
::
PSParameter
_ps_param
;
paddle
::
distributed
::
PaddlePSEnvironment
_ps_env
;
int
servers_
=
0
;
int
trainers_
;
int
trainer_id_
=
0
;
int
barrier_table_id_
=
0
;
RpcCtxMap
send_varname_to_ctx_
;
RecvCtxMap
recv_varname_to_ctx_
;
Scope
*
recv_scope_
;
// should be global scope
std
::
unique_ptr
<
Scope
>
xpu_temp_scope_
;
std
::
atomic
<
uint32_t
>
_async_call_num
{
0
};
};
class
AsyncCommunicator
:
public
Communicator
{
public:
AsyncCommunicator
()
:
Communicator
()
{}
explicit
AsyncCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
:
Communicator
(
envs
)
{}
~
AsyncCommunicator
();
void
InitEnvs
()
{
independent_recv_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"communicator_independent_recv_thread"
)));
min_send_grad_num_before_recv_
=
std
::
stoi
(
envs
.
at
(
"communicator_min_send_grad_num_before_recv"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
need_global_step_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"need_global_step"
)));
}
void
Start
()
override
;
void
Stop
()
override
;
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
override
;
virtual
void
MainThread
();
virtual
void
RecvThread
();
virtual
bool
Check
(
const
int
table_id
);
virtual
bool
Check
(
const
std
::
vector
<
std
::
string
>
&
var_tables
);
void
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
override
;
virtual
void
SendByCommunicator
();
virtual
void
RecvByCommunicator
();
virtual
void
RecvNoBarrier
();
virtual
int
BatchesCounter
()
{
return
1
;
}
virtual
void
BarrierSend
()
{}
virtual
void
BarrierRecv
()
{}
virtual
void
BarrierWeakUp
()
{}
void
PushDensePostProcessing
();
void
PullSparseToTensorSync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
bool
is_training
,
std
::
vector
<
const
framework
::
LoDTensor
*>
*
inputs
,
// NOLINT
std
::
vector
<
framework
::
LoDTensor
*>
*
outputs
);
// NOLINT
void
PushSparseFromTensorAsync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
framework
::
LoDTensor
*>
*
inputs
,
const
framework
::
LoDTensor
*
shows
,
const
framework
::
LoDTensor
*
clicks
,
std
::
vector
<
framework
::
LoDTensor
*>
*
outputs
);
protected:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>>
send_varname_to_queue_
;
std
::
unique_ptr
<::
ThreadPool
>
send_threadpool_
{
nullptr
};
int
min_send_grad_num_before_recv_
;
int
thread_pool_size_
;
int
max_merge_var_num_
;
int
send_wait_times_
;
int
send_queue_size_
;
bool
need_global_step_
=
false
;
bool
independent_recv_
=
true
;
int
parallel_task_nums_
=
0
;
int32_t
sleep_seconds_before_fail_exit_
;
std
::
unique_ptr
<
std
::
thread
>
main_thread_
{
nullptr
};
std
::
unique_ptr
<
std
::
thread
>
recv_thread_
{
nullptr
};
std
::
unique_ptr
<
Scope
>
send_scope_
;
// an independent scope
std
::
atomic_uint
grad_num_
{
0
};
// the num of gradient sent since last recv
};
class
HalfAsyncCommunicator
:
public
AsyncCommunicator
{
public:
HalfAsyncCommunicator
()
{}
explicit
HalfAsyncCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
:
AsyncCommunicator
(
envs
)
{}
void
InitEnvs
()
{
// enfore to recv after send
independent_recv_
=
false
;
min_send_grad_num_before_recv_
=
0
;
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
need_global_step_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"need_global_step"
)));
VLOG
(
1
)
<<
"HalfAsyncCommunicator Initialized"
;
}
void
MainThread
()
override
;
void
SendByCommunicator
()
override
;
void
Clean
()
override
;
void
Barrier
()
override
;
void
BarrierTriggerDecrement
()
override
;
void
BarrierTriggerReset
(
int
initial_val
)
override
;
int
BatchesCounter
();
void
BarrierWeakUp
();
protected:
// mutex for Wait for barrier
std
::
mutex
barrier_mutex_
;
std
::
condition_variable
barrier_cond_
;
std
::
atomic
<
int64_t
>
barrier_trigger_
{
0
};
std
::
atomic
<
int64_t
>
barrier_counter_
{
0
};
};
class
SyncCommunicator
:
public
HalfAsyncCommunicator
{
public:
SyncCommunicator
()
:
HalfAsyncCommunicator
()
{}
explicit
SyncCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
:
HalfAsyncCommunicator
(
envs
)
{}
void
InitEnvs
()
{
// enfore to recv after send
independent_recv_
=
false
;
min_send_grad_num_before_recv_
=
0
;
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
need_global_step_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"need_global_step"
)));
VLOG
(
1
)
<<
"SyncCommunicator Initialized"
;
}
void
BarrierSend
();
void
BarrierRecv
();
private:
std
::
vector
<
std
::
string
>
pserver_endpoints_
{};
};
class
GeoCommunicator
:
public
AsyncCommunicator
{
public:
GeoCommunicator
()
:
AsyncCommunicator
()
{}
explicit
GeoCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
:
AsyncCommunicator
(
envs
)
{}
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
override
;
void
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
override
;
void
InitDense
(
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
);
// NOLINT
void
InitSparse
(
const
std
::
string
&
var_name
,
int
table_id
);
void
SendDense
(
const
CommContext
&
send_ctx
);
void
RecvDense
(
const
CommContext
&
send_ctx
);
std
::
vector
<
int64_t
>
MergeSparseIds
(
const
std
::
string
&
varname
);
void
SendSparse
(
const
std
::
string
&
varname
,
std
::
vector
<
int64_t
>
&
sparse_ids
,
// NOLINT
int
table_id
,
int
ep_idx
);
void
RecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
int
ep_idx
);
void
MainThread
()
override
;
void
InitEnvs
()
{
independent_recv_
=
false
;
min_send_grad_num_before_recv_
=
0
;
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
// id_queue's size
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_queue_size_
=
max_merge_var_num_
;
VLOG
(
1
)
<<
"GeoCommunicator Initialized"
;
}
void
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
override
;
void
SendByCommunicator
()
{
return
;
}
void
RecvByCommunicator
()
override
{
return
;
}
inline
std
::
string
GradToParam
(
const
std
::
string
var_name
)
{
std
::
string
param_name
=
var_name
.
substr
(
0
,
var_name
.
size
()
-
5
);
return
param_name
;
}
inline
std
::
string
SplitedGradToParam
(
const
std
::
string
delta_name
)
{
// delta_name: emb.delta0
auto
pos
=
delta_name
.
find
(
".block"
);
std
::
string
param_name
=
delta_name
.
substr
(
0
,
pos
);
return
param_name
;
}
private:
// parameter for delta calc and send
std
::
shared_ptr
<
Scope
>
delta_scope_
;
// parameter for storage the pserver param after last recv
std
::
shared_ptr
<
Scope
>
old_scope_
;
// parameter on pserver
std
::
shared_ptr
<
Scope
>
pserver_scope_
;
std
::
unordered_map
<
std
::
string
,
paddle
::
framework
::
Channel
<
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>>
sparse_id_queues_
;
};
}
// namespace distributed
}
// namespace paddle
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment