Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
de2e6515
Commit
de2e6515
authored
Apr 26, 2023
by
yuguo960516yuguo
Browse files
2.4.1-dtk-23.04
parent
ad08b8ce
Pipeline
#228
failed with stages
in 0 seconds
Changes
272
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7643 additions
and
0 deletions
+7643
-0
paddle/fluid/distributed/ps/service/communicator/communicator.cc
...fluid/distributed/ps/service/communicator/communicator.cc
+1590
-0
paddle/fluid/distributed/ps/service/communicator/communicator.h
.../fluid/distributed/ps/service/communicator/communicator.h
+715
-0
paddle/fluid/distributed/ps/service/communicator/communicator_common.h
...distributed/ps/service/communicator/communicator_common.h
+124
-0
paddle/fluid/distributed/ps/service/coordinator_client.cc
paddle/fluid/distributed/ps/service/coordinator_client.cc
+205
-0
paddle/fluid/distributed/ps/service/coordinator_client.h
paddle/fluid/distributed/ps/service/coordinator_client.h
+256
-0
paddle/fluid/distributed/ps/service/env.cc
paddle/fluid/distributed/ps/service/env.cc
+19
-0
paddle/fluid/distributed/ps/service/env.h
paddle/fluid/distributed/ps/service/env.h
+321
-0
paddle/fluid/distributed/ps/service/graph_brpc_client.cc
paddle/fluid/distributed/ps/service/graph_brpc_client.cc
+702
-0
paddle/fluid/distributed/ps/service/graph_brpc_client.h
paddle/fluid/distributed/ps/service/graph_brpc_client.h
+139
-0
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
+692
-0
paddle/fluid/distributed/ps/service/graph_brpc_server.h
paddle/fluid/distributed/ps/service/graph_brpc_server.h
+170
-0
paddle/fluid/distributed/ps/service/heter_client.cc
paddle/fluid/distributed/ps/service/heter_client.cc
+428
-0
paddle/fluid/distributed/ps/service/heter_client.h
paddle/fluid/distributed/ps/service/heter_client.h
+257
-0
paddle/fluid/distributed/ps/service/heter_server.cc
paddle/fluid/distributed/ps/service/heter_server.cc
+262
-0
paddle/fluid/distributed/ps/service/heter_server.h
paddle/fluid/distributed/ps/service/heter_server.h
+685
-0
paddle/fluid/distributed/ps/service/ps_client.cc
paddle/fluid/distributed/ps/service/ps_client.cc
+93
-0
paddle/fluid/distributed/ps/service/ps_client.h
paddle/fluid/distributed/ps/service/ps_client.h
+376
-0
paddle/fluid/distributed/ps/service/ps_local_client.cc
paddle/fluid/distributed/ps/service/ps_local_client.cc
+328
-0
paddle/fluid/distributed/ps/service/ps_local_client.h
paddle/fluid/distributed/ps/service/ps_local_client.h
+237
-0
paddle/fluid/distributed/ps/service/ps_local_server.h
paddle/fluid/distributed/ps/service/ps_local_server.h
+44
-0
No files found.
Too many changes to show.
To preserve performance only
272 of 272+
files are displayed.
Plain diff
Email patch
paddle/fluid/distributed/ps/service/communicator/communicator.cc
0 → 100644
View file @
de2e6515
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include <google/protobuf/text_format.h>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
namespace
paddle
{
namespace
distributed
{
using
framework
::
LoDTensor
;
using
phi
::
SelectedRows
;
const
uint32_t
MAX_FEASIGN_NUM
=
1024
*
100
*
100
;
inline
double
GetCurrentUS
()
{
struct
timeval
time
;
gettimeofday
(
&
time
,
NULL
);
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
Communicator
::
Communicator
()
{}
void
Communicator
::
InitGFlag
(
const
std
::
string
&
gflags
)
{
VLOG
(
3
)
<<
"Init With Gflags:"
<<
gflags
;
std
::
vector
<
std
::
string
>
flags
=
paddle
::
string
::
split_string
(
gflags
);
if
(
flags
.
size
()
<
1
)
{
flags
.
push_back
(
"-max_body_size=314217728"
);
flags
.
push_back
(
"-bthread_concurrency=40"
);
flags
.
push_back
(
"-socket_max_unwritten_bytes=2048000000"
);
flags
.
push_back
(
"-max_connection_pool_size=1950"
);
}
auto
it
=
flags
.
begin
();
flags
.
insert
(
it
,
"exe default"
);
char
*
flags_ptr
[
flags
.
size
()];
for
(
size_t
i
=
0
;
i
<
flags
.
size
();
++
i
)
{
flags_ptr
[
i
]
=
(
char
*
)(
flags
[
i
].
c_str
());
// NOLINT
}
int
params_cnt
=
flags
.
size
();
char
**
params_ptr
=
&
(
flags_ptr
[
0
]);
::
GFLAGS_NAMESPACE
::
ParseCommandLineFlags
(
&
params_cnt
,
&
params_ptr
,
true
);
}
std
::
once_flag
Communicator
::
init_flag_
;
std
::
shared_ptr
<
Communicator
>
Communicator
::
communicator_
(
nullptr
);
void
Communicator
::
InitBrpcClient
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
)
{
auto
fleet
=
paddle
::
distributed
::
FleetWrapper
::
GetInstance
();
if
(
_worker_ptr
.
get
()
==
nullptr
)
{
_worker_ptr
=
fleet
->
worker_ptr_
;
}
return
;
}
std
::
vector
<
uint64_t
>
Communicator
::
GetClientInfo
()
{
std
::
vector
<
uint64_t
>
res
=
_ps_env
.
GetClientInfo
();
for
(
auto
rr
:
res
)
{
VLOG
(
2
)
<<
"Communicator::GetClientInfo "
<<
rr
;
}
return
res
;
}
int
Communicator
::
SetClients
(
std
::
vector
<
uint64_t
>
&
host_sign_list
)
{
int
node
=
host_sign_list
.
size
();
return
_ps_env
.
SetPsClients
(
host_sign_list
.
data
(),
node
);
}
void
Communicator
::
RpcRecvDense
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
Scope
*
scope
)
{
// pserver_scope_
platform
::
RecordEvent
record_event
(
"Communicator->RpcRecvDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
std
::
vector
<
paddle
::
distributed
::
Region
>
regions
;
regions
.
reserve
(
varnames
.
size
());
for
(
auto
&
t
:
varnames
)
{
Variable
*
var
=
scope
->
Var
(
t
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
Variable
*
temp_var
=
xpu_temp_scope_
->
Var
(
t
);
LoDTensor
*
temp_tensor
=
temp_var
->
GetMutable
<
LoDTensor
>
();
temp_tensor
->
Resize
(
tensor
->
dims
());
float
*
temp_data
=
temp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
paddle
::
distributed
::
Region
reg
(
temp_data
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
VLOG
(
1
)
<<
"Communicator::RpcRecvDense Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_data
[
tensor
->
numel
()
-
1
];
#endif
}
else
{
float
*
w
=
tensor
->
mutable_data
<
float
>
(
tensor
->
place
());
paddle
::
distributed
::
Region
reg
(
w
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
}
}
auto
status
=
_worker_ptr
->
PullDense
(
regions
.
data
(),
regions
.
size
(),
table_id
);
status
.
wait
();
for
(
auto
&
t
:
varnames
)
{
Variable
*
var
=
scope
->
FindVar
(
t
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
VLOG
(
3
)
<<
"Communicator::RecvNoBarrier Var "
<<
t
<<
" On gpu? "
<<
platform
::
is_gpu_place
(
tensor
->
place
());
float
*
temp_recv_data
=
tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
VLOG
(
3
)
<<
"Communicator::RpcRecvDense Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_recv_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_recv_data
[
tensor
->
numel
()
-
1
];
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
LoDTensor
*
temp_tensor
=
xpu_temp_scope_
->
FindVar
(
t
)
->
GetMutable
<
LoDTensor
>
();
framework
::
TensorCopy
(
*
temp_tensor
,
tensor
->
place
(),
tensor
);
float
*
temp_data
=
temp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
VLOG
(
1
)
<<
"Communicator::RpcRecvDense Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_data
[
tensor
->
numel
()
-
1
];
#endif
}
}
return
;
}
void
Communicator
::
RpcSendDenseParam
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
const
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcSendDenseParam"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
place
=
platform
::
CPUPlace
();
std
::
vector
<
paddle
::
distributed
::
Region
>
regions
;
for
(
auto
&
t
:
varnames
)
{
Variable
*
var
=
scope
.
FindVar
(
t
);
CHECK
(
var
!=
nullptr
)
<<
"var["
<<
t
<<
"] not found"
;
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
Variable
*
temp_var
=
xpu_temp_scope_
->
Var
(
t
);
LoDTensor
*
temp_tensor
=
temp_var
->
GetMutable
<
LoDTensor
>
();
temp_tensor
->
Resize
(
tensor
->
dims
());
float
*
temp_data
=
temp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
framework
::
TensorCopy
(
*
tensor
,
platform
::
CPUPlace
(),
temp_tensor
);
paddle
::
distributed
::
Region
reg
(
temp_data
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
VLOG
(
1
)
<<
"rpc_send_dense_param Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_data
[
tensor
->
numel
()
-
1
];
#endif
}
else
{
float
*
w
=
tensor
->
mutable_data
<
float
>
(
place
);
paddle
::
distributed
::
Region
reg
(
w
,
tensor
->
numel
());
regions
.
emplace_back
(
reg
);
VLOG
(
1
)
<<
"rpc_send_dense_param Var "
<<
t
<<
" talbe_id "
<<
table_id
<<
" Temp_data[0] "
<<
w
[
0
]
<<
" Temp_data[-1] "
<<
w
[
tensor
->
numel
()
-
1
];
}
}
auto
status
=
_worker_ptr
->
PushDenseParam
(
regions
.
data
(),
regions
.
size
(),
table_id
);
status
.
wait
();
VLOG
(
4
)
<<
"RPC Send Dense Param "
<<
table_id
<<
" done!"
;
return
;
}
void
Communicator
::
RpcSendDense
(
const
CommContext
&
ctx
,
const
Scope
&
scope
)
{
// delta_scope_
platform
::
RecordEvent
record_event
(
"Communicator->RpcSendDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
&
var_names
=
ctx
.
origin_varnames
;
auto
&
table_id
=
ctx
.
table_id
;
auto
dense_data
=
std
::
make_shared
<
std
::
vector
<
float
>>
();
size_t
request_call_num
=
_worker_ptr
->
GetServerNums
();
uint32_t
num_per_shard
=
DenseDimPerShard
(
ctx
.
height_sections
[
0
],
request_call_num
);
dense_data
->
resize
(
num_per_shard
*
request_call_num
);
// accessor->update_dim() = 1
float
*
data
=
dense_data
->
data
();
uint32_t
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
var_names
.
size
();
++
i
)
{
const
LoDTensor
tensor
=
scope
.
FindVar
(
var_names
[
i
])
->
Get
<
LoDTensor
>
();
size_t
count
=
static_cast
<
size_t
>
(
tensor
.
numel
());
const
float
*
g
=
tensor
.
data
<
float
>
();
CHECK
(
pos
+
count
<=
dense_data
->
size
())
<<
"invalid dense size, cur pos["
<<
pos
<<
"]"
<<
" data_num["
<<
count
<<
"] size["
<<
dense_data
->
size
()
<<
"]"
;
memcpy
(
data
+
pos
,
g
,
count
*
sizeof
(
float
));
pos
+=
count
;
}
++
_async_call_num
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
auto
status
=
_worker_ptr
->
PushDenseRawGradient
(
table_id
,
data
,
dense_data
->
size
(),
closure
);
status
.
wait
();
return
;
}
void
Communicator
::
RpcSendSparseParam
(
const
std
::
string
&
varname
,
int
table_id
,
const
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcSendSparseParam"
,
platform
::
TracerEventType
::
Communication
,
1
);
size_t
request_call_num
=
_worker_ptr
->
GetServerNums
();
std
::
vector
<
float
*>
push_g_vec
;
auto
*
send_var
=
scope
.
FindVar
(
varname
);
auto
*
tensor
=
send_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dim
=
tensor
->
dims
()[
1
];
uint64_t
sparse_num
=
static_cast
<
uint64_t
>
(
tensor
->
dims
()[
0
]);
std
::
vector
<
uint64_t
>
sparse_push_keys
(
sparse_num
);
std
::
iota
(
sparse_push_keys
.
begin
(),
sparse_push_keys
.
end
(),
0
);
push_g_vec
.
reserve
(
sparse_num
);
for
(
auto
i
=
0
;
i
<
static_cast
<
int
>
(
sparse_push_keys
.
size
());
++
i
)
{
push_g_vec
.
push_back
(
tensor
->
data
<
float
>
()
+
i
*
dim
);
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_SPARSE_PARAM
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
status
=
_worker_ptr
->
PushSparseParam
(
table_id
,
sparse_push_keys
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
sparse_push_keys
.
size
(),
closure
);
status
.
wait
();
return
;
}
void
Communicator
::
RpcSendSparse
(
const
std
::
string
&
var_name
,
int
table_id
,
const
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcSendSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
size_t
request_call_num
=
_worker_ptr
->
GetServerNums
();
std
::
vector
<
uint64_t
>
sparse_push_keys
;
std
::
vector
<
float
*>
push_g_vec
;
auto
*
send_var
=
scope
.
FindVar
(
var_name
);
auto
*
tensor
=
send_var
->
GetMutable
<
phi
::
SelectedRows
>
();
auto
dim
=
tensor
->
value
().
dims
()[
1
];
std
::
transform
(
tensor
->
rows
().
begin
(),
tensor
->
rows
().
end
(),
std
::
back_inserter
(
sparse_push_keys
),
[
&
](
int64_t
id
)
{
return
static_cast
<
uint64_t
>
(
id
);
});
for
(
auto
i
=
0
;
i
<
static_cast
<
int
>
(
sparse_push_keys
.
size
());
++
i
)
{
push_g_vec
.
push_back
(
tensor
->
mutable_value
()
->
data
<
float
>
()
+
i
*
dim
);
}
// TODO(wangguanqun): padding_idx is not ignored, this is a bug.
// if padding_idx == padding in datareader, the server will core.
/*
for (size_t i = 0; i < tensor->rows().size(); ++i) {
uint64_t real_id = static_cast<uint64_t>(tensor->rows()[i]);
if (real_id != 0) {
sparse_push_keys.push_back(real_id);
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}
}
*/
++
_async_call_num
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
auto
status
=
_worker_ptr
->
PushSparseRawGradient
(
table_id
,
sparse_push_keys
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
sparse_push_keys
.
size
(),
closure
);
status
.
wait
();
return
;
}
void
Communicator
::
RpcRecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
Scope
*
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcRecvSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
*
send_var
=
scope
->
Var
(
varname
);
auto
*
tensor
=
send_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dim
=
tensor
->
dims
()[
1
];
uint64_t
sparse_num
=
static_cast
<
uint64_t
>
(
tensor
->
dims
()[
0
]);
std
::
vector
<
uint64_t
>
sparse_pull_keys
(
sparse_num
);
std
::
iota
(
sparse_pull_keys
.
begin
(),
sparse_pull_keys
.
end
(),
0
);
std
::
vector
<
float
*>
pull_g_vec
;
for
(
auto
i
=
0
;
i
<
static_cast
<
int
>
(
sparse_pull_keys
.
size
());
++
i
)
{
pull_g_vec
.
push_back
(
tensor
->
data
<
float
>
()
+
i
*
dim
);
}
bool
training
=
true
;
auto
status
=
_worker_ptr
->
PullSparseParam
(
static_cast
<
float
**>
(
pull_g_vec
.
data
()),
table_id
,
sparse_pull_keys
.
data
(),
sparse_pull_keys
.
size
(),
training
);
status
.
wait
();
return
;
}
void
Communicator
::
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
{
if
(
trainer_id_
==
0
)
{
for
(
auto
&
iter
:
recv_varname_to_ctx
)
{
auto
&
table_id
=
iter
.
first
;
auto
&
varnames
=
iter
.
second
;
RpcSendDenseParam
(
varnames
,
table_id
,
*
recv_scope_
);
VLOG
(
1
)
<<
"push dense param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
}
return
;
}
void
Communicator
::
PullDense
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
{
for
(
auto
&
iter
:
recv_varname_to_ctx
)
{
auto
&
table_id
=
iter
.
first
;
auto
&
varnames
=
iter
.
second
;
RpcRecvDense
(
varnames
,
table_id
,
recv_scope_
);
VLOG
(
1
)
<<
"pull dense param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
return
;
}
void
Communicator
::
RpcProfilerControl
()
{
if
(
trainer_id_
==
0
)
{
if
(
!
do_server_profiler_
&&
platform
::
IsProfileEnabled
())
{
// send profiler start flag
do_server_profiler_
=
true
;
auto
start_status
=
_worker_ptr
->
StartProfiler
();
start_status
.
wait
();
}
else
if
(
do_server_profiler_
&&
!
platform
::
IsProfileEnabled
())
{
// send profiler end flag
auto
stop_status
=
_worker_ptr
->
StopProfiler
();
stop_status
.
wait
();
do_server_profiler_
=
false
;
}
}
}
void
Communicator
::
SendGlobalStep
(
const
CommContext
&
ctx
,
int
batches
,
Scope
*
send_scope
)
{
if
(
batches
==
0
)
{
return
;
}
platform
::
RecordEvent
record_event
(
"Communicator->SendGlobalStep"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
&
table_id
=
ctx
.
table_id
;
size_t
request_call_num
=
_worker_ptr
->
GetServerNums
();
auto
&
var_name
=
STEP_COUNTER
;
auto
*
out_var
=
send_scope
->
Var
(
var_name
);
auto
*
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
data
=
out_t
->
mutable_data
<
int64_t
>
({
1
},
platform
::
CPUPlace
());
data
[
0
]
=
static_cast
<
int64_t
>
(
batches
);
VLOG
(
3
)
<<
"Communicator::SendGlobalStep send: "
<<
batches
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_GLOBAL_STEP
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
status
=
_worker_ptr
->
PushGlobalStep
(
table_id
,
data
,
closure
);
status
.
wait
();
return
;
}
void
AsyncCommunicator
::
RecvThread
()
{
if
(
!
independent_recv_
)
return
;
VLOG
(
3
)
<<
"Independent RecvThread Start and Wait"
;
while
(
running_
)
{
int
grad_num
=
grad_num_
.
load
();
if
(
grad_num
>
min_send_grad_num_before_recv_
)
{
RecvByCommunicator
();
grad_num_
.
store
(
0
);
}
else
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
}
}
VLOG
(
1
)
<<
"communicator stopped, independent recv thread exit"
;
}
void
AsyncCommunicator
::
RecvByCommunicator
()
{
if
(
!
running_
)
return
;
RecvNoBarrier
();
VLOG
(
3
)
<<
"run recv graph end"
;
}
void
AsyncCommunicator
::
RecvNoBarrier
()
{
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
&
table_id
=
iter
.
first
;
auto
&
varnames
=
iter
.
second
;
RpcRecvDense
(
varnames
,
table_id
,
recv_scope_
);
}
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
var_names
=
iter
.
second
;
for
(
auto
&
t
:
var_names
)
{
Variable
*
var
=
recv_scope_
->
FindVar
(
t
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
VLOG
(
3
)
<<
"AsyncCommunicator::RecvNoBarrier Var "
<<
t
<<
" On gpu? "
<<
platform
::
is_gpu_place
(
tensor
->
place
());
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
LoDTensor
*
temp_tensor
=
xpu_temp_scope_
->
FindVar
(
t
)
->
GetMutable
<
LoDTensor
>
();
framework
::
TensorCopy
(
*
temp_tensor
,
tensor
->
place
(),
tensor
);
#endif
}
}
}
return
;
}
void
AsyncCommunicator
::
SendByCommunicator
()
{
std
::
vector
<
std
::
future
<
void
>>
tasks
;
tasks
.
reserve
(
send_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
auto
send_recv_task
=
[
this
,
&
ctx
]
{
auto
&
varnames
=
ctx
.
origin_varnames
;
auto
&
table_id
=
ctx
.
table_id
;
size_t
var_nums
=
varnames
.
size
();
auto
&
check_queue
=
send_varname_to_queue_
[
varnames
[
0
]];
std
::
vector
<
std
::
vector
<
std
::
shared_ptr
<
Variable
>>>
vars
;
vars
.
resize
(
var_nums
);
int
merged_var_num
=
0
;
int
wait_times
=
0
;
while
(
merged_var_num
<
max_merge_var_num_
)
{
if
(
check_queue
->
Size
()
==
0
)
{
VLOG
(
4
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
send_wait_times_
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
wait_times
++
;
continue
;
}
else
{
wait_times
=
0
;
for
(
size_t
i
=
0
;
i
<
var_nums
;
i
++
)
{
auto
&
var_name
=
varnames
[
i
];
auto
&
var_queue
=
send_varname_to_queue_
[
var_name
];
vars
[
i
].
push_back
(
var_queue
->
Pop
());
}
merged_var_num
++
;
}
}
if
(
merged_var_num
==
0
)
return
;
for
(
size_t
i
=
0
;
i
<
var_nums
;
i
++
)
{
auto
&
var_name
=
varnames
[
i
];
if
(
var_name
==
STEP_COUNTER
)
{
MergeVars
<
int64_t
>
(
var_name
,
vars
[
i
],
send_scope_
.
get
(),
1
);
}
else
{
MergeVars
<
float
>
(
var_name
,
vars
[
i
],
send_scope_
.
get
(),
1
);
}
}
if
(
ctx
.
is_tensor_table
)
{
SendGlobalStep
(
ctx
,
merged_var_num
,
send_scope_
.
get
());
}
else
if
(
ctx
.
is_sparse
)
{
PADDLE_ENFORCE_EQ
(
varnames
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"sparse variables can only be merged by one variables"
));
RpcSendSparse
(
varnames
[
0
],
table_id
,
*
send_scope_
);
}
else
{
RpcSendDense
(
ctx
,
*
send_scope_
);
if
(
!
independent_recv_
&&
recv_varname_to_ctx_
.
find
(
table_id
)
!=
recv_varname_to_ctx_
.
end
())
{
auto
recv_varnames
=
recv_varname_to_ctx_
.
at
(
table_id
);
RpcRecvDense
(
recv_varnames
,
table_id
,
recv_scope_
);
}
}
if
(
independent_recv_
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
}
};
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_recv_task
)));
}
for
(
auto
&
task
:
tasks
)
{
task
.
wait
();
}
return
;
}
void
AsyncCommunicator
::
PushDensePostProcessing
()
{
if
(
independent_recv_
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
}
return
;
}
void
AsyncCommunicator
::
MainThread
()
{
VLOG
(
3
)
<<
"AsyncCommunicator MainThread start and wait"
;
while
(
waiting_
&&
running_
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
VLOG
(
3
)
<<
"wait for running"
;
}
while
(
running_
)
{
SendByCommunicator
();
RpcProfilerControl
();
}
VLOG
(
1
)
<<
"communicator stopped, send thread exit"
;
}
void
AsyncCommunicator
::
PullSparseToTensorSync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
bool
is_training
,
std
::
vector
<
const
LoDTensor
*>
*
inputs
,
std
::
vector
<
LoDTensor
*>
*
outputs
)
{
std
::
vector
<
uint64_t
>
fea_keys
;
std
::
vector
<
float
*>
pull_result_ptr
;
fea_keys
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
pull_result_ptr
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
std
::
vector
<
float
>
init_value
(
fea_dim
,
0
);
framework
::
LoDTensor
*
output
=
nullptr
;
float
*
output_data
=
nullptr
;
size_t
output_index
=
-
1
;
size_t
output_len
=
0
;
for
(
size_t
index
=
0
;
index
<
inputs
->
size
();
++
index
)
{
const
framework
::
LoDTensor
*
tensor
=
inputs
->
at
(
index
);
const
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
size_t
len
=
tensor
->
numel
();
for
(
size_t
i
=
0
;
i
<
len
;
++
i
,
output_len
+=
fea_dim
)
{
if
(
!
output
||
output_len
==
size_t
(
output
->
numel
()))
{
++
output_index
;
CHECK
(
output_index
<
outputs
->
size
());
// NOLINT
output
=
outputs
->
at
(
output_index
);
output
->
set_lod
(
tensor
->
lod
());
output_data
=
output
->
mutable_data
<
float
>
(
place
);
output_len
=
0
;
CHECK
(
output
->
numel
()
%
fea_dim
==
0
);
// NOLINT
CHECK
(
output_data
!=
nullptr
);
// NOLINT
}
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
i
]);
if
(
real_id
==
padding_id
)
{
memcpy
(
output_data
+
output_len
,
init_value
.
data
(),
sizeof
(
float
)
*
fea_dim
);
continue
;
}
fea_keys
.
push_back
(
real_id
);
pull_result_ptr
.
push_back
(
output_data
+
output_len
);
}
}
auto
status
=
_worker_ptr
->
PullSparse
(
pull_result_ptr
.
data
(),
table_id
,
fea_keys
.
data
(),
fea_keys
.
size
(),
is_training
);
status
.
wait
();
auto
ret
=
status
.
get
();
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"fleet pull sparse failed, status["
<<
ret
<<
"]"
;
sleep
(
sleep_seconds_before_fail_exit_
);
}
}
void
AsyncCommunicator
::
PushSparseFromTensorAsync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
framework
::
LoDTensor
*>
*
inputs
,
const
framework
::
LoDTensor
*
shows
,
const
framework
::
LoDTensor
*
clks
,
std
::
vector
<
framework
::
LoDTensor
*>
*
outputs
)
{
int
batch_size
=
-
1
;
bool
batch_size_consist
=
true
;
for
(
auto
*
input
:
*
inputs
)
{
int
cur_batch_size
=
input
->
lod
().
size
()
?
input
->
lod
()[
0
].
size
()
-
1
:
input
->
dims
()[
0
];
if
(
batch_size
==
-
1
)
{
batch_size
=
cur_batch_size
;
}
else
if
(
batch_size
!=
cur_batch_size
)
{
// CHECK(batch_size == cur_batch_size); // NOLINT
batch_size_consist
=
false
;
break
;
}
}
CHECK
(
batch_size
>
0
);
// NOLINT
int
show_size
=
shows
->
lod
().
size
()
?
shows
->
lod
()[
0
].
size
()
-
1
:
shows
->
dims
()[
0
];
CHECK
(
show_size
==
batch_size
||
show_size
==
1
);
int
clk_size
=
clks
->
lod
().
size
()
?
clks
->
lod
()[
0
].
size
()
-
1
:
clks
->
dims
()[
0
];
CHECK
(
clk_size
==
batch_size
||
clk_size
==
1
);
CHECK
(
outputs
->
size
()
==
inputs
->
size
());
std
::
vector
<
uint64_t
>
push_keys
;
push_keys
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
std
::
vector
<
std
::
vector
<
float
>>
push_values
;
push_values
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
size_t
output_len
=
0
;
size_t
input_idx
=
0
;
VLOG
(
2
)
<<
"fleet.cc::emb_dim: "
<<
fea_dim
<<
" batch_size: "
<<
batch_size
<<
" batch_size_consist: "
<<
batch_size_consist
;
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>();
for
(
size_t
index
=
0
;
index
<
inputs
->
size
();
++
index
)
{
framework
::
LoDTensor
*
g_tensor
=
outputs
->
at
(
index
);
float
*
g
=
g_tensor
->
data
<
float
>
();
if
(
batch_size_consist
)
{
// TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen
::
Map
<
Eigen
::
Matrix
<
float
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
g_mat
(
g
,
g_tensor
->
numel
()
/
fea_dim
,
fea_dim
);
g_mat
.
rightCols
(
fea_dim
-
2
)
*=
batch_size
;
// hard code here, because of cvm_grad op
}
const
framework
::
LoDTensor
*
tensor
=
inputs
->
at
(
index
);
const
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
size_t
len
=
tensor
->
numel
();
output_len
=
0
;
if
(
tensor
->
lod
().
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
tensor
->
lod
()[
0
].
size
()
-
1
;
++
i
)
{
for
(
size_t
j
=
tensor
->
lod
()[
0
][
i
];
j
<
tensor
->
lod
()[
0
][
i
+
1
];
++
j
,
output_len
+=
fea_dim
)
{
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
j
]);
if
(
real_id
==
padding_id
)
{
continue
;
}
push_keys
.
emplace_back
(
real_id
);
push_values
.
emplace_back
(
fea_dim
+
1
);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
// push_values.back()[1] =
// (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
// push_values.back()[2] =
// (i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float
*
data
=
push_values
.
back
().
data
()
+
1
;
// hard code here
memcpy
(
data
,
g
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
++
input_idx
;
}
}
}
else
{
for
(
size_t
i
=
0
;
i
<
len
;
++
i
,
output_len
+=
fea_dim
)
{
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
i
]);
if
(
real_id
==
padding_id
)
{
continue
;
}
push_keys
.
emplace_back
(
real_id
);
push_values
.
emplace_back
(
fea_dim
+
1
);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
// push_values.back()[1] =
// (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
// push_values.back()[2] =
// (i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float
*
data
=
push_values
.
back
().
data
()
+
1
;
memcpy
(
data
,
g
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
++
input_idx
;
}
}
CHECK
(
static_cast
<
int64_t
>
(
output_len
)
==
g_tensor
->
numel
());
}
std
::
vector
<
float
*>
push_g_vec
(
input_idx
,
nullptr
);
for
(
auto
i
=
0u
;
i
<
push_keys
.
size
();
++
i
)
{
push_g_vec
[
i
]
=
push_values
.
at
(
i
).
data
();
}
PADDLE_ENFORCE_EQ
(
this
->
Check
(
table_id
),
true
,
platform
::
errors
::
InvalidArgument
(
"can not find table: %s, please check your config"
,
table_id
));
auto
status
=
_worker_ptr
->
PushSparse
(
table_id
,
push_keys
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
push_keys
.
size
());
}
void
HalfAsyncCommunicator
::
MainThread
()
{
VLOG
(
3
)
<<
"HalfAsyncCommunicator MainThread start and wait"
;
while
(
waiting_
&&
running_
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
VLOG
(
3
)
<<
"wait for running"
;
}
while
(
running_
)
{
SendByCommunicator
();
BarrierSend
();
RecvByCommunicator
();
BarrierRecv
();
BarrierWeakUp
();
}
VLOG
(
1
)
<<
"communicator stopped, send thread exit"
;
}
void
AsyncCommunicator
::
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
send_varname_to_ctx_
=
std
::
move
(
send_varname_to_ctx
);
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
recv_scope_
=
std
::
move
(
recv_scope
);
send_scope_
.
reset
(
new
Scope
());
xpu_temp_scope_
.
reset
(
new
Scope
());
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
auto
&
varnames
=
ctx
.
origin_varnames
;
for
(
auto
&
var_name
:
varnames
)
{
send_varname_to_queue_
[
var_name
]
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>
(
send_queue_size_
);
}
}
send_threadpool_
.
reset
(
new
::
ThreadPool
(
thread_pool_size_
));
}
AsyncCommunicator
::~
AsyncCommunicator
()
{
running_
=
false
;
if
(
main_thread_
)
main_thread_
->
join
();
if
(
recv_thread_
)
recv_thread_
->
join
();
}
void
AsyncCommunicator
::
Start
()
{
VLOG
(
1
)
<<
"Communicator start"
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
VLOG
(
1
)
<<
"start send thread and recv thread"
;
waiting_
=
true
;
running_
=
true
;
// flushing_ = false;
BarrierTriggerReset
(
max_merge_var_num_
);
// start send and recv thread
main_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
MainThread
,
this
)));
if
(
independent_recv_
)
{
recv_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
RecvThread
,
this
)));
}
}
}
void
AsyncCommunicator
::
Stop
()
{
VLOG
(
1
)
<<
"Communicator stop begin"
;
running_
=
false
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
// _worker_ptr->FinalizeWorker();
VLOG
(
1
)
<<
"client finalize_worker done"
;
if
(
recv_thread_
)
{
VLOG
(
1
)
<<
"stop recv thread"
;
recv_thread_
->
join
();
recv_thread_
.
reset
(
nullptr
);
}
if
(
main_thread_
)
{
VLOG
(
1
)
<<
"stop main thread"
;
main_thread_
->
join
();
main_thread_
.
reset
(
nullptr
);
}
}
VLOG
(
1
)
<<
"Communicator stop done"
;
}
bool
AsyncCommunicator
::
Check
(
const
std
::
vector
<
std
::
string
>
&
var_tables
)
{
PADDLE_ENFORCE_EQ
(
var_tables
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"var_tables.size() == 1 is permitted"
));
auto
table_name
=
var_tables
[
0
];
if
(
send_varname_to_ctx_
.
find
(
table_name
)
==
send_varname_to_ctx_
.
end
())
{
return
false
;
}
if
(
table_name
==
STEP_COUNTER
)
{
VLOG
(
3
)
<<
"send step_counter into queue"
;
auto
tmp_var
=
std
::
make_shared
<
Variable
>
();
auto
*
tensor
=
tmp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
phi
::
make_ddim
({
1
}));
auto
*
out_d
=
tensor
->
mutable_data
<
int64_t
>
(
platform
::
CPUPlace
());
out_d
[
0
]
=
1
;
send_varname_to_queue_
[
table_name
]
->
Push
(
tmp_var
);
}
return
true
;
}
bool
AsyncCommunicator
::
Check
(
const
int
table_id
)
{
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
if
(
ctx
.
table_id
==
table_id
)
return
true
;
}
return
false
;
}
void
AsyncCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
{
waiting_
=
false
;
for
(
size_t
i
=
0
;
i
<
var_names
.
size
();
i
++
)
{
auto
*
var
=
scope
.
FindVar
(
var_names
[
i
]);
auto
tmp_grad_var
=
std
::
make_shared
<
Variable
>
();
framework
::
CopyVariable
(
*
var
,
tmp_grad_var
.
get
());
send_varname_to_queue_
[
var_names
[
i
]]
->
Push
(
tmp_grad_var
);
}
}
void
HalfAsyncCommunicator
::
Clean
()
{
for
(
auto
&
iter
:
send_varname_to_queue_
)
{
auto
&
var_name
=
iter
.
first
;
auto
&
var_queue
=
iter
.
second
;
while
(
var_queue
->
Size
()
>
0
)
{
var_queue
->
Pop
();
}
VLOG
(
3
)
<<
"clean var: "
<<
var_name
<<
" done"
;
}
}
void
HalfAsyncCommunicator
::
BarrierTriggerDecrement
()
{
barrier_trigger_
--
;
VLOG
(
3
)
<<
"BarrierTriggerDecrement decrement barrier trigger to "
<<
barrier_trigger_
.
load
();
}
void
HalfAsyncCommunicator
::
BarrierTriggerReset
(
int
initial_val
)
{
barrier_trigger_
.
store
(
initial_val
);
VLOG
(
3
)
<<
"BarrierTriggerReset reset barrier trigger to "
<<
barrier_trigger_
.
load
();
}
void
HalfAsyncCommunicator
::
Barrier
()
{
barrier_counter_
++
;
if
(
!
running_
)
{
VLOG
(
3
)
<<
"Communicator is not running, release barrier"
;
return
;
}
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
barrier_mutex_
);
barrier_cond_
.
wait
(
lk
,
[
this
]
{
return
(
barrier_counter_
==
0
);
});
}
}
int
HalfAsyncCommunicator
::
BatchesCounter
()
{
while
(
running_
)
{
if
(
barrier_counter_
.
load
()
>=
barrier_trigger_
.
load
()
&&
barrier_trigger_
.
load
()
!=
0
)
{
break
;
}
else
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
}
}
return
barrier_counter_
.
load
();
}
void
HalfAsyncCommunicator
::
SendByCommunicator
()
{
int
batches
=
BatchesCounter
();
VLOG
(
1
)
<<
"HalfAsyncCommunicator::BatchesCounter = "
<<
batches
;
if
(
batches
<=
0
)
return
;
std
::
vector
<
std
::
future
<
void
>>
tasks
;
tasks
.
reserve
(
send_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
auto
send_recv_task
=
[
this
,
&
ctx
,
batches
]
{
auto
&
varnames
=
ctx
.
origin_varnames
;
auto
&
table_id
=
ctx
.
table_id
;
size_t
var_nums
=
varnames
.
size
();
std
::
vector
<
std
::
vector
<
std
::
shared_ptr
<
Variable
>>>
vars
;
vars
.
resize
(
var_nums
);
for
(
size_t
i
=
0
;
i
<
var_nums
;
i
++
)
{
auto
&
var_name
=
varnames
[
i
];
auto
&
var_queue
=
send_varname_to_queue_
[
var_name
];
for
(
int
j
=
0
;
j
<
batches
;
j
++
)
vars
[
i
].
push_back
(
var_queue
->
Pop
());
MergeVars
<
float
>
(
var_name
,
vars
[
i
],
send_scope_
.
get
(),
1
);
}
if
(
ctx
.
is_sparse
)
{
PADDLE_ENFORCE_EQ
(
varnames
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"sparse variables can only be merged by one variables"
));
RpcSendSparse
(
varnames
[
0
],
table_id
,
*
send_scope_
);
}
else
{
RpcSendDense
(
ctx
,
*
send_scope_
);
}
};
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_recv_task
)));
}
for
(
auto
&
task
:
tasks
)
{
task
.
wait
();
}
return
;
}
void
HalfAsyncCommunicator
::
BarrierWeakUp
()
{
barrier_counter_
.
store
(
0
);
barrier_cond_
.
notify_all
();
}
void
SyncCommunicator
::
BarrierSend
()
{
if
(
!
running_
)
return
;
BarrierWithTable
(
0
);
VLOG
(
4
)
<<
"BarrierSend with SyncCommunicator"
;
}
void
SyncCommunicator
::
BarrierRecv
()
{
if
(
!
running_
)
return
;
BarrierWithTable
(
1
);
VLOG
(
4
)
<<
"BarrierRecv with SyncCommunicator"
;
}
void
GeoCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
{
// last op in program
platform
::
RecordEvent
record_event
(
"GeoCommunicator->Send"
,
platform
::
TracerEventType
::
Communication
,
1
);
waiting_
=
false
;
auto
before_send
=
GetCurrentUS
();
auto
table_name
=
var_names
[
0
];
size_t
splited_var_nums
=
send_varname_to_ctx_
[
table_name
].
splited_varnames
.
size
();
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
int64_t
>>
ids_table
;
for
(
size_t
j
=
0
;
j
<
splited_var_nums
;
j
++
)
{
ids_table
.
insert
(
std
::
pair
<
std
::
string
,
std
::
unordered_set
<
int64_t
>>
(
send_varname_to_ctx_
[
table_name
].
splited_varnames
[
j
],
std
::
unordered_set
<
int64_t
>
()));
}
auto
*
var
=
scope
.
FindVar
(
table_name
);
PADDLE_ENFORCE_EQ
(
var
->
IsType
<
phi
::
SelectedRows
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Only need to send Sparse Grad in Geo mode."
));
auto
&
rows
=
var
->
Get
<
phi
::
SelectedRows
>
().
rows
();
// insert ids which has not been record
// VLOG(0) << "fl-ps > table_name: " << table_name << " splited_var_nums: " <<
// splited_var_nums << " rows size: " << rows.size();
for
(
size_t
j
=
0
;
j
<
rows
.
size
();
j
++
)
{
// batch_size == rows.size()
auto
ep_idx
=
rows
[
j
]
%
splited_var_nums
;
ids_table
.
at
(
send_varname_to_ctx_
[
table_name
].
splited_varnames
[
ep_idx
])
.
insert
(
rows
[
j
]);
// VLOG(0) << " id: " << rows[j] << " ";
}
for
(
auto
&
iter
:
ids_table
)
{
auto
&
key
=
iter
.
first
;
auto
&
sparse_ids_set
=
iter
.
second
;
auto
sparse_ids_vec
=
std
::
make_shared
<
std
::
vector
<
int64_t
>>
();
sparse_ids_vec
->
assign
(
sparse_ids_set
.
begin
(),
sparse_ids_set
.
end
());
sparse_id_queues_
.
at
(
key
)
->
Put
(
sparse_ids_vec
);
VLOG
(
3
)
<<
"push "
<<
sparse_ids_vec
->
size
()
<<
" ids to "
<<
key
<<
"'s queue"
;
}
auto
after_send
=
GetCurrentUS
();
VLOG
(
2
)
<<
"run send op finish. use time "
<<
(
after_send
-
before_send
);
}
void
GeoCommunicator
::
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
send_varname_to_ctx_
=
std
::
move
(
send_varname_to_ctx
);
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
// dense_map - key: table_id, value: params
recv_scope_
=
std
::
move
(
recv_scope
);
for
(
auto
it
=
send_varname_to_ctx_
.
begin
();
it
!=
send_varname_to_ctx_
.
end
();)
{
auto
&
ctx
=
it
->
second
;
if
(
!
ctx
.
is_sparse
)
{
parallel_task_nums_
+=
1
;
it
++
;
continue
;
}
auto
&
varnames
=
ctx
.
origin_varnames
;
if
(
varnames
.
empty
())
{
VLOG
(
0
)
<<
"ERROR! sparse variables num can not be zero"
;
}
auto
&
varname
=
varnames
[
0
];
// embedding_0.w_0@GRAD
auto
&
ids
=
ctx
.
remote_sparse_ids
;
if
(
!
ids
.
empty
())
{
it
=
send_varname_to_ctx_
.
erase
(
it
);
continue
;
}
else
{
it
++
;
}
for
(
auto
&
splited_var
:
ctx
.
splited_varnames
)
{
// embedding_0.w_0.block0
parallel_task_nums_
+=
1
;
sparse_id_queues_
.
insert
(
std
::
pair
<
std
::
string
,
paddle
::
framework
::
Channel
<
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>>
(
splited_var
,
paddle
::
framework
::
MakeChannel
<
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>
(
send_queue_size_
)));
}
}
send_threadpool_
=
std
::
make_unique
<
ThreadPool
>
(
thread_pool_size_
);
delta_scope_
=
std
::
make_shared
<
Scope
>
();
old_scope_
=
std
::
make_shared
<
Scope
>
();
pserver_scope_
=
std
::
make_shared
<
Scope
>
();
return
;
}
void
GeoCommunicator
::
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
{
std
::
vector
<
std
::
future
<
void
>>
tasks
;
tasks
.
reserve
(
recv_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
&
table_id
=
iter
.
first
;
auto
&
varnames
=
iter
.
second
;
auto
recv_task
=
[
this
,
&
table_id
,
&
varnames
]
{
InitDense
(
varnames
,
table_id
);
};
if
(
send_threadpool_
==
nullptr
)
{
VLOG
(
0
)
<<
"ERROR! send_threadpool_ is nullptr"
;
}
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
recv_task
)));
}
for
(
auto
&
task
:
tasks
)
{
task
.
wait
();
}
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
if
(
!
ctx
.
is_sparse
)
{
continue
;
}
auto
&
varname
=
ctx
.
origin_varnames
[
0
];
auto
&
table_id
=
ctx
.
table_id
;
auto
param
=
varname
.
substr
(
0
,
varname
.
size
()
-
5
);
VLOG
(
0
)
<<
"InitSparse: "
<<
param
<<
", "
<<
table_id
;
InitSparse
(
param
,
table_id
);
}
return
;
}
void
GeoCommunicator
::
InitDense
(
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
)
{
VLOG
(
1
)
<<
"init dense table "
<<
table_id
<<
" begin"
;
if
(
trainer_id_
==
0
)
{
RpcSendDenseParam
(
varnames
,
table_id
,
*
recv_scope_
);
BarrierWithTable
(
1
);
VLOG
(
1
)
<<
"push dense param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
else
{
BarrierWithTable
(
1
);
RpcRecvDense
(
varnames
,
table_id
,
recv_scope_
);
VLOG
(
1
)
<<
"pull dense param from table "
<<
table_id
<<
" from 0' trainer done"
;
}
// copy to old_scope
for
(
auto
&
t
:
varnames
)
{
auto
*
global_var
=
recv_scope_
->
FindVar
(
t
);
global_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
old_var
=
old_scope_
->
Var
(
t
);
old_var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
CopyVariable
(
*
global_var
,
old_var
);
// src, dst
// init pserver_scope_
auto
*
pserver_var
=
pserver_scope_
->
Var
(
t
);
pserver_var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
CopyVariable
(
*
global_var
,
pserver_var
);
}
VLOG
(
1
)
<<
"init dense table "
<<
table_id
<<
" done"
;
}
void
GeoCommunicator
::
SendDense
(
const
CommContext
&
send_ctx
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->SendDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
&
var_names
=
send_ctx
.
origin_varnames
;
auto
&
table_id
=
send_ctx
.
table_id
;
for
(
auto
&
varname
:
var_names
)
{
auto
param_name
=
GradToParam
(
varname
);
auto
*
var_latest
=
recv_scope_
->
FindVar
(
param_name
);
auto
*
var_timestamp
=
old_scope_
->
FindVar
(
param_name
);
PADDLE_ENFORCE_EQ
(
var_latest
->
IsInitialized
(),
true
,
platform
::
errors
::
Unavailable
(
"%s is not initialized, please check"
,
param_name
));
PADDLE_ENFORCE_EQ
(
var_timestamp
->
IsInitialized
(),
true
,
platform
::
errors
::
Unavailable
(
"%s is not initialized, please check"
,
param_name
));
auto
&
t_latest
=
var_latest
->
Get
<
framework
::
LoDTensor
>
();
auto
t_timestamp
=
var_timestamp
->
GetMutable
<
framework
::
LoDTensor
>
();
phi
::
CPUContext
cpu_ctx
;
auto
*
var_delta
=
delta_scope_
->
Var
(
varname
);
auto
*
t_delta
=
var_delta
->
GetMutable
<
framework
::
LoDTensor
>
();
t_delta
->
mutable_data
<
float
>
(
t_latest
.
dims
(),
cpu_ctx
.
GetPlace
());
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
CPUContext
,
float
>
(
cpu_ctx
);
blas
.
VSUB
(
t_latest
.
numel
(),
t_latest
.
data
<
float
>
(),
t_timestamp
->
data
<
float
>
(),
t_delta
->
data
<
float
>
());
float
coefficient
=
1.0
/
static_cast
<
float
>
(
trainers_
);
blas
.
SCAL
(
t_latest
.
numel
(),
coefficient
,
t_delta
->
data
<
float
>
());
blas
.
VADD
(
t_latest
.
numel
(),
t_timestamp
->
data
<
float
>
(),
t_delta
->
data
<
float
>
(),
t_timestamp
->
data
<
float
>
());
}
RpcSendDense
(
send_ctx
,
*
delta_scope_
);
VLOG
(
1
)
<<
"Finish Send Dense "
<<
var_names
[
0
]
<<
", table_id: "
<<
table_id
;
return
;
}
void
GeoCommunicator
::
RecvDense
(
const
CommContext
&
send_ctx
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->RecvDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
auto
&
table_id
=
send_ctx
.
table_id
;
auto
&
varnames
=
recv_varname_to_ctx_
.
at
(
table_id
);
// 1. recv from pserver
RpcRecvDense
(
varnames
,
table_id
,
pserver_scope_
.
get
());
// 2.1 pserver - old => delta; 2.2 latest + delta => latest 2.3 old =>
// pserver
phi
::
CPUContext
cpu_ctx
;
for
(
auto
&
varname
:
varnames
)
{
auto
*
var_latest
=
recv_scope_
->
FindVar
(
varname
);
auto
t_latest
=
var_latest
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
var_old
=
old_scope_
->
FindVar
(
varname
);
auto
t_old
=
var_old
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
var_pserver
=
pserver_scope_
->
FindVar
(
varname
);
auto
t_pserver
=
var_pserver
->
Get
<
framework
::
LoDTensor
>
();
auto
*
var_delta
=
delta_scope_
->
Var
(
varname
);
auto
*
t_delta
=
var_delta
->
GetMutable
<
framework
::
LoDTensor
>
();
t_delta
->
mutable_data
<
float
>
(
t_latest
->
dims
(),
cpu_ctx
.
GetPlace
());
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
CPUContext
,
float
>
(
cpu_ctx
);
blas
.
VSUB
(
t_latest
->
numel
(),
t_pserver
.
data
<
float
>
(),
t_old
->
data
<
float
>
(),
t_delta
->
data
<
float
>
());
blas
.
VADD
(
t_latest
->
numel
(),
t_latest
->
data
<
float
>
(),
t_delta
->
data
<
float
>
(),
t_latest
->
data
<
float
>
());
blas
.
VCOPY
(
t_latest
->
numel
(),
t_pserver
.
data
<
float
>
(),
t_old
->
data
<
float
>
());
}
VLOG
(
1
)
<<
"Finish Recv Dense "
<<
varnames
[
0
]
<<
", table_id: "
<<
table_id
;
return
;
}
void
GeoCommunicator
::
InitSparse
(
const
std
::
string
&
var_name
,
int
table_id
)
{
VLOG
(
1
)
<<
"Init Sparse "
<<
var_name
<<
" : table "
<<
table_id
<<
" begin."
;
if
(
trainer_id_
==
0
)
{
RpcSendSparseParam
(
var_name
,
table_id
,
*
recv_scope_
);
BarrierWithTable
(
1
);
VLOG
(
1
)
<<
"push sparse param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
else
{
BarrierWithTable
(
1
);
RpcRecvSparse
(
var_name
,
table_id
,
recv_scope_
);
VLOG
(
1
)
<<
"pull sparse param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
VLOG
(
1
)
<<
"Init Sparse "
<<
var_name
<<
" : table "
<<
table_id
<<
" done."
;
auto
*
global_var
=
recv_scope_
->
FindVar
(
var_name
);
auto
*
var
=
old_scope_
->
Var
(
var_name
);
framework
::
CopyVariable
(
*
global_var
,
var
);
// src, dst
return
;
}
std
::
vector
<
int64_t
>
GeoCommunicator
::
MergeSparseIds
(
const
std
::
string
&
send_varname
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->MergeSparseIds"
,
platform
::
TracerEventType
::
Communication
,
1
);
size_t
merge_num
=
0
,
wait_times
=
0
;
std
::
unordered_set
<
int64_t
>
sparse_ids
;
while
(
merge_num
<
static_cast
<
size_t
>
(
max_merge_var_num_
))
{
// -> geo_step: 100
VLOG
(
3
)
<<
"Merge Number of "
<<
send_varname
<<
" = "
<<
merge_num
;
if
(
sparse_id_queues_
.
at
(
send_varname
)
->
Size
()
>
0
)
{
wait_times
=
0
;
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>
pop_ids
=
nullptr
;
sparse_id_queues_
.
at
(
send_varname
)
->
Get
(
pop_ids
);
for
(
size_t
j
=
0
;
j
<
pop_ids
->
size
();
j
++
)
{
sparse_ids
.
insert
(
pop_ids
->
at
(
j
));
}
merge_num
+=
1
;
VLOG
(
3
)
<<
"sparse_id_queues_("
<<
send_varname
<<
") pushed"
;
}
else
if
(
sparse_id_queues_
.
at
(
send_varname
)
->
Size
()
==
0
)
{
VLOG
(
3
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
static_cast
<
size_t
>
(
send_wait_times_
))
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
wait_times
++
;
continue
;
}
}
std
::
vector
<
int64_t
>
res
;
res
.
assign
(
sparse_ids
.
begin
(),
sparse_ids
.
end
());
return
res
;
}
void
GeoCommunicator
::
SendSparse
(
const
std
::
string
&
varname
,
std
::
vector
<
int64_t
>
&
sparse_ids
,
int
table_id
,
int
ep_idx
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->SendSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
if
(
sparse_ids
.
size
()
==
0
)
{
return
;
}
std
::
string
param_name
=
SplitedGradToParam
(
varname
);
VLOG
(
1
)
<<
"In GeoCommunicator::SendSparse("
<<
varname
<<
" "
<<
param_name
<<
", ids.size = "
<<
sparse_ids
.
size
()
<<
", table_id: "
<<
table_id
<<
", ep_idx: "
<<
ep_idx
;
auto
*
var_latest
=
recv_scope_
->
FindVar
(
param_name
);
auto
*
var_old
=
old_scope_
->
FindVar
(
param_name
);
PADDLE_ENFORCE_EQ
(
var_latest
->
IsInitialized
(),
true
,
platform
::
errors
::
Unavailable
(
"%s is not initialized, please check"
,
param_name
));
PADDLE_ENFORCE_EQ
(
var_old
->
IsInitialized
(),
true
,
platform
::
errors
::
Unavailable
(
"%s is not initialized, please check"
,
param_name
));
auto
&
t_latest
=
var_latest
->
Get
<
framework
::
LoDTensor
>
();
auto
*
t_old
=
var_old
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dims1
=
t_latest
.
dims
()[
1
];
phi
::
CPUContext
cpu_ctx
;
auto
*
var_delta
=
delta_scope_
->
Var
(
varname
);
auto
*
t_delta
=
var_delta
->
GetMutable
<
phi
::
SelectedRows
>
();
auto
*
var_t_value
=
t_delta
->
mutable_value
();
var_t_value
->
Resize
({
static_cast
<
int64_t
>
(
sparse_ids
.
size
()),
dims1
});
auto
*
t_value
=
var_t_value
->
mutable_data
<
float
>
(
cpu_ctx
.
GetPlace
());
t_delta
->
set_rows
(
sparse_ids
);
t_delta
->
set_height
(
t_latest
.
dims
()[
0
]);
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
CPUContext
,
float
>
(
cpu_ctx
);
float
coefficient
=
1.0
/
static_cast
<
float
>
(
trainers_
);
std
::
vector
<
float
*>
push_g_vec
;
for
(
auto
j
=
0
;
j
<
static_cast
<
int
>
(
sparse_ids
.
size
());
++
j
)
{
blas
.
VSUB
(
dims1
,
t_latest
.
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
,
t_old
->
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
,
t_value
+
j
*
dims1
);
blas
.
SCAL
(
dims1
,
coefficient
,
t_value
+
j
*
dims1
);
blas
.
VADD
(
dims1
,
t_old
->
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
,
t_value
+
j
*
dims1
,
t_old
->
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
);
push_g_vec
.
push_back
(
t_value
+
j
*
dims1
);
VLOG
(
5
)
<<
"DEBUG GeoCommunicator::SendSparse send sparse key "
<<
sparse_ids
[
j
]
<<
" value[0] "
<<
push_g_vec
[
j
][
0
]
<<
" value[-1] "
<<
push_g_vec
[
j
][
dims1
-
1
];
}
++
_async_call_num
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
this
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
// NOLINT
if
(
closure
->
check_response
(
0
,
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
auto
status
=
_worker_ptr
->
PushSparseRawGradientPartial
(
table_id
,
(
const
uint64_t
*
)
sparse_ids
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
sparse_ids
.
size
(),
closure
,
ep_idx
);
status
.
wait
();
VLOG
(
1
)
<<
"Finish Send Sparse "
<<
varname
<<
", ids.size = "
<<
sparse_ids
.
size
()
<<
", table_id: "
<<
table_id
;
return
;
}
void
GeoCommunicator
::
RecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
int
ep_idx
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->RecvSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
// 1. recv from pserver
std
::
vector
<
uint64_t
>
keys
;
std
::
vector
<
float
>
values
;
auto
status
=
_worker_ptr
->
PullGeoParam
(
table_id
,
&
values
,
&
keys
,
ep_idx
);
status
.
wait
();
std
::
string
param
=
SplitedGradToParam
(
varname
);
VLOG
(
1
)
<<
"RecvSparse receive var: "
<<
varname
<<
" "
<<
param
<<
", "
<<
table_id
<<
"; ids Size: "
<<
keys
.
size
()
<<
"; values size: "
<<
values
.
size
();
auto
*
var_latest
=
recv_scope_
->
FindVar
(
param
);
auto
*
var_old
=
old_scope_
->
FindVar
(
param
);
auto
*
t_latest
=
var_latest
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
t_old
=
var_old
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dims1
=
t_latest
->
dims
()[
1
];
auto
numel
=
keys
.
size
()
*
dims1
;
std
::
vector
<
float
>
v_delta
;
v_delta
.
resize
(
numel
);
phi
::
CPUContext
cpu_ctx
;
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
CPUContext
,
float
>
(
cpu_ctx
);
for
(
auto
j
=
0
;
j
<
static_cast
<
int
>
(
keys
.
size
());
++
j
)
{
VLOG
(
5
)
<<
"DEBUG GeoCommunicator::RecvSparse recv sparse key"
<<
keys
[
j
]
<<
"value[0] "
<<
values
[
j
*
dims1
]
<<
" value[-1] "
<<
values
[
j
*
dims1
+
dims1
-
1
];
float
*
latest_data
=
t_latest
->
data
<
float
>
()
+
keys
[
j
]
*
dims1
;
float
*
old_data
=
t_old
->
data
<
float
>
()
+
keys
[
j
]
*
dims1
;
// pserver - old => delta
blas
.
VSUB
(
dims1
,
values
.
data
()
+
j
*
dims1
,
old_data
,
v_delta
.
data
()
+
j
*
dims1
);
// latest + delta => latest
blas
.
VADD
(
dims1
,
latest_data
,
v_delta
.
data
()
+
j
*
dims1
,
latest_data
);
// pserver => old
blas
.
VCOPY
(
dims1
,
values
.
data
()
+
j
*
dims1
,
old_data
);
}
VLOG
(
1
)
<<
"Finish Recv Sparse "
<<
param
<<
", table_id: "
<<
table_id
;
}
void
GeoCommunicator
::
MainThread
()
{
VLOG
(
3
)
<<
"MainThread start and wait"
;
while
(
waiting_
&&
running_
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
VLOG
(
3
)
<<
"wait for running"
;
}
while
(
running_
)
{
std
::
vector
<
std
::
future
<
void
>>
tasks
;
tasks
.
reserve
(
parallel_task_nums_
);
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
auto
&
varnames
=
ctx
.
origin_varnames
;
auto
&
table_id
=
ctx
.
table_id
;
if
(
ctx
.
is_sparse
)
{
PADDLE_ENFORCE_EQ
(
varnames
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"sparse variables can only be merged by one variables"
));
int
pserver_num
=
static_cast
<
int
>
(
ctx
.
epmap
.
size
());
for
(
int
ep_idx
=
0
;
ep_idx
<
pserver_num
;
ep_idx
++
)
{
// varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0
auto
send_recv_task
=
[
this
,
table_id
,
ep_idx
,
&
ctx
]
{
auto
splited_varname
=
ctx
.
splited_varnames
[
ep_idx
];
// embedding_0.w_0.block0
// embedding_1.w_0.block0
auto
sparse_ids
=
MergeSparseIds
(
splited_varname
);
SendSparse
(
splited_varname
,
sparse_ids
,
table_id
,
ep_idx
);
RecvSparse
(
splited_varname
,
table_id
,
ep_idx
);
};
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_recv_task
)));
}
}
else
{
auto
send_recv_task
=
[
this
,
&
ctx
]
{
SendDense
(
ctx
);
RecvDense
(
ctx
);
};
tasks
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_recv_task
)));
}
}
for
(
auto
&
task
:
tasks
)
{
task
.
wait
();
}
}
}
void
FLCommunicator
::
InitBrpcClient
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
)
{
auto
fleet
=
paddle
::
distributed
::
FleetWrapper
::
GetInstance
();
if
(
_worker_ptr
.
get
()
==
nullptr
)
{
VLOG
(
0
)
<<
"fl-ps > FLCommunicator::InitBrpcClient get _worker_ptr"
;
_worker_ptr
=
fleet
->
worker_ptr_
;
// FleetWrapper::InitWorker must be excuted
// before, but no need for Coordinator
}
if
(
coordinator_client_ptr_
==
nullptr
)
{
coordinator_client_ptr_
.
reset
(
new
CoordinatorClient
);
}
int16_t
servers
=
host_sign_list
.
size
();
coordinator_client_ptr_
->
_env
=
&
ps_env_
;
coordinator_client_ptr_
->
_env
->
SetPsServers
(
&
host_sign_list
,
servers
);
}
void
FLCommunicator
::
StartCoordinatorClient
(
const
std
::
vector
<
std
::
string
>
&
trainer_endpoints
)
{
if
(
coordinator_client_ptr_
==
nullptr
)
{
LOG
(
ERROR
)
<<
"coordinator_client_ptr_ is null"
;
return
;
}
coordinator_client_ptr_
->
Initialize
(
trainer_endpoints
);
VLOG
(
0
)
<<
"fl-ps > StartCoordinatorClient finish!"
;
}
void
FLCommunicator
::
StartCoordinatorServer
()
{
if
(
coordinator_client_ptr_
==
nullptr
)
{
LOG
(
ERROR
)
<<
"coordinator_client_ptr_ is null"
;
}
int
ret
=
coordinator_client_ptr_
->
StartClientService
();
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"coordinator_client_ptr_ StartClientService failed"
;
}
VLOG
(
0
)
<<
"fl-ps > StartCoordinatorServer finished!"
;
return
;
}
std
::
unordered_map
<
uint32_t
,
std
::
string
>
FLCommunicator
::
QueryFLClientsInfo
()
{
return
coordinator_client_ptr_
->
QueryFLClientsInfo
();
}
void
FLCommunicator
::
SaveFLStrategy
(
const
std
::
unordered_map
<
uint32_t
,
std
::
string
>
&
fl_strategy
)
{
coordinator_client_ptr_
->
SaveFLStrategy
(
fl_strategy
);
return
;
}
void
FLCommunicator
::
SendThreadAsync
()
{
while
(
is_running_
)
{
RpcSendFLStrategy
();
}
return
;
}
void
FLCommunicator
::
RpcSendFLStrategy
()
{
std
::
set
<
uint32_t
>
clients
=
coordinator_client_ptr_
->
GetFLClientIds
();
coordinator_client_ptr_
->
WaitForFLStrategyReady
();
for
(
auto
client_id
:
clients
)
{
coordinator_client_ptr_
->
SendFLStrategy
(
client_id
);
}
coordinator_client_ptr_
->
ResetFLStrategyFlag
();
VLOG
(
0
)
<<
"fl-ps > RpcSendFLStrategy finished!"
;
return
;
}
void
FLCommunicator
::
StartCoordinator
(
const
std
::
string
&
self_endpoint
,
const
std
::
vector
<
std
::
string
>
&
trainer_endpoints
)
{
coordinator_client_ptr_
->
SetEndpoint
(
self_endpoint
);
StartCoordinatorClient
(
trainer_endpoints
);
StartCoordinatorServer
();
async_send_thread_
.
reset
(
new
std
::
thread
(
&
FLCommunicator
::
SendThreadAsync
,
this
));
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/communicator/communicator.h
0 → 100644
View file @
de2e6515
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
distributed
{
class
PSClient
;
struct
CommContext
;
}
// namespace distributed
}
// namespace paddle
DECLARE_bool
(
communicator_is_sgd_optimizer
);
namespace
paddle
{
namespace
distributed
{
using
Scope
=
framework
::
Scope
;
using
Variable
=
framework
::
Variable
;
template
<
typename
T
>
class
BlockingQueue
{
public:
explicit
BlockingQueue
(
size_t
capacity
)
:
capacity_
(
capacity
)
{
PADDLE_ENFORCE_GT
(
capacity_
,
0
,
platform
::
errors
::
InvalidArgument
(
"The capacity must be greater than 0."
));
}
bool
Push
(
const
T
&
elem
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
WaitForWrite
(
lock
);
queue_
.
push_back
(
elem
);
Notify
();
return
true
;
}
bool
WaitForWrite
(
std
::
unique_lock
<
std
::
mutex
>
&
lock
)
{
// NOLINT
while
(
FullUnlocked
())
{
if
(
empty_waiters_
!=
0
)
{
empty_cond_
.
notify_one
();
}
full_waiters_
++
;
full_cond_
.
wait
(
lock
);
full_waiters_
--
;
}
return
true
;
}
bool
WaitForRead
(
std
::
unique_lock
<
std
::
mutex
>
&
lock
)
{
// NOLINT
while
(
EmptyUnlocked
())
{
if
(
full_waiters_
!=
0
)
{
full_cond_
.
notify_one
();
}
empty_waiters_
++
;
empty_cond_
.
wait
(
lock
);
empty_waiters_
--
;
}
return
true
;
}
bool
EmptyUnlocked
()
{
return
queue_
.
empty
();
}
bool
FullUnlocked
()
{
return
queue_
.
size
()
>=
capacity_
;
}
void
Notify
()
{
if
(
empty_waiters_
!=
0
&&
(
!
EmptyUnlocked
()))
{
empty_cond_
.
notify_one
();
}
if
(
full_waiters_
!=
0
&&
(
!
FullUnlocked
()))
{
full_cond_
.
notify_one
();
}
}
bool
Push
(
T
&&
elem
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
WaitForWrite
(
lock
);
queue_
.
emplace_back
(
std
::
move
(
elem
));
Notify
();
return
true
;
}
T
Pop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
WaitForRead
(
lock
);
T
rc
(
std
::
move
(
queue_
.
front
()));
queue_
.
pop_front
();
Notify
();
return
rc
;
}
size_t
Cap
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
capacity_
;
}
size_t
Size
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
queue_
.
size
();
}
private:
int
empty_waiters_
=
0
;
int
full_waiters_
=
0
;
std
::
condition_variable
empty_cond_
;
std
::
condition_variable
full_cond_
;
const
size_t
capacity_
;
std
::
deque
<
T
>
queue_
;
mutable
std
::
mutex
mutex_
;
};
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
inline
void
MergeVars
(
const
std
::
string
&
var_name
,
const
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
&
vars
,
Scope
*
scope
,
bool
merge_add
=
true
)
{
PADDLE_ENFORCE_NE
(
vars
.
empty
(),
true
,
platform
::
errors
::
InvalidArgument
(
"vector vars are empty."
));
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
var0
=
vars
[
0
];
auto
*
out_var
=
scope
->
Var
(
var_name
);
if
(
var0
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
dims
=
var0
->
Get
<
framework
::
LoDTensor
>
().
dims
();
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" LoDTensor dims "
<<
dims
<<
"; merge add: "
<<
merge_add
;
// init output tensor
auto
*
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
out_t
->
mutable_data
<
T
>
(
dims
,
cpu_place
);
// check the input dims
for
(
auto
&
var
:
vars
)
{
auto
&
var_t
=
var
->
Get
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
var_t
.
dims
(),
dims
,
platform
::
errors
::
InvalidArgument
(
"vars should have the same dims."
));
}
// set output tensor to 0.
phi
::
CPUContext
cpu_ctx
;
phi
::
funcs
::
SetConstant
<
phi
::
CPUContext
,
T
>
constant_functor
;
constant_functor
(
cpu_ctx
,
out_t
,
static_cast
<
T
>
(
0
));
// sum all vars to out
auto
result
=
EigenVector
<
T
>::
Flatten
(
*
out_t
);
for
(
auto
&
var
:
vars
)
{
auto
&
in_t
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
in
=
EigenVector
<
T
>::
Flatten
(
in_t
);
result
.
device
(
*
cpu_ctx
.
eigen_device
())
=
result
+
in
;
}
if
(
!
merge_add
)
{
result
.
device
(
*
cpu_ctx
.
eigen_device
())
=
result
/
static_cast
<
T
>
(
vars
.
size
());
}
}
else
if
(
var0
->
IsType
<
phi
::
SelectedRows
>
())
{
auto
&
slr0
=
var0
->
Get
<
phi
::
SelectedRows
>
();
auto
*
out_slr
=
out_var
->
GetMutable
<
phi
::
SelectedRows
>
();
out_slr
->
mutable_rows
()
->
clear
();
out_slr
->
mutable_value
()
->
mutable_data
<
T
>
({{}},
cpu_place
);
std
::
vector
<
const
phi
::
SelectedRows
*>
inputs
;
inputs
.
reserve
(
vars
.
size
());
for
(
auto
&
var
:
vars
)
{
inputs
.
push_back
(
&
var
->
Get
<
phi
::
SelectedRows
>
());
}
phi
::
CPUContext
dev_ctx
;
if
(
merge_add
)
{
paddle
::
operators
::
math
::
scatter
::
MergeAdd
<
phi
::
CPUContext
,
T
>
merge_add
;
merge_add
(
dev_ctx
,
inputs
,
out_slr
);
}
else
{
paddle
::
operators
::
math
::
scatter
::
MergeAverage
<
phi
::
CPUContext
,
T
>
merge_average
;
merge_average
(
dev_ctx
,
inputs
,
out_slr
);
}
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" SelectedRows height: "
<<
slr0
.
height
()
<<
" dims: "
<<
slr0
.
value
().
dims
()
<<
"; merge add: "
<<
merge_add
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"unsupported var type: %s!"
,
var0
->
Type
()));
}
}
using
RpcCtxMap
=
std
::
unordered_map
<
std
::
string
,
CommContext
>
;
using
RecvCtxMap
=
std
::
unordered_map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
;
using
SparseValue
=
std
::
unordered_map
<
int64_t
,
std
::
vector
<
float
>>
;
class
Communicator
{
public:
Communicator
();
explicit
Communicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs_
)
{
VLOG
(
3
)
<<
"Communicator Init Envs"
;
for
(
auto
&
iter
:
envs_
)
{
envs
[
iter
.
first
]
=
iter
.
second
;
VLOG
(
3
)
<<
iter
.
first
<<
": "
<<
iter
.
second
;
}
if
(
!
envs
.
empty
())
{
barrier_table_id_
=
std
::
stoi
(
envs
.
at
(
"barrier_table_id"
));
trainer_id_
=
std
::
stoi
(
envs
.
at
(
"trainer_id"
));
trainers_
=
std
::
stoi
(
envs
.
at
(
"trainers"
));
}
}
virtual
void
InitBrpcClient
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
);
virtual
std
::
vector
<
uint64_t
>
GetClientInfo
();
virtual
int
SetClients
(
std
::
vector
<
uint64_t
>
&
host_sign_list
);
// NOLINT
// 1. recv dense param
virtual
void
RpcRecvDense
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
Scope
*
scope
);
// 2. send dense param
virtual
void
RpcSendDenseParam
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
const
Scope
&
scope
);
// 3. send dense grad
virtual
void
RpcSendDense
(
const
CommContext
&
ctx
,
const
Scope
&
scope
);
// 4. send sparse grad
virtual
void
RpcSendSparse
(
const
std
::
string
&
var_name
,
int
table_id
,
const
Scope
&
scope
);
// 5. send sparse param
virtual
void
RpcSendSparseParam
(
const
std
::
string
&
varname
,
int
table_id
,
const
Scope
&
scope
);
// 6. recv sparse param
virtual
void
RpcRecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
Scope
*
scope
);
// 7. send gloabl step
virtual
void
SendGlobalStep
(
const
CommContext
&
ctx
,
int
batches
,
Scope
*
send_scope
);
virtual
std
::
unordered_map
<
uint32_t
,
std
::
string
>
QueryFLClientsInfo
()
{
return
{};
}
virtual
void
SaveFLStrategy
(
const
std
::
unordered_map
<
uint32_t
,
std
::
string
>
&
fl_strategy
)
{}
virtual
void
StartCoordinator
(
const
std
::
string
&
self_endpoint
,
const
std
::
vector
<
std
::
string
>
&
trainer_endpoints
)
{}
virtual
~
Communicator
()
{}
virtual
void
RpcProfilerControl
();
virtual
void
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
);
// note: only for pull dense param first before training
virtual
void
PullDense
(
const
RecvCtxMap
&
recv_varname_to_ctx
);
virtual
void
Start
()
=
0
;
virtual
void
Stop
()
=
0
;
virtual
bool
IsRunning
()
{
return
running_
;
}
virtual
void
Clean
()
{}
virtual
bool
Check
(
const
int
table_id
)
=
0
;
virtual
bool
Check
(
const
std
::
vector
<
std
::
string
>
&
var_tables
)
=
0
;
virtual
void
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
=
0
;
virtual
void
RecvNoBarrier
()
{}
virtual
void
Barrier
()
{}
virtual
void
BarrierWithTable
(
uint32_t
barrier_type
)
{
auto
rets
=
_worker_ptr
->
Barrier
(
barrier_table_id_
,
barrier_type
);
rets
.
wait
();
int
status
=
rets
.
get
();
PADDLE_ENFORCE_EQ
(
status
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ret status must be 0 when barrier with table"
));
}
virtual
void
CreateC2CConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
{
_worker_ptr
->
CreateClient2ClientConnection
(
pserver_timeout_ms
,
pserver_connect_timeout_ms
,
max_retry
);
}
virtual
void
BarrierTriggerDecrement
()
{}
virtual
void
BarrierTriggerReset
(
int
init_counter
)
{}
virtual
void
InitEnvs
()
=
0
;
virtual
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{}
static
Communicator
*
GetInstance
()
{
return
communicator_
.
get
();
}
static
std
::
shared_ptr
<
Communicator
>
GetInstantcePtr
()
{
return
communicator_
;
}
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
RpcCtxMap
&
send_ctx
,
const
RecvCtxMap
&
recv_ctx
,
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
,
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithRpcCtx
<
T
>
,
send_ctx
,
recv_ctx
,
dist_desc
,
host_sign_list
,
recv_scope
,
std
::
ref
(
envs
));
return
communicator_
.
get
();
}
// called by InitInstance.
template
<
typename
T
>
static
void
InitWithRpcCtx
(
const
RpcCtxMap
&
send_ctx
,
const
RecvCtxMap
&
recv_ctx
,
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
,
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
{
VLOG
(
0
)
<<
"Communicator type is: "
<<
typeid
(
T
).
name
();
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
(
std
::
ref
(
envs
)));
communicator_
->
InitEnvs
();
communicator_
->
InitBrpcClient
(
dist_desc
,
host_sign_list
);
communicator_
->
InitImpl
(
send_ctx
,
recv_ctx
,
recv_scope
);
}
}
PSClient
*
GetPsClient
()
{
return
_worker_ptr
.
get
();
}
RecvCtxMap
&
GetRecvCtxMap
()
{
return
recv_varname_to_ctx_
;
}
std
::
shared_ptr
<
PSClient
>
_worker_ptr
;
// pointer to worker
protected:
bool
running_
=
false
;
bool
waiting_
=
true
;
bool
flushing_
=
false
;
bool
do_server_profiler_
=
false
;
static
std
::
shared_ptr
<
Communicator
>
communicator_
;
static
std
::
once_flag
init_flag_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
envs
;
// 计算每个shard 对 dense的存储量
inline
uint32_t
DenseDimPerShard
(
uint32_t
dense_dim_total
,
uint32_t
shard_num
)
{
return
dense_dim_total
/
shard_num
+
1
;
}
void
InitGFlag
(
const
std
::
string
&
gflags
);
paddle
::
distributed
::
PSParameter
_ps_param
;
paddle
::
distributed
::
PaddlePSEnvironment
_ps_env
;
int
servers_
=
0
;
int
trainers_
;
int
trainer_id_
=
0
;
int
barrier_table_id_
=
0
;
RpcCtxMap
send_varname_to_ctx_
;
RecvCtxMap
recv_varname_to_ctx_
;
Scope
*
recv_scope_
;
// should be global scope
std
::
unique_ptr
<
Scope
>
xpu_temp_scope_
;
std
::
atomic
<
uint32_t
>
_async_call_num
{
0
};
};
class
AsyncCommunicator
:
public
Communicator
{
public:
AsyncCommunicator
()
:
Communicator
()
{}
explicit
AsyncCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
:
Communicator
(
envs
)
{}
~
AsyncCommunicator
();
void
InitEnvs
()
{
independent_recv_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"communicator_independent_recv_thread"
)));
min_send_grad_num_before_recv_
=
std
::
stoi
(
envs
.
at
(
"communicator_min_send_grad_num_before_recv"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
need_global_step_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"need_global_step"
)));
}
void
Start
()
override
;
void
Stop
()
override
;
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
override
;
virtual
void
MainThread
();
virtual
void
RecvThread
();
virtual
bool
Check
(
const
int
table_id
);
virtual
bool
Check
(
const
std
::
vector
<
std
::
string
>
&
var_tables
);
void
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
override
;
virtual
void
SendByCommunicator
();
virtual
void
RecvByCommunicator
();
virtual
void
RecvNoBarrier
();
virtual
int
BatchesCounter
()
{
return
1
;
}
virtual
void
BarrierSend
()
{}
virtual
void
BarrierRecv
()
{}
virtual
void
BarrierWeakUp
()
{}
void
PushDensePostProcessing
();
void
PullSparseToTensorSync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
bool
is_training
,
std
::
vector
<
const
framework
::
LoDTensor
*>
*
inputs
,
// NOLINT
std
::
vector
<
framework
::
LoDTensor
*>
*
outputs
);
// NOLINT
void
PushSparseFromTensorAsync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
framework
::
LoDTensor
*>
*
inputs
,
const
framework
::
LoDTensor
*
shows
,
const
framework
::
LoDTensor
*
clicks
,
std
::
vector
<
framework
::
LoDTensor
*>
*
outputs
);
protected:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>>
send_varname_to_queue_
;
std
::
unique_ptr
<::
ThreadPool
>
send_threadpool_
{
nullptr
};
int
min_send_grad_num_before_recv_
;
int
thread_pool_size_
;
int
max_merge_var_num_
;
int
send_wait_times_
;
int
send_queue_size_
;
bool
need_global_step_
=
false
;
bool
independent_recv_
=
true
;
int
parallel_task_nums_
=
0
;
int32_t
sleep_seconds_before_fail_exit_
;
std
::
unique_ptr
<
std
::
thread
>
main_thread_
{
nullptr
};
std
::
unique_ptr
<
std
::
thread
>
recv_thread_
{
nullptr
};
std
::
unique_ptr
<
Scope
>
send_scope_
;
// an independent scope
std
::
atomic_uint
grad_num_
{
0
};
// the num of gradient sent since last recv
};
class
HalfAsyncCommunicator
:
public
AsyncCommunicator
{
public:
HalfAsyncCommunicator
()
{}
explicit
HalfAsyncCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
:
AsyncCommunicator
(
envs
)
{}
void
InitEnvs
()
{
// enfore to recv after send
independent_recv_
=
false
;
min_send_grad_num_before_recv_
=
0
;
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
need_global_step_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"need_global_step"
)));
VLOG
(
1
)
<<
"HalfAsyncCommunicator Initialized"
;
}
void
MainThread
()
override
;
void
SendByCommunicator
()
override
;
void
Clean
()
override
;
void
Barrier
()
override
;
void
BarrierTriggerDecrement
()
override
;
void
BarrierTriggerReset
(
int
initial_val
)
override
;
int
BatchesCounter
();
void
BarrierWeakUp
();
protected:
// mutex for Wait for barrier
std
::
mutex
barrier_mutex_
;
std
::
condition_variable
barrier_cond_
;
std
::
atomic
<
int64_t
>
barrier_trigger_
{
0
};
std
::
atomic
<
int64_t
>
barrier_counter_
{
0
};
};
class
SyncCommunicator
:
public
HalfAsyncCommunicator
{
public:
SyncCommunicator
()
:
HalfAsyncCommunicator
()
{}
explicit
SyncCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
:
HalfAsyncCommunicator
(
envs
)
{}
void
InitEnvs
()
{
// enfore to recv after send
independent_recv_
=
false
;
min_send_grad_num_before_recv_
=
0
;
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
need_global_step_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"need_global_step"
)));
VLOG
(
1
)
<<
"SyncCommunicator Initialized"
;
}
void
BarrierSend
();
void
BarrierRecv
();
private:
std
::
vector
<
std
::
string
>
pserver_endpoints_
{};
};
class
GeoCommunicator
:
public
AsyncCommunicator
{
public:
GeoCommunicator
()
:
AsyncCommunicator
()
{}
explicit
GeoCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
:
AsyncCommunicator
(
envs
)
{}
void
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
override
;
void
InitDense
(
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
);
// NOLINT
void
InitSparse
(
const
std
::
string
&
var_name
,
int
table_id
);
void
SendDense
(
const
CommContext
&
send_ctx
);
void
RecvDense
(
const
CommContext
&
send_ctx
);
std
::
vector
<
int64_t
>
MergeSparseIds
(
const
std
::
string
&
varname
);
void
SendSparse
(
const
std
::
string
&
varname
,
std
::
vector
<
int64_t
>
&
sparse_ids
,
// NOLINT
int
table_id
,
int
ep_idx
);
void
RecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
int
ep_idx
);
void
MainThread
()
override
;
virtual
void
InitEnvs
()
{
independent_recv_
=
false
;
min_send_grad_num_before_recv_
=
0
;
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
// id_queue's size
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_queue_size_
=
max_merge_var_num_
;
VLOG
(
1
)
<<
"GeoCommunicator Initialized"
;
}
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
override
;
void
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
framework
::
Scope
&
scope
)
override
;
void
SendByCommunicator
()
{
return
;
}
void
RecvByCommunicator
()
override
{
return
;
}
inline
std
::
string
GradToParam
(
const
std
::
string
var_name
)
{
std
::
string
param_name
=
var_name
.
substr
(
0
,
var_name
.
size
()
-
5
);
return
param_name
;
}
inline
std
::
string
SplitedGradToParam
(
const
std
::
string
delta_name
)
{
// delta_name: emb.delta0
auto
pos
=
delta_name
.
find
(
".block"
);
std
::
string
param_name
=
delta_name
.
substr
(
0
,
pos
);
return
param_name
;
}
public:
// parameter for delta calc and send
std
::
shared_ptr
<
Scope
>
delta_scope_
;
// parameter for storage the pserver param after last recv
std
::
shared_ptr
<
Scope
>
old_scope_
;
// parameter on pserver
std
::
shared_ptr
<
Scope
>
pserver_scope_
;
std
::
unordered_map
<
std
::
string
,
paddle
::
framework
::
Channel
<
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>>
sparse_id_queues_
;
};
class
FLCommunicator
:
public
GeoCommunicator
{
public:
FLCommunicator
()
:
GeoCommunicator
()
{}
~
FLCommunicator
()
{
is_running_
=
false
;
async_send_thread_
->
join
();
}
explicit
FLCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs
)
:
GeoCommunicator
(
envs
)
{}
void
InitEnvs
()
override
{}
virtual
void
InitBrpcClient
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
);
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RecvCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{}
void
StartCoordinatorClient
(
const
std
::
vector
<
std
::
string
>
&
trainer_endpoints
);
void
StartCoordinatorServer
();
void
StartCoordinator
(
const
std
::
string
&
self_endpoint
,
const
std
::
vector
<
std
::
string
>
&
trainer_endpoints
)
override
;
std
::
unordered_map
<
uint32_t
,
std
::
string
>
QueryFLClientsInfo
();
void
SaveFLStrategy
(
const
std
::
unordered_map
<
uint32_t
,
std
::
string
>
&
fl_strategy
);
void
SendThreadAsync
();
void
RpcSendFLStrategy
();
private:
int
thread_pool_size_
=
1
;
bool
is_running_
=
true
;
PaddlePSEnvironment
ps_env_
;
std
::
shared_ptr
<
CoordinatorClient
>
coordinator_client_ptr_
{
nullptr
};
std
::
unique_ptr
<
std
::
thread
>
async_send_thread_
{
nullptr
};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/communicator/communicator_common.h
0 → 100644
View file @
de2e6515
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
namespace
paddle
{
namespace
distributed
{
struct
CommContext
{
CommContext
()
=
default
;
CommContext
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
std
::
string
>
&
emap
,
const
std
::
vector
<
int64_t
>
&
sections
,
const
std
::
vector
<
std
::
string
>
&
origin_names
,
int
trainer_id
,
bool
merge_add
=
true
,
bool
is_sparse
=
true
,
bool
is_distributed
=
false
,
int
table_id
=
-
1
,
bool
is_tensor_table
=
false
,
bool
is_datanorm_table
=
false
,
int64_t
program_id
=
-
1
,
const
std
::
vector
<
int32_t
>
&
remote_sparse_ids
=
{})
:
var_name
(
name
),
splited_varnames
(
names
),
epmap
(
emap
),
height_sections
(
sections
),
origin_varnames
(
origin_names
),
trainer_id
(
trainer_id
),
merge_add
(
merge_add
),
is_sparse
(
is_sparse
),
is_distributed
(
is_distributed
),
table_id
(
table_id
),
program_id
(
program_id
),
is_tensor_table
(
is_tensor_table
),
is_datanorm_table
(
is_datanorm_table
),
remote_sparse_ids
(
remote_sparse_ids
)
{}
CommContext
(
const
CommContext
&
ctx
)
{
var_name
=
ctx
.
var_name
;
splited_varnames
=
ctx
.
splited_varnames
;
epmap
=
ctx
.
epmap
;
height_sections
=
ctx
.
height_sections
;
trainer_id
=
ctx
.
trainer_id
;
merge_add
=
ctx
.
merge_add
;
is_sparse
=
ctx
.
is_sparse
;
origin_varnames
=
ctx
.
origin_varnames
;
is_distributed
=
ctx
.
is_distributed
;
table_id
=
ctx
.
table_id
;
program_id
=
ctx
.
program_id
;
is_tensor_table
=
ctx
.
is_tensor_table
;
is_datanorm_table
=
ctx
.
is_datanorm_table
;
remote_sparse_ids
=
ctx
.
remote_sparse_ids
;
}
std
::
string
print
()
const
{
std
::
stringstream
ss
;
ss
<<
"varname: "
<<
var_name
<<
" trainer_id: "
<<
trainer_id
<<
" "
;
ss
<<
" table_id: "
<<
table_id
;
std
::
for_each
(
remote_sparse_ids
.
begin
(),
remote_sparse_ids
.
end
(),
[
&
](
const
int
&
i
)
{
ss
<<
"remote_sparse_id: "
<<
i
<<
" "
;
});
for
(
size_t
i
=
0
;
i
<
splited_varnames
.
size
();
i
++
)
{
ss
<<
"slice varname: "
<<
splited_varnames
[
i
]
<<
" ep: "
<<
epmap
[
i
]
<<
" section: "
<<
height_sections
[
i
]
<<
" "
;
}
ss
<<
"origin varnames: "
;
for
(
size_t
i
=
0
;
i
<
origin_varnames
.
size
();
i
++
)
{
ss
<<
origin_varnames
[
i
]
<<
" "
;
}
ss
<<
" aggregation->add: "
<<
merge_add
;
ss
<<
" is_sparse: "
<<
is_sparse
;
ss
<<
" is_distributed: "
<<
is_distributed
<<
"
\n
"
;
ss
<<
" table_id: "
<<
table_id
<<
"
\n
"
;
ss
<<
" program_id: "
<<
program_id
<<
"
\n
"
;
ss
<<
" is_tensor_table: "
<<
is_tensor_table
<<
"
\n
"
;
ss
<<
" is_datanorm_table: "
<<
is_datanorm_table
<<
"
\n
"
;
return
ss
.
str
();
}
std
::
string
var_name
;
std
::
vector
<
std
::
string
>
splited_varnames
;
std
::
vector
<
std
::
string
>
epmap
;
std
::
vector
<
int64_t
>
height_sections
;
std
::
vector
<
std
::
string
>
origin_varnames
;
int
trainer_id
;
bool
merge_add
;
bool
is_sparse
;
bool
is_distributed
;
int
table_id
;
int64_t
program_id
;
bool
is_tensor_table
;
bool
is_datanorm_table
;
std
::
vector
<
int32_t
>
remote_sparse_ids
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/coordinator_client.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include <memory>
#include <sstream>
#include <string>
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/split.h"
static
const
int
MIN_PORT
=
8500
;
static
const
int
MAX_PORT
=
65535
;
namespace
paddle
{
namespace
distributed
{
DEFINE_uint64
(
total_fl_client_size
,
100
,
"supported total fl client size"
);
DEFINE_uint32
(
coordinator_wait_all_clients_max_time
,
60
,
"uint32: s"
);
void
CoordinatorService
::
FLService
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
CoordinatorReqMessage
*
request
,
CoordinatorResMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
int32_t
msg_type
=
request
->
cmd_id
();
uint32_t
from_client_id
=
request
->
client_id
();
VLOG
(
0
)
<<
"fl-ps > recv from client id: "
<<
from_client_id
<<
", msg_type: "
<<
msg_type
;
// TODO(ziyoujiyi): find is not thread safe, beacuse of RB_Tree traversal
auto
itr
=
_service_handle_map
.
find
(
msg_type
);
if
(
itr
==
_service_handle_map
.
end
())
{
LOG
(
ERROR
)
<<
"fl-ps > unknown flClient2Coordinator msg type: "
<<
msg_type
;
return
;
}
int
ret
=
itr
->
second
(
*
request
,
response
,
cntl
);
// SaveFLClientInfo
if
(
ret
!=
0
)
{
response
->
set_err_code
(
-
1
);
response
->
set_err_msg
(
"fl-ps > handle flClient2Coordinator msg failed"
);
}
return
;
}
int32_t
CoordinatorClient
::
Initialize
(
const
std
::
vector
<
std
::
string
>&
trainer_endpoints
)
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
paddle
::
distributed
::
FLAGS_pserver_timeout_ms
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
paddle
::
distributed
::
FLAGS_pserver_connect_timeout_ms
;
options
.
max_retry
=
3
;
std
::
string
server_ip_port
;
// 获取 Pserver 列表,并连接
if
(
_env
==
nullptr
)
{
LOG
(
ERROR
)
<<
"_env is null in CoordinatorClient::Initialize()"
;
return
-
1
;
}
std
::
vector
<
PSHost
>
pserver_list
=
_env
->
GetPsServers
();
_pserver_channels
.
resize
(
pserver_list
.
size
());
for
(
size_t
i
=
0
;
i
<
pserver_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
pserver_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
pserver_list
[
i
].
port
));
for
(
size_t
j
=
0
;
j
<
_pserver_channels
[
i
].
size
();
++
j
)
{
_pserver_channels
[
i
][
j
].
reset
(
new
brpc
::
Channel
());
if
(
_pserver_channels
[
i
][
j
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"CoordinatorClient connect to PServer:"
<<
server_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
pserver_list
[
i
].
ip
,
pserver_list
[
i
].
port
);
if
(
_pserver_channels
[
i
][
j
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"CoordinatorClient connect to PServer:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
}
}
// 获取 fl_client 列表,并连接
std
::
vector
<
PSHost
>
fl_client_list
;
fl_client_list
.
resize
(
trainer_endpoints
.
size
());
if
(
fl_client_list
.
empty
())
{
LOG
(
ERROR
)
<<
">>> fl clients addr info lost"
;
return
-
1
;
}
for
(
size_t
i
=
0
;
i
<
trainer_endpoints
.
size
();
i
++
)
{
std
::
vector
<
std
::
string
>
addr
=
paddle
::
string
::
Split
(
trainer_endpoints
[
i
],
':'
);
fl_client_list
[
i
].
ip
=
addr
[
0
];
fl_client_list
[
i
].
port
=
std
::
stol
(
addr
[
1
]);
fl_client_list
[
i
].
rank
=
i
;
// TO CHECK
}
std
::
string
fl_client_ip_port
;
for
(
size_t
i
=
0
;
i
<
fl_client_list
.
size
();
++
i
)
{
fl_client_ip_port
.
assign
(
fl_client_list
[
i
].
ip
);
fl_client_ip_port
.
append
(
":"
);
fl_client_ip_port
.
append
(
std
::
to_string
(
fl_client_list
[
i
].
port
));
uint32_t
rank
=
fl_client_list
[
i
].
rank
;
VLOG
(
0
)
<<
"fl-ps > coordinator connect to fl_client: "
<<
rank
;
_fl_client_channels
[
rank
].
reset
(
new
brpc
::
Channel
());
if
(
_fl_client_channels
[
rank
]
->
Init
(
fl_client_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"CoordinatorClient connect to FLClient:"
<<
fl_client_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
fl_client_list
[
i
].
ip
,
fl_client_list
[
i
].
port
);
if
(
_fl_client_channels
[
rank
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"CoordinatorClient connect to PSClient:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
}
SetTotalFLClientsNum
(
fl_client_list
.
size
());
SetDefaultFLStrategy
();
return
0
;
}
int32_t
CoordinatorClient
::
StartClientService
()
{
_service
.
Initialize
();
_server
.
AddService
(
&
_service
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
);
brpc
::
ServerOptions
options
;
options
.
num_threads
=
1
;
if
(
_endpoint
.
empty
())
{
LOG
(
ERROR
)
<<
"fl-ps > coordinator server endpoint not set"
;
return
-
1
;
}
auto
addr
=
paddle
::
string
::
Split
(
_endpoint
,
':'
);
std
::
string
ip
=
addr
[
0
];
std
::
string
port
=
addr
[
1
];
std
::
string
rank
=
addr
[
2
];
std
::
string
ip_port
=
ip
+
":"
+
port
;
if
(
_server
.
Start
(
ip_port
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"fl-ps > StartClientService failed"
;
return
-
1
;
}
uint32_t
port_
=
std
::
stol
(
port
);
int32_t
rank_
=
std
::
stoi
(
rank
);
_env
->
RegisteCoordinatorClient
(
ip
,
port_
,
rank_
);
VLOG
(
0
)
<<
"fl-ps > coordinator service addr: "
<<
ip
<<
", "
<<
port
<<
", "
<<
_coordinator_id
;
return
0
;
}
void
CoordinatorClient
::
SendFLStrategy
(
const
uint32_t
&
client_id
)
{
size_t
request_call_num
=
1
;
FlClientBrpcClosure
*
closure
=
new
FlClientBrpcClosure
(
request_call_num
,
[](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
FlClientBrpcClosure
*>
(
done
);
int
ret
=
0
;
if
(
closure
->
check_response
(
0
,
PUSH_FL_STRATEGY
)
!=
0
)
{
LOG
(
ERROR
)
<<
"fl-ps > SendFLStrategy failed"
;
ret
=
-
1
;
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int32_t
>
fut
=
promise
->
get_future
();
closure
->
add_promise
(
promise
);
closure
->
request
(
0
)
->
set_cmd_id
(
PUSH_FL_STRATEGY
);
closure
->
request
(
0
)
->
set_client_id
(
client_id
);
std
::
string
fl_strategy
=
_fl_strategy_mp
[
client_id
];
closure
->
request
(
0
)
->
set_str_params
(
fl_strategy
);
brpc
::
Channel
*
rpc_channel
=
_fl_client_channels
[
client_id
].
get
();
if
(
rpc_channel
==
nullptr
)
{
LOG
(
ERROR
)
<<
"fl-ps > _fl_client_channels is null"
;
return
;
}
PsService_Stub
rpc_stub
(
rpc_channel
);
// DownpourPsClientService
rpc_stub
.
FLService
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
fut
.
wait
();
VLOG
(
0
)
<<
"fl-ps > SendFLStrategy to client: "
<<
client_id
<<
" finished"
;
return
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/coordinator_client.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace
paddle
{
namespace
distributed
{
DECLARE_int32
(
pserver_timeout_ms
);
DECLARE_int32
(
pserver_connect_timeout_ms
);
DECLARE_uint64
(
total_fl_client_size
);
DECLARE_uint32
(
coordinator_wait_all_clients_max_time
);
using
CoordinatorServiceFunc
=
std
::
function
<
int32_t
(
const
CoordinatorReqMessage
&
request
,
CoordinatorResMessage
*
response
,
brpc
::
Controller
*
cntl
)
>
;
class
ClientReportedInfo
{
public:
ClientReportedInfo
()
{}
~
ClientReportedInfo
()
{}
uint32_t
client_id
;
uint32_t
iteration_idx
;
double
auc
=
0.0
;
};
class
CoordinatorServiceHandle
{
public:
CoordinatorServiceHandle
()
{}
virtual
~
CoordinatorServiceHandle
()
{}
void
SaveFLClientInfo
(
const
CoordinatorReqMessage
&
request
)
{
auto
client_id
=
request
.
client_id
();
const
std
::
string
&
str_params
=
request
.
str_params
();
// each client is allowed to send empty message to maintain heartbeat(i.e.
// use staleness msg)
std
::
unique_lock
<
std
::
mutex
>
lck
(
_mtx
);
if
(
str_params
.
size
()
!=
0
)
{
_client_info_mp
[
client_id
]
=
str_params
;
}
else
{
LOG
(
INFO
)
<<
"fl-ps > content in request from "
<<
client_id
<<
" is null"
;
}
fl_client_ids
.
insert
(
client_id
);
_fl_clients_count
++
;
// TODO(ziyoujiyi): how to process when a client loss connection?
if
(
_fl_clients_count
.
load
()
==
last_round_total_fl_clients_num
)
{
_is_all_clients_info_collected
=
true
;
_cv
.
notify_one
();
}
lck
.
unlock
();
VLOG
(
0
)
<<
"last_round_total_fl_clients_num: "
<<
last_round_total_fl_clients_num
<<
", has recved fl client num: "
<<
_fl_clients_count
.
load
();
return
;
}
std
::
unordered_map
<
uint32_t
,
std
::
string
>
QueryFLClientsInfo
()
{
platform
::
Timer
timeline
;
double
query_wait_time
=
0.0
;
timeline
.
Start
();
auto
f
=
[
&
]()
->
bool
{
while
(
query_wait_time
<
paddle
::
distributed
::
FLAGS_coordinator_wait_all_clients_max_time
)
{
// in case that
// some
// clients down
if
(
_is_all_clients_info_collected
==
true
)
{
// LOG(INFO) << "fl-ps > _is_all_clients_info_collected";
return
true
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
timeline
.
Pause
();
query_wait_time
+=
timeline
.
ElapsedSec
();
}
// LOG(WARNNING) << "fl-ps > query_wait_time exceed!";
return
true
;
};
std
::
unique_lock
<
std
::
mutex
>
lck
(
_mtx
);
_cv
.
wait
(
lck
,
f
);
lck
.
unlock
();
_is_all_clients_info_collected
=
false
;
_fl_clients_count
.
store
(
0
);
return
_client_info_mp
;
}
public:
std
::
unordered_map
<
uint32_t
,
std
::
string
>
_client_info_mp
;
std
::
set
<
uint32_t
>
fl_client_ids
;
uint32_t
last_round_total_fl_clients_num
=
0
;
bool
_is_all_clients_info_collected
=
false
;
private:
std
::
mutex
_mtx
;
std
::
condition_variable
_cv
;
std
::
atomic
<
uint32_t
>
_fl_clients_count
{
0
};
};
class
CoordinatorService
:
public
PsService
{
public:
CoordinatorService
()
{
_coordinator_service_handle
=
std
::
make_shared
<
CoordinatorServiceHandle
>
();
}
virtual
~
CoordinatorService
()
{}
virtual
void
Initialize
()
{
_service_handle_map
[
PUSH_FL_CLIENT_INFO_SYNC
]
=
std
::
bind
(
&
CoordinatorService
::
SaveFLClientInfo
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
}
virtual
void
FLService
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
CoordinatorReqMessage
*
request
,
CoordinatorResMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
);
int32_t
SaveFLClientInfo
(
const
CoordinatorReqMessage
&
request
,
CoordinatorResMessage
*
response
,
brpc
::
Controller
*
cntl
)
{
_coordinator_service_handle
->
SaveFLClientInfo
(
request
);
return
0
;
}
void
SetTotalFLClientsNum
(
uint32_t
all_fl_clients_num
)
{
if
(
_coordinator_service_handle
.
get
()
!=
nullptr
)
{
_coordinator_service_handle
->
last_round_total_fl_clients_num
=
all_fl_clients_num
;
}
else
{
LOG
(
ERROR
)
<<
"fl-ps > _coordinator_service_handle is null in "
"CoordinatorService"
;
}
return
;
}
std
::
set
<
uint32_t
>
GetFLClientIds
()
{
return
_coordinator_service_handle
->
fl_client_ids
;
}
std
::
unordered_map
<
uint32_t
,
std
::
string
>
QueryFLClientsInfo
()
{
return
_coordinator_service_handle
->
QueryFLClientsInfo
();
}
private:
std
::
shared_ptr
<
CoordinatorServiceHandle
>
_coordinator_service_handle
;
std
::
unordered_map
<
int32_t
,
CoordinatorServiceFunc
>
_service_handle_map
;
std
::
mutex
_mtx
;
};
class
CoordinatorClient
:
public
BrpcPsClient
{
public:
CoordinatorClient
()
:
_coordinator_id
(
0
)
{}
virtual
~
CoordinatorClient
()
{}
int32_t
Initialize
(
const
std
::
vector
<
std
::
string
>&
trainer_endpoints
);
void
SetTotalFLClientsNum
(
uint32_t
all_fl_clients_num
)
{
_service
.
SetTotalFLClientsNum
(
all_fl_clients_num
);
this
->
_total_clients_num
=
all_fl_clients_num
;
return
;
}
int32_t
StartClientService
();
void
SaveFLStrategy
(
const
std
::
unordered_map
<
uint32_t
,
std
::
string
>&
fl_strategy
)
{
for
(
auto
it
=
fl_strategy
.
begin
();
it
!=
fl_strategy
.
end
();
it
++
)
{
uint32_t
client_id
=
it
->
first
;
_fl_strategy_mp
[
client_id
]
=
it
->
second
;
}
std
::
unique_lock
<
std
::
mutex
>
lck
(
_mtx
);
_is_fl_strategy_ready
=
true
;
_cv
.
notify_all
();
return
;
}
void
WaitForFLStrategyReady
()
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
_mtx
);
_cv
.
wait
(
lck
,
[
=
]()
{
return
_is_fl_strategy_ready
;
});
}
void
SendFLStrategy
(
const
uint32_t
&
client_id
);
void
ResetFLStrategyFlag
()
{
_is_fl_strategy_ready
=
false
;
}
void
SetDefaultFLStrategy
()
{
for
(
size_t
i
=
0
;
i
<
_total_clients_num
;
i
++
)
{
_fl_strategy_mp
[
i
]
=
""
;
}
return
;
}
std
::
set
<
uint32_t
>
GetFLClientIds
()
{
return
_service
.
GetFLClientIds
();
}
std
::
unordered_map
<
uint32_t
,
std
::
string
>
QueryFLClientsInfo
()
{
return
_service
.
QueryFLClientsInfo
();
}
void
SetEndpoint
(
const
std
::
string
&
endpoint
)
{
_endpoint
=
std
::
move
(
endpoint
);
}
public:
size_t
_coordinator_id
;
uint32_t
_total_clients_num
;
std
::
string
_endpoint
;
std
::
vector
<
std
::
array
<
std
::
shared_ptr
<
brpc
::
Channel
>
,
1
>>
_pserver_channels
;
// coordinator2pserver
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
brpc
::
Channel
>>
_fl_client_channels
;
// coordinator2psclient
brpc
::
Server
_server
;
CoordinatorService
_service
;
std
::
unordered_map
<
uint32_t
,
std
::
string
>
_fl_strategy_mp
;
bool
_is_fl_strategy_ready
=
false
;
std
::
mutex
_mtx
;
std
::
condition_variable
_cv
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/env.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/env.h"
namespace
paddle
{
namespace
distributed
{}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/env.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <arpa/inet.h>
#include <glog/logging.h>
#include <netinet/in.h>
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "gflags/gflags.h"
namespace
paddle
{
namespace
distributed
{
struct
PSHost
{
std
::
string
ip
;
uint32_t
port
;
uint32_t
rank
;
PSHost
()
=
default
;
PSHost
(
const
std
::
string
ip
,
uint32_t
port
,
uint32_t
rank
)
:
ip
(
ip
),
port
(
port
),
rank
(
rank
)
{}
// |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-|
uint64_t
SerializeToUint64
()
{
uint64_t
host_label
=
0
;
host_label
=
inet_addr
(
ip
.
c_str
());
host_label
=
host_label
<<
32
;
host_label
+=
(
port
<<
12
);
host_label
+=
rank
;
return
host_label
;
}
void
ParseFromUint64
(
uint64_t
host_label
)
{
static
uint64_t
rank_label_mask
=
(
1L
<<
12
)
-
1
;
static
uint64_t
port_label_mask
=
(
1L
<<
20
)
-
1
;
rank
=
host_label
&
rank_label_mask
;
port
=
(
host_label
>>
12
)
&
port_label_mask
;
uint32_t
ip_addr
=
(
host_label
>>
32
);
ip
=
inet_ntoa
(
*
(
in_addr
*
)
&
ip_addr
);
// NOLINT
}
std
::
string
ToString
()
{
std
::
stringstream
s
;
s
<<
"host: "
<<
ip
;
s
<<
" port: "
<<
port
;
s
<<
" rank: "
<<
rank
;
s
<<
" uint64: "
<<
SerializeToUint64
();
return
s
.
str
();
}
// for open source parameter server
std
::
string
SerializeToString
()
{
std
::
stringstream
s
;
s
<<
ip
<<
":"
;
s
<<
port
<<
":"
;
s
<<
rank
;
return
s
.
str
();
}
void
ParseFromString
(
std
::
string
endpoint
)
{
std
::
vector
<
std
::
string
>
endpoint_info
;
StringSplit
(
endpoint
,
':'
,
&
endpoint_info
);
ip
=
endpoint_info
[
0
];
port
=
std
::
stoi
(
endpoint_info
[
1
]);
rank
=
std
::
stoi
(
endpoint_info
[
2
]);
}
void
StringSplit
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
,
bool
ignore_null
=
true
)
{
pieces
->
clear
();
if
(
str
.
empty
())
{
if
(
!
ignore_null
)
{
pieces
->
push_back
(
str
);
}
return
;
}
size_t
pos
=
0
;
size_t
next
=
str
.
find
(
sep
,
pos
);
while
(
next
!=
std
::
string
::
npos
)
{
pieces
->
push_back
(
str
.
substr
(
pos
,
next
-
pos
));
pos
=
next
+
1
;
next
=
str
.
find
(
sep
,
pos
);
}
if
(
!
str
.
substr
(
pos
).
empty
())
{
pieces
->
push_back
(
str
.
substr
(
pos
));
}
}
};
class
PSEnvironment
{
public:
explicit
PSEnvironment
()
{}
// NOLINT
virtual
~
PSEnvironment
()
{}
virtual
int32_t
SetPsServers
(
uint64_t
*
host_sign_list
,
int
node_num
)
{
return
0
;
}
virtual
int32_t
SetPsServers
(
const
std
::
vector
<
std
::
string
>
*
host_endpoint_list
,
int
node_num
)
{
return
0
;
}
virtual
int32_t
SetPsClients
(
uint64_t
*
host_sign_list
,
int
node_num
)
{
return
0
;
}
virtual
int32_t
SetPsClients
(
std
::
string
*
host_endpoint_list
,
int
node_num
)
{
return
0
;
}
virtual
uint64_t
GetLocalHostSign
()
{
return
0
;
}
virtual
std
::
vector
<
PSHost
>
GetPsServers
()
const
{
return
_ps_server_list
;
}
virtual
int32_t
RegistePsServer
(
const
std
::
string
&
ip
,
uint32_t
port
,
int32_t
rank
)
{
return
RegistePsHost
(
ip
,
port
,
rank
,
_ps_server_list
,
_ps_server_sign_set
);
}
virtual
std
::
vector
<
PSHost
>
GetPsClients
()
const
{
return
_ps_client_list
;
}
virtual
int32_t
RegistePsClient
(
const
std
::
string
&
ip
,
uint32_t
port
,
int32_t
rank
)
{
return
RegistePsHost
(
ip
,
port
,
rank
,
_ps_client_list
,
_ps_client_sign_set
);
}
virtual
std
::
vector
<
PSHost
>
GetCoordinators
()
const
{
return
_coordinator_list
;
}
virtual
int32_t
RegisteCoordinatorClient
(
const
std
::
string
&
ip
,
uint32_t
port
,
int32_t
rank
)
{
return
RegistePsHost
(
ip
,
port
,
rank
,
_coordinator_list
,
_coordinator_sign_set
);
}
virtual
std
::
vector
<
uint64_t
>
GetClientInfo
()
{
std
::
vector
<
uint64_t
>
client_info
;
for
(
auto
&
i
:
_ps_client_list
)
{
client_info
.
push_back
(
i
.
SerializeToUint64
());
}
return
client_info
;
}
virtual
std
::
vector
<
std
::
string
>
GetClientInfo
(
bool
use_string_endpoint
)
{
if
(
use_string_endpoint
)
{
std
::
vector
<
std
::
string
>
client_info
;
for
(
auto
&
i
:
_ps_client_list
)
{
client_info
.
push_back
(
i
.
SerializeToString
());
}
return
client_info
;
}
return
{};
}
virtual
void
SetTrainers
(
int
trainers
)
{
trainers_
=
trainers
;
}
virtual
int
GetTrainers
()
{
return
trainers_
;
}
protected:
//注册一个host // NOLINT
virtual
int32_t
RegistePsHost
(
const
std
::
string
&
ip
,
uint32_t
port
,
int32_t
rank
,
std
::
vector
<
PSHost
>
&
host_list
,
// NOLINT
std
::
unordered_set
<
uint64_t
>
&
sign_set
)
{
// NOLINT
PSHost
host
;
host
.
ip
=
ip
;
host
.
port
=
port
;
host
.
rank
=
rank
;
if
(
sign_set
.
count
(
rank
)
==
0
)
{
host_list
.
push_back
(
host
);
sign_set
.
insert
(
rank
);
}
return
0
;
}
int
trainers_
=
0
;
std
::
vector
<
PSHost
>
_ps_client_list
;
std
::
unordered_set
<
uint64_t
>
_ps_client_sign_set
;
// for unique filter
std
::
vector
<
PSHost
>
_ps_server_list
;
std
::
unordered_set
<
uint64_t
>
_ps_server_sign_set
;
// for unique filter
std
::
vector
<
PSHost
>
_coordinator_list
;
std
::
unordered_set
<
uint64_t
>
_coordinator_sign_set
;
};
class
PaddlePSEnvironment
:
public
PSEnvironment
{
public:
explicit
PaddlePSEnvironment
()
{}
// NOLINT
virtual
~
PaddlePSEnvironment
()
{}
virtual
int32_t
SetPsServers
(
uint64_t
*
host_sign_list
,
int
node_num
)
{
_ps_server_list
.
clear
();
_ps_server_sign_set
.
clear
();
for
(
int
i
=
0
;
i
<
node_num
;
++
i
)
{
if
(
host_sign_list
[
i
]
>
0
)
{
PSHost
host
;
host
.
ParseFromUint64
(
host_sign_list
[
i
]);
_ps_server_list
.
push_back
(
host
);
_ps_server_sign_set
.
insert
(
host
.
SerializeToUint64
());
}
}
std
::
sort
(
_ps_server_list
.
begin
(),
_ps_server_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
return
0
;
}
virtual
int32_t
SetPsServers
(
const
std
::
vector
<
std
::
string
>
*
host_sign_list
,
int
node_num
)
{
_ps_server_list
.
clear
();
_ps_server_sign_set
.
clear
();
for
(
int
i
=
0
;
i
<
node_num
;
++
i
)
{
if
(
host_sign_list
->
at
(
i
)
!=
""
)
{
PSHost
host
;
host
.
ParseFromString
(
host_sign_list
->
at
(
i
));
_ps_server_list
.
push_back
(
host
);
_ps_server_sign_set
.
insert
(
host
.
rank
);
}
}
std
::
sort
(
_ps_server_list
.
begin
(),
_ps_server_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
return
0
;
}
virtual
int32_t
SetPsClients
(
uint64_t
*
host_sign_list
,
int
node_num
)
{
_ps_client_list
.
clear
();
_ps_client_sign_set
.
clear
();
for
(
int
i
=
0
;
i
<
node_num
;
++
i
)
{
if
(
host_sign_list
[
i
]
>
0
)
{
PSHost
host
;
host
.
ParseFromUint64
(
host_sign_list
[
i
]);
_ps_client_list
.
push_back
(
host
);
_ps_client_sign_set
.
insert
(
host
.
SerializeToUint64
());
}
}
std
::
sort
(
_ps_client_list
.
begin
(),
_ps_client_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
return
0
;
}
virtual
int32_t
SetPsClients
(
const
std
::
vector
<
std
::
string
>
*
host_sign_list
,
int
node_num
)
{
_ps_client_list
.
clear
();
_ps_client_sign_set
.
clear
();
for
(
int
i
=
0
;
i
<
node_num
;
++
i
)
{
if
(
host_sign_list
->
at
(
i
)
!=
""
)
{
PSHost
host
;
host
.
ParseFromString
(
host_sign_list
->
at
(
i
));
_ps_client_list
.
push_back
(
host
);
_ps_client_sign_set
.
insert
(
host
.
rank
);
}
}
std
::
sort
(
_ps_client_list
.
begin
(),
_ps_client_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
VLOG
(
1
)
<<
"env.set_ps_clients done
\n
"
;
return
0
;
}
virtual
void
SetCoordinators
(
const
std
::
vector
<
std
::
string
>
*
host_sign_list
,
size_t
node_num
)
{
_coordinator_list
.
clear
();
_coordinator_sign_set
.
clear
();
for
(
size_t
i
=
0
;
i
<
node_num
;
++
i
)
{
if
(
host_sign_list
->
at
(
i
)
!=
""
)
{
PSHost
host
;
host
.
ParseFromString
(
host_sign_list
->
at
(
i
));
_coordinator_list
.
push_back
(
host
);
_coordinator_sign_set
.
insert
(
host
.
rank
);
VLOG
(
0
)
<<
"fl-ps > coordinator info in env: "
<<
host
.
ToString
();
}
}
return
;
}
virtual
uint64_t
GetLocalHostSign
()
{
if
(
_ps_client_list
.
size
()
>
0
)
{
return
_ps_client_list
[
0
].
SerializeToUint64
();
}
else
{
return
0
;
}
}
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/graph_brpc_client.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
distributed
{
void
GraphPsService_Stub
::
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
distributed
::
PsRequestMessage
*
request
,
::
paddle
::
distributed
::
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
if
(
graph_service
!=
NULL
&&
local_channel
==
channel
())
{
// VLOG(0)<<"use local";
task_pool
->
enqueue
([
this
,
controller
,
request
,
response
,
done
]()
->
int
{
this
->
graph_service
->
service
(
controller
,
request
,
response
,
done
);
return
0
;
});
}
else
{
// VLOG(0)<<"use server";
PsService_Stub
::
service
(
controller
,
request
,
response
,
done
);
}
}
int
GraphBrpcClient
::
get_server_index_by_id
(
int64_t
id
)
{
int
shard_num
=
get_shard_num
();
int
shard_per_server
=
shard_num
%
server_size
==
0
?
shard_num
/
server_size
:
shard_num
/
server_size
+
1
;
return
id
%
shard_num
/
shard_per_server
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
get_node_feat
(
const
uint32_t
&
table_id
,
int
idx_
,
const
std
::
vector
<
int64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
)
{
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_ids
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_GET_NODE_FEAT
)
!=
0
)
{
++
fail_num
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
request_idx
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
std
::
unique_ptr
<
char
[]
>
buffer_wrapper
(
new
char
[
bytes_size
]);
char
*
buffer
=
buffer_wrapper
.
get
();
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
query_idx_buckets
.
at
(
request_idx
).
size
();
++
node_idx
)
{
int
query_idx
=
query_idx_buckets
.
at
(
request_idx
).
at
(
node_idx
);
size_t
feat_len
=
*
(
size_t
*
)(
buffer
);
buffer
+=
sizeof
(
size_t
);
auto
feature
=
std
::
string
(
buffer
,
feat_len
);
res
[
feat_idx
][
query_idx
]
=
feature
;
buffer
+=
feat_len
;
}
}
}
if
(
fail_num
==
request_call_num
)
{
ret
=
-
1
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_GET_NODE_FEAT
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
std
::
string
joint_feature_name
=
paddle
::
string
::
join_strings
(
feature_names
,
'\t'
);
closure
->
request
(
request_idx
)
->
add_params
(
joint_feature_name
.
c_str
(),
joint_feature_name
.
size
());
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
clear_nodes
(
uint32_t
table_id
,
int
type_id
,
int
idx_
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
server_size
,
[
&
,
server_size
=
this
->
server_size
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
server_size
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_CLEAR
)
!=
0
)
{
++
fail_num
;
break
;
}
}
ret
=
fail_num
==
0
?
0
:
-
1
;
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
server_size
;
i
++
)
{
int
server_index
=
i
;
closure
->
request
(
server_index
)
->
set_cmd_id
(
PS_GRAPH_CLEAR
);
closure
->
request
(
server_index
)
->
set_table_id
(
table_id
);
closure
->
request
(
server_index
)
->
set_client_id
(
_client_id
);
closure
->
request
(
server_index
)
->
add_params
((
char
*
)
&
type_id
,
sizeof
(
int
));
closure
->
request
(
server_index
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
server_index
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
server_index
),
closure
->
request
(
server_index
),
closure
->
response
(
server_index
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
add_graph_node
(
uint32_t
table_id
,
int
idx_
,
std
::
vector
<
int64_t
>
&
node_id_list
,
std
::
vector
<
bool
>
&
is_weighted_list
)
{
std
::
vector
<
std
::
vector
<
int64_t
>>
request_bucket
;
std
::
vector
<
std
::
vector
<
bool
>>
is_weighted_bucket
;
bool
add_weight
=
is_weighted_list
.
size
()
>
0
;
std
::
vector
<
int
>
server_index_arr
;
std
::
vector
<
int
>
index_mapping
(
server_size
,
-
1
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_id_list
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_id_list
[
query_idx
]);
if
(
index_mapping
[
server_index
]
==
-
1
)
{
index_mapping
[
server_index
]
=
request_bucket
.
size
();
server_index_arr
.
push_back
(
server_index
);
request_bucket
.
push_back
(
std
::
vector
<
int64_t
>
());
if
(
add_weight
)
is_weighted_bucket
.
push_back
(
std
::
vector
<
bool
>
());
}
request_bucket
[
index_mapping
[
server_index
]].
push_back
(
node_id_list
[
query_idx
]);
if
(
add_weight
)
is_weighted_bucket
[
index_mapping
[
server_index
]].
push_back
(
query_idx
<
is_weighted_list
.
size
()
?
is_weighted_list
[
query_idx
]
:
false
);
}
size_t
request_call_num
=
request_bucket
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_ADD_GRAPH_NODE
)
!=
0
)
{
++
fail_num
;
}
}
ret
=
fail_num
==
request_call_num
?
-
1
:
0
;
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
server_index_arr
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_ADD_GRAPH_NODE
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
request_bucket
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
request_bucket
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
if
(
add_weight
)
{
bool
weighted
[
is_weighted_bucket
[
request_idx
].
size
()
+
1
];
for
(
size_t
j
=
0
;
j
<
is_weighted_bucket
[
request_idx
].
size
();
j
++
)
weighted
[
j
]
=
is_weighted_bucket
[
request_idx
][
j
];
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
weighted
,
sizeof
(
bool
)
*
is_weighted_bucket
[
request_idx
].
size
());
}
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
remove_graph_node
(
uint32_t
table_id
,
int
idx_
,
std
::
vector
<
int64_t
>
&
node_id_list
)
{
std
::
vector
<
std
::
vector
<
int64_t
>>
request_bucket
;
std
::
vector
<
int
>
server_index_arr
;
std
::
vector
<
int
>
index_mapping
(
server_size
,
-
1
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_id_list
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_id_list
[
query_idx
]);
if
(
index_mapping
[
server_index
]
==
-
1
)
{
index_mapping
[
server_index
]
=
request_bucket
.
size
();
server_index_arr
.
push_back
(
server_index
);
request_bucket
.
push_back
(
std
::
vector
<
int64_t
>
());
}
request_bucket
[
index_mapping
[
server_index
]].
push_back
(
node_id_list
[
query_idx
]);
}
size_t
request_call_num
=
request_bucket
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_REMOVE_GRAPH_NODE
)
!=
0
)
{
++
fail_num
;
}
}
ret
=
fail_num
==
request_call_num
?
-
1
:
0
;
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
server_index_arr
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_REMOVE_GRAPH_NODE
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
request_bucket
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
request_bucket
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
// char* &buffer,int &actual_size
std
::
future
<
int32_t
>
GraphBrpcClient
::
batch_sample_neighbors
(
uint32_t
table_id
,
int
idx_
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
// std::vector<std::vector<std::pair<int64_t, float>>> &res,
std
::
vector
<
std
::
vector
<
int64_t
>>
&
res
,
std
::
vector
<
std
::
vector
<
float
>>
&
res_weight
,
bool
need_weight
,
int
server_index
)
{
if
(
server_index
!=
-
1
)
{
res
.
resize
(
node_ids
.
size
());
if
(
need_weight
)
{
res_weight
.
resize
(
node_ids
.
size
());
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
if
(
closure
->
check_response
(
0
,
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
)
!=
0
)
{
ret
=
-
1
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
std
::
unique_ptr
<
char
[]
>
buffer_wrapper
(
new
char
[
bytes_size
]);
char
*
buffer
=
buffer_wrapper
.
get
();
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
size_t
node_num
=
*
(
size_t
*
)
buffer
;
int
*
actual_sizes
=
(
int
*
)(
buffer
+
sizeof
(
size_t
));
char
*
node_buffer
=
buffer
+
sizeof
(
size_t
)
+
sizeof
(
int
)
*
node_num
;
int
offset
=
0
;
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
int
actual_size
=
actual_sizes
[
node_idx
];
int
start
=
0
;
while
(
start
<
actual_size
)
{
res
[
node_idx
].
emplace_back
(
*
(
int64_t
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
id_size
;
if
(
need_weight
)
{
res_weight
[
node_idx
].
emplace_back
(
*
(
float
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
weight_size
;
}
}
offset
+=
actual_size
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
;
closure
->
request
(
0
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
node_ids
.
data
(),
sizeof
(
int64_t
)
*
node_ids
.
size
());
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
need_weight
,
sizeof
(
bool
));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
res
.
clear
();
res_weight
.
clear
();
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
// res.push_back(std::vector<std::pair<int64_t, float>>());
res
.
push_back
({});
if
(
need_weight
)
{
res_weight
.
push_back
({});
}
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_ids
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_SAMPLE_NEIGHBORS
)
!=
0
)
{
++
fail_num
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
request_idx
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
std
::
unique_ptr
<
char
[]
>
buffer_wrapper
(
new
char
[
bytes_size
]);
char
*
buffer
=
buffer_wrapper
.
get
();
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
size_t
node_num
=
*
(
size_t
*
)
buffer
;
int
*
actual_sizes
=
(
int
*
)(
buffer
+
sizeof
(
size_t
));
char
*
node_buffer
=
buffer
+
sizeof
(
size_t
)
+
sizeof
(
int
)
*
node_num
;
int
offset
=
0
;
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
int
query_idx
=
query_idx_buckets
.
at
(
request_idx
).
at
(
node_idx
);
int
actual_size
=
actual_sizes
[
node_idx
];
int
start
=
0
;
while
(
start
<
actual_size
)
{
res
[
query_idx
].
emplace_back
(
*
(
int64_t
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
id_size
;
if
(
need_weight
)
{
res_weight
[
query_idx
].
emplace_back
(
*
(
float
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
weight_size
;
}
}
offset
+=
actual_size
;
}
}
if
(
fail_num
==
request_call_num
)
{
ret
=
-
1
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NEIGHBORS
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
need_weight
,
sizeof
(
bool
));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
random_sample_nodes
(
uint32_t
table_id
,
int
type_id
,
int
idx_
,
int
server_index
,
int
sample_size
,
std
::
vector
<
int64_t
>
&
ids
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
if
(
closure
->
check_response
(
0
,
PS_GRAPH_SAMPLE_NODES
)
!=
0
)
{
ret
=
-
1
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
char
*
buffer
=
new
char
[
bytes_size
];
size_t
index
=
0
;
while
(
index
<
bytes_size
)
{
ids
.
push_back
(
*
(
int64_t
*
)(
buffer
+
index
));
index
+=
GraphNode
::
id_size
;
}
delete
[]
buffer
;
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
;
closure
->
request
(
0
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NODES
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
type_id
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
pull_graph_list
(
uint32_t
table_id
,
int
type_id
,
int
idx_
,
int
server_index
,
int
start
,
int
size
,
int
step
,
std
::
vector
<
FeatureNode
>
&
res
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
if
(
closure
->
check_response
(
0
,
PS_PULL_GRAPH_LIST
)
!=
0
)
{
ret
=
-
1
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
char
*
buffer
=
new
char
[
bytes_size
];
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
size_t
index
=
0
;
while
(
index
<
bytes_size
)
{
FeatureNode
node
;
node
.
recover_from_buffer
(
buffer
+
index
);
index
+=
node
.
get_size
(
false
);
res
.
push_back
(
node
);
}
delete
[]
buffer
;
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
closure
->
request
(
0
)
->
set_cmd_id
(
PS_PULL_GRAPH_LIST
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
type_id
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
start
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
size
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
step
,
sizeof
(
int
));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
set_node_feat
(
const
uint32_t
&
table_id
,
int
idx_
,
const
std
::
vector
<
int64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
features
)
{
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
string
>>>
features_idx_buckets
(
request_call_num
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_ids
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
if
(
features_idx_buckets
[
request_idx
].
size
()
==
0
)
{
features_idx_buckets
[
request_idx
].
resize
(
feature_names
.
size
());
}
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
features_idx_buckets
[
request_idx
][
feat_idx
].
push_back
(
features
[
feat_idx
][
query_idx
]);
}
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_SET_NODE_FEAT
)
!=
0
)
{
++
fail_num
;
}
if
(
fail_num
==
request_call_num
)
{
ret
=
-
1
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_SET_NODE_FEAT
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
std
::
string
joint_feature_name
=
paddle
::
string
::
join_strings
(
feature_names
,
'\t'
);
closure
->
request
(
request_idx
)
->
add_params
(
joint_feature_name
.
c_str
(),
joint_feature_name
.
size
());
// set features
std
::
string
set_feature
=
""
;
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
size_t
feat_len
=
features_idx_buckets
[
request_idx
][
feat_idx
][
node_idx
].
size
();
set_feature
.
append
((
char
*
)
&
feat_len
,
sizeof
(
size_t
));
set_feature
.
append
(
features_idx_buckets
[
request_idx
][
feat_idx
][
node_idx
].
data
(),
feat_len
);
}
}
closure
->
request
(
request_idx
)
->
add_params
(
set_feature
.
c_str
(),
set_feature
.
size
());
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
int32_t
GraphBrpcClient
::
Initialize
()
{
// set_shard_num(_config.shard_num());
BrpcPsClient
::
Initialize
();
server_size
=
GetServerNums
();
graph_service
=
NULL
;
local_channel
=
NULL
;
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/graph_brpc_client.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace
paddle
{
namespace
distributed
{
class
GraphPsService_Stub
:
public
PsService_Stub
{
public:
GraphPsService_Stub
(
::
google
::
protobuf
::
RpcChannel
*
channel
,
::
google
::
protobuf
::
RpcChannel
*
local_channel
=
NULL
,
GraphBrpcService
*
service
=
NULL
,
int
thread_num
=
1
)
:
PsService_Stub
(
channel
)
{
this
->
local_channel
=
local_channel
;
this
->
graph_service
=
service
;
task_pool
.
reset
(
new
::
ThreadPool
(
thread_num
));
}
virtual
~
GraphPsService_Stub
()
{}
// implements PsService ------------------------------------------
GraphBrpcService
*
graph_service
;
std
::
shared_ptr
<::
ThreadPool
>
task_pool
;
::
google
::
protobuf
::
RpcChannel
*
local_channel
;
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS
(
GraphPsService_Stub
);
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
distributed
::
PsRequestMessage
*
request
,
::
paddle
::
distributed
::
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
);
};
class
GraphBrpcClient
:
public
BrpcPsClient
{
public:
GraphBrpcClient
()
{}
virtual
~
GraphBrpcClient
()
{}
// given a batch of nodes, sample graph_neighbors for each of them
virtual
std
::
future
<
int32_t
>
batch_sample_neighbors
(
uint32_t
table_id
,
int
idx
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
vector
<
int64_t
>>&
res
,
std
::
vector
<
std
::
vector
<
float
>>&
res_weight
,
bool
need_weight
,
int
server_index
=
-
1
);
virtual
std
::
future
<
int32_t
>
pull_graph_list
(
uint32_t
table_id
,
int
type_id
,
int
idx
,
int
server_index
,
int
start
,
int
size
,
int
step
,
std
::
vector
<
FeatureNode
>&
res
);
virtual
std
::
future
<
int32_t
>
random_sample_nodes
(
uint32_t
table_id
,
int
type_id
,
int
idx
,
int
server_index
,
int
sample_size
,
std
::
vector
<
int64_t
>&
ids
);
virtual
std
::
future
<
int32_t
>
get_node_feat
(
const
uint32_t
&
table_id
,
int
idx
,
const
std
::
vector
<
int64_t
>&
node_ids
,
const
std
::
vector
<
std
::
string
>&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>&
res
);
virtual
std
::
future
<
int32_t
>
set_node_feat
(
const
uint32_t
&
table_id
,
int
idx
,
const
std
::
vector
<
int64_t
>&
node_ids
,
const
std
::
vector
<
std
::
string
>&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
features
);
virtual
std
::
future
<
int32_t
>
clear_nodes
(
uint32_t
table_id
,
int
type_id
,
int
idx
);
virtual
std
::
future
<
int32_t
>
add_graph_node
(
uint32_t
table_id
,
int
idx
,
std
::
vector
<
int64_t
>&
node_id_list
,
std
::
vector
<
bool
>&
is_weighted_list
);
virtual
std
::
future
<
int32_t
>
remove_graph_node
(
uint32_t
table_id
,
int
idx_
,
std
::
vector
<
int64_t
>&
node_id_list
);
virtual
int32_t
Initialize
();
int
get_shard_num
()
{
return
shard_num
;
}
void
set_shard_num
(
int
shard_num
)
{
this
->
shard_num
=
shard_num
;
}
int
get_server_index_by_id
(
int64_t
id
);
void
set_local_channel
(
int
index
)
{
this
->
local_channel
=
GetCmdChannel
(
index
);
}
void
set_local_graph_service
(
GraphBrpcService
*
graph_service
)
{
this
->
graph_service
=
graph_service
;
}
GraphPsService_Stub
getServiceStub
(
::
google
::
protobuf
::
RpcChannel
*
channel
,
int
thread_num
=
1
)
{
return
GraphPsService_Stub
(
channel
,
local_channel
,
graph_service
,
thread_num
);
}
private:
int
shard_num
;
size_t
server_size
;
::
google
::
protobuf
::
RpcChannel
*
local_channel
;
GraphBrpcService
*
graph_service
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include <thread> // NOLINT
#include <utility>
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
distributed
{
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t
GraphBrpcServer
::
Initialize
()
{
auto
&
service_config
=
_config
.
downpour_server_param
().
service_param
();
if
(
!
service_config
.
has_service_class
())
{
LOG
(
ERROR
)
<<
"miss service_class in ServerServiceParameter"
;
return
-
1
;
}
auto
*
service
=
CREATE_PSCORE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
if
(
service
==
NULL
)
{
LOG
(
ERROR
)
<<
"service is unregistered, service_name:"
<<
service_config
.
service_class
();
return
-
1
;
}
_service
.
reset
(
service
);
if
(
service
->
Configure
(
this
)
!=
0
||
service
->
Initialize
()
!=
0
)
{
LOG
(
ERROR
)
<<
"service initialize failed, service_name:"
<<
service_config
.
service_class
();
return
-
1
;
}
if
(
_server
.
AddService
(
service
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
)
!=
0
)
{
LOG
(
ERROR
)
<<
"service add to brpc failed, service:"
<<
service_config
.
service_class
();
return
-
1
;
}
return
0
;
}
brpc
::
Channel
*
GraphBrpcServer
::
GetCmdChannel
(
size_t
server_index
)
{
return
_pserver_channels
[
server_index
].
get
();
}
uint64_t
GraphBrpcServer
::
Start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
string
ip_port
=
ip
+
":"
+
std
::
to_string
(
port
);
VLOG
(
3
)
<<
"server of rank "
<<
_rank
<<
" starts at "
<<
ip_port
;
brpc
::
ServerOptions
options
;
int
num_threads
=
std
::
thread
::
hardware_concurrency
();
auto
trainers
=
_environment
->
GetTrainers
();
options
.
num_threads
=
trainers
>
num_threads
?
trainers
:
num_threads
;
if
(
_server
.
Start
(
ip_port
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"GraphBrpcServer start failed, ip_port="
<<
ip_port
;
return
0
;
}
_environment
->
RegistePsServer
(
ip
,
port
,
_rank
);
return
0
;
}
int32_t
GraphBrpcServer
::
build_peer2peer_connection
(
int
rank
)
{
this
->
rank
=
rank
;
auto
_env
=
Environment
();
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
500000
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
10000
;
options
.
max_retry
=
3
;
std
::
vector
<
PSHost
>
server_list
=
_env
->
GetPsServers
();
_pserver_channels
.
resize
(
server_list
.
size
());
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
for
(
size_t
i
=
0
;
i
<
server_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
server_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
server_list
[
i
].
port
));
_pserver_channels
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
_pserver_channels
[
i
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"GraphServer connect to Server:"
<<
server_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
server_list
[
i
].
ip
,
server_list
[
i
].
port
);
if
(
_pserver_channels
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"GraphServer connect to Server:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
os
<<
server_ip_port
<<
","
;
}
LOG
(
INFO
)
<<
"servers peer2peer connection success:"
<<
os
.
str
();
return
0
;
}
int32_t
GraphBrpcService
::
clear_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int
type_id
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
int
idx_
=
*
(
int
*
)(
request
.
params
(
1
).
c_str
());
((
GraphTable
*
)
table
)
->
clear_nodes
(
type_id
,
idx_
);
return
0
;
}
int32_t
GraphBrpcService
::
add_graph_node
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"add_graph_node request requires at least 2 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
int64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
std
::
vector
<
uint64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
bool
>
is_weighted_list
;
if
(
request
.
params_size
()
==
3
)
{
size_t
weight_list_size
=
request
.
params
(
2
).
size
()
/
sizeof
(
bool
);
bool
*
is_weighted_buffer
=
(
bool
*
)(
request
.
params
(
2
).
c_str
());
is_weighted_list
=
std
::
vector
<
bool
>
(
is_weighted_buffer
,
is_weighted_buffer
+
weight_list_size
);
}
// if (request.params_size() == 2) {
// size_t weight_list_size = request.params(1).size() / sizeof(bool);
// bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
// is_weighted_list = std::vector<bool>(is_weighted_buffer,
// is_weighted_buffer +
// weight_list_size);
// }
((
GraphTable
*
)
table
)
->
add_graph_node
(
idx_
,
node_ids
,
is_weighted_list
);
return
0
;
}
int32_t
GraphBrpcService
::
remove_graph_node
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"remove_graph_node request requires at least 2 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
std
::
vector
<
uint64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
((
GraphTable
*
)
table
)
->
remove_graph_node
(
idx_
,
node_ids
);
return
0
;
}
int32_t
GraphBrpcServer
::
Port
()
{
return
_server
.
listen_address
().
port
;
}
int32_t
GraphBrpcService
::
Initialize
()
{
_is_initialize_shard_info
=
false
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
GraphBrpcService
::
StopServer
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
GraphBrpcService
::
LoadOneTable
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
GraphBrpcService
::
LoadAllTable
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
GraphBrpcService
::
PrintTableStat
;
_service_handler_map
[
PS_BARRIER
]
=
&
GraphBrpcService
::
Barrier
;
_service_handler_map
[
PS_START_PROFILER
]
=
&
GraphBrpcService
::
StartProfiler
;
_service_handler_map
[
PS_STOP_PROFILER
]
=
&
GraphBrpcService
::
StopProfiler
;
_service_handler_map
[
PS_PULL_GRAPH_LIST
]
=
&
GraphBrpcService
::
pull_graph_list
;
_service_handler_map
[
PS_GRAPH_SAMPLE_NEIGHBORS
]
=
&
GraphBrpcService
::
graph_random_sample_neighbors
;
_service_handler_map
[
PS_GRAPH_SAMPLE_NODES
]
=
&
GraphBrpcService
::
graph_random_sample_nodes
;
_service_handler_map
[
PS_GRAPH_GET_NODE_FEAT
]
=
&
GraphBrpcService
::
graph_get_node_feat
;
_service_handler_map
[
PS_GRAPH_CLEAR
]
=
&
GraphBrpcService
::
clear_nodes
;
_service_handler_map
[
PS_GRAPH_ADD_GRAPH_NODE
]
=
&
GraphBrpcService
::
add_graph_node
;
_service_handler_map
[
PS_GRAPH_REMOVE_GRAPH_NODE
]
=
&
GraphBrpcService
::
remove_graph_node
;
_service_handler_map
[
PS_GRAPH_SET_NODE_FEAT
]
=
&
GraphBrpcService
::
graph_set_node_feat
;
_service_handler_map
[
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
]
=
&
GraphBrpcService
::
sample_neighbors_across_multi_servers
;
InitializeShardInfo
();
return
0
;
}
int32_t
GraphBrpcService
::
InitializeShardInfo
()
{
if
(
!
_is_initialize_shard_info
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
_initialize_shard_mutex
);
if
(
_is_initialize_shard_info
)
{
return
0
;
}
server_size
=
_server
->
Environment
()
->
GetPsServers
().
size
();
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
itr
:
table_map
)
{
itr
.
second
->
SetShard
(
_rank
,
server_size
);
}
_is_initialize_shard_info
=
true
;
}
return
0
;
}
void
GraphBrpcService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
log_label
(
"ReceiveCmd-"
);
if
(
!
request
->
has_table_id
())
{
set_response_code
(
*
response
,
-
1
,
"PsRequestMessage.tabel_id is required"
);
return
;
}
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
auto
*
table
=
_server
->
GetTable
(
request
->
table_id
());
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_base
);
auto
itr
=
_service_handler_map
.
find
(
request
->
cmd_id
());
if
(
itr
==
_service_handler_map
.
end
())
{
std
::
string
err_msg
(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"
);
err_msg
.
append
(
std
::
to_string
(
request
->
cmd_id
()));
set_response_code
(
*
response
,
-
1
,
err_msg
.
c_str
());
return
;
}
serviceFunc
handler_func
=
itr
->
second
;
int
service_ret
=
(
this
->*
handler_func
)(
table
,
*
request
,
*
response
,
cntl
);
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
service_ret
);
if
(
!
response
->
has_err_msg
())
{
response
->
set_err_msg
(
"server internal error"
);
}
}
}
int32_t
GraphBrpcService
::
Barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
auto
trainer_id
=
request
.
client_id
();
auto
barrier_type
=
request
.
params
(
0
);
table
->
Barrier
(
trainer_id
,
barrier_type
);
return
0
;
}
int32_t
GraphBrpcService
::
PrintTableStat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
std
::
pair
<
int64_t
,
int64_t
>
ret
=
table
->
PrintTableStat
();
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
ret
.
first
<<
ret
.
second
;
std
::
string
table_info
(
ar
.
Buffer
(),
ar
.
Length
());
response
.
set_data
(
table_info
);
return
0
;
}
int32_t
GraphBrpcService
::
LoadOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 2 for path & load_param"
);
return
-
1
;
}
if
(
table
->
Load
(
request
.
params
(
0
),
request
.
params
(
1
))
!=
0
)
{
set_response_code
(
response
,
-
1
,
"table load failed"
);
return
-
1
;
}
return
0
;
}
int32_t
GraphBrpcService
::
LoadAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
&
itr
:
table_map
)
{
if
(
LoadOneTable
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
LOG
(
ERROR
)
<<
"load table["
<<
itr
.
first
<<
"] failed"
;
return
-
1
;
}
}
return
0
;
}
int32_t
GraphBrpcService
::
StopServer
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
GraphBrpcServer
*
p_server
=
(
GraphBrpcServer
*
)
_server
;
std
::
thread
t_stop
([
p_server
]()
{
p_server
->
Stop
();
LOG
(
INFO
)
<<
"Server Stoped"
;
});
p_server
->
export_cv
()
->
notify_all
();
t_stop
.
detach
();
return
0
;
}
int32_t
GraphBrpcService
::
StopProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"server_%s_profile"
,
_rank
));
return
0
;
}
int32_t
GraphBrpcService
::
StartProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
return
0
;
}
int32_t
GraphBrpcService
::
pull_graph_list
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
5
)
{
set_response_code
(
response
,
-
1
,
"pull_graph_list request requires at least 5 arguments"
);
return
0
;
}
int
type_id
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
int
idx
=
*
(
int
*
)(
request
.
params
(
1
).
c_str
());
int
start
=
*
(
int
*
)(
request
.
params
(
2
).
c_str
());
int
size
=
*
(
int
*
)(
request
.
params
(
3
).
c_str
());
int
step
=
*
(
int
*
)(
request
.
params
(
4
).
c_str
());
std
::
unique_ptr
<
char
[]
>
buffer
;
int
actual_size
;
((
GraphTable
*
)
table
)
->
pull_graph_list
(
type_id
,
idx
,
start
,
size
,
buffer
,
actual_size
,
false
,
step
);
cntl
->
response_attachment
().
append
(
buffer
.
get
(),
actual_size
);
return
0
;
}
int32_t
GraphBrpcService
::
graph_random_sample_neighbors
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
4
)
{
set_response_code
(
response
,
-
1
,
"graph_random_sample_neighbors request requires at least 3 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
int
sample_size
=
*
(
int
*
)(
request
.
params
(
2
).
c_str
());
bool
need_weight
=
*
(
bool
*
)(
request
.
params
(
3
).
c_str
());
std
::
vector
<
std
::
shared_ptr
<
char
>>
buffers
(
node_num
);
std
::
vector
<
int
>
actual_sizes
(
node_num
,
0
);
((
GraphTable
*
)
table
)
->
random_sample_neighbors
(
idx_
,
node_data
,
sample_size
,
buffers
,
actual_sizes
,
need_weight
);
cntl
->
response_attachment
().
append
(
&
node_num
,
sizeof
(
size_t
));
cntl
->
response_attachment
().
append
(
actual_sizes
.
data
(),
sizeof
(
int
)
*
node_num
);
for
(
size_t
idx
=
0
;
idx
<
node_num
;
++
idx
)
{
cntl
->
response_attachment
().
append
(
buffers
[
idx
].
get
(),
actual_sizes
[
idx
]);
}
return
0
;
}
int32_t
GraphBrpcService
::
graph_random_sample_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int
type_id
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
int
idx_
=
*
(
int
*
)(
request
.
params
(
1
).
c_str
());
size_t
size
=
*
(
uint64_t
*
)(
request
.
params
(
2
).
c_str
());
// size_t size = *(int64_t *)(request.params(0).c_str());
std
::
unique_ptr
<
char
[]
>
buffer
;
int
actual_size
;
if
(((
GraphTable
*
)
table
)
->
random_sample_nodes
(
type_id
,
idx_
,
size
,
buffer
,
actual_size
)
==
0
)
{
cntl
->
response_attachment
().
append
(
buffer
.
get
(),
actual_size
);
}
else
cntl
->
response_attachment
().
append
(
NULL
,
0
);
return
0
;
}
int32_t
GraphBrpcService
::
graph_get_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
3
)
{
set_response_code
(
response
,
-
1
,
"graph_get_node_feat request requires at least 3 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
std
::
vector
<
uint64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
std
::
string
>
feature_names
=
paddle
::
string
::
split_string
<
std
::
string
>
(
request
.
params
(
2
),
"
\t
"
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
feature
(
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_num
));
((
GraphTable
*
)
table
)
->
get_node_feat
(
idx_
,
node_ids
,
feature_names
,
feature
);
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
size_t
feat_len
=
feature
[
feat_idx
][
node_idx
].
size
();
cntl
->
response_attachment
().
append
(
&
feat_len
,
sizeof
(
size_t
));
cntl
->
response_attachment
().
append
(
feature
[
feat_idx
][
node_idx
].
data
(),
feat_len
);
}
}
return
0
;
}
int32_t
GraphBrpcService
::
sample_neighbors_across_multi_servers
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
// sleep(5);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
4
)
{
set_response_code
(
response
,
-
1
,
"sample_neighbors_across_multi_servers request requires "
"at least 4 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
int
sample_size
=
*
(
int
*
)(
request
.
params
(
2
).
c_str
());
bool
need_weight
=
*
(
bool
*
)(
request
.
params
(
3
).
c_str
());
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
std
::
vector
<
uint64_t
>
local_id
;
std
::
vector
<
int
>
local_query_idx
;
size_t
rank
=
GetRank
();
for
(
size_t
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
}
if
(
server2request
[
rank
]
!=
-
1
)
{
auto
pos
=
server2request
[
rank
];
std
::
swap
(
request2server
[
pos
],
request2server
[(
int
)
request2server
.
size
()
-
1
]);
server2request
[
request2server
[
pos
]]
=
pos
;
server2request
[
request2server
[(
int
)
request2server
.
size
()
-
1
]]
=
request2server
.
size
()
-
1
;
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
shared_ptr
<
char
>>
local_buffers
;
std
::
vector
<
int
>
local_actual_sizes
;
std
::
vector
<
size_t
>
seq
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_data
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
seq
.
push_back
(
request_idx
);
}
size_t
remote_call_num
=
request_call_num
;
if
(
request2server
.
size
()
!=
0
&&
static_cast
<
size_t
>
(
request2server
.
back
())
==
rank
)
{
remote_call_num
--
;
local_buffers
.
resize
(
node_id_buckets
.
back
().
size
());
local_actual_sizes
.
resize
(
node_id_buckets
.
back
().
size
());
}
cntl
->
response_attachment
().
append
(
&
node_num
,
sizeof
(
size_t
));
auto
local_promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int
>
local_fut
=
local_promise
->
get_future
();
std
::
vector
<
bool
>
failed
(
server_size
,
false
);
std
::
function
<
void
(
void
*
)
>
func
=
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
local_fut
.
get
();
std
::
vector
<
int
>
actual_size
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
std
::
vector
<
std
::
unique_ptr
<
butil
::
IOBufBytesIterator
>>
res
(
remote_call_num
);
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
remote_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_SAMPLE_NEIGHBORS
)
!=
0
)
{
++
fail_num
;
failed
[
request2server
[
request_idx
]]
=
true
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
request_idx
)
->
response_attachment
();
res
[
request_idx
].
reset
(
new
butil
::
IOBufBytesIterator
(
res_io_buffer
));
size_t
num
;
res
[
request_idx
]
->
copy_and_forward
(
&
num
,
sizeof
(
size_t
));
}
}
int
size
;
int
local_index
=
0
;
for
(
size_t
i
=
0
;
i
<
node_num
;
i
++
)
{
if
(
fail_num
>
0
&&
failed
[
seq
[
i
]])
{
size
=
0
;
}
else
if
(
static_cast
<
size_t
>
(
request2server
[
seq
[
i
]])
!=
rank
)
{
res
[
seq
[
i
]]
->
copy_and_forward
(
&
size
,
sizeof
(
int
));
}
else
{
size
=
local_actual_sizes
[
local_index
++
];
}
actual_size
.
push_back
(
size
);
}
cntl
->
response_attachment
().
append
(
actual_size
.
data
(),
actual_size
.
size
()
*
sizeof
(
int
));
local_index
=
0
;
for
(
size_t
i
=
0
;
i
<
node_num
;
i
++
)
{
if
(
fail_num
>
0
&&
failed
[
seq
[
i
]])
{
continue
;
}
else
if
(
static_cast
<
size_t
>
(
request2server
[
seq
[
i
]])
!=
rank
)
{
char
temp
[
actual_size
[
i
]
+
1
];
res
[
seq
[
i
]]
->
copy_and_forward
(
temp
,
actual_size
[
i
]);
cntl
->
response_attachment
().
append
(
temp
,
actual_size
[
i
]);
}
else
{
char
*
temp
=
local_buffers
[
local_index
++
].
get
();
cntl
->
response_attachment
().
append
(
temp
,
actual_size
[
i
]);
}
}
closure
->
set_promise_value
(
0
);
};
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
remote_call_num
,
func
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
remote_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NEIGHBORS
);
closure
->
request
(
request_idx
)
->
set_table_id
(
request
.
table_id
());
closure
->
request
(
request_idx
)
->
set_client_id
(
rank
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
uint64_t
)
*
node_num
);
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
need_weight
,
sizeof
(
bool
));
PsService_Stub
rpc_stub
(
((
GraphBrpcServer
*
)
GetServer
())
->
GetCmdChannel
(
server_index
));
// GraphPsService_Stub rpc_stub =
// getServiceStub(GetCmdChannel(server_index));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
if
(
server2request
[
rank
]
!=
-
1
)
{
((
GraphTable
*
)
table
)
->
random_sample_neighbors
(
idx_
,
node_id_buckets
.
back
().
data
(),
sample_size
,
local_buffers
,
local_actual_sizes
,
need_weight
);
}
local_promise
.
get
()
->
set_value
(
0
);
if
(
remote_call_num
==
0
)
func
(
closure
);
fut
.
get
();
return
0
;
}
int32_t
GraphBrpcService
::
graph_set_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
4
)
{
set_response_code
(
response
,
-
1
,
"graph_set_node_feat request requires at least 3 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
std
::
vector
<
uint64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
// std::vector<std::string> feature_names =
// paddle::string::split_string<std::string>(request.params(1), "\t");
std
::
vector
<
std
::
string
>
feature_names
=
paddle
::
string
::
split_string
<
std
::
string
>
(
request
.
params
(
2
),
"
\t
"
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
(
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_num
));
// const char *buffer = request.params(2).c_str();
const
char
*
buffer
=
request
.
params
(
3
).
c_str
();
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
size_t
feat_len
=
*
(
size_t
*
)(
buffer
);
buffer
+=
sizeof
(
size_t
);
auto
feat
=
std
::
string
(
buffer
,
feat_len
);
features
[
feat_idx
][
node_idx
]
=
feat
;
buffer
+=
feat_len
;
}
}
((
GraphTable
*
)
table
)
->
set_node_feat
(
idx_
,
node_ids
,
feature_names
,
features
);
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/graph_brpc_server.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/ps/service/server.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace
paddle
{
namespace
distributed
{
class
GraphBrpcServer
:
public
PSServer
{
public:
GraphBrpcServer
()
{}
virtual
~
GraphBrpcServer
()
{}
PsBaseService
*
get_service
()
{
return
_service
.
get
();
}
virtual
uint64_t
Start
(
const
std
::
string
&
ip
,
uint32_t
port
);
virtual
int32_t
build_peer2peer_connection
(
int
rank
);
virtual
brpc
::
Channel
*
GetCmdChannel
(
size_t
server_index
);
virtual
int32_t
Stop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
stoped_
)
return
0
;
stoped_
=
true
;
// cv_.notify_all();
_server
.
Stop
(
1000
);
_server
.
Join
();
return
0
;
}
int32_t
Port
();
std
::
condition_variable
*
export_cv
()
{
return
&
cv_
;
}
private:
virtual
int32_t
Initialize
();
mutable
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
bool
stoped_
=
false
;
int
rank
;
brpc
::
Server
_server
;
std
::
shared_ptr
<
PsBaseService
>
_service
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
};
class
GraphBrpcService
;
typedef
int32_t
(
GraphBrpcService
::*
serviceFunc
)(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
class
GraphBrpcService
:
public
PsBaseService
{
public:
virtual
int32_t
Initialize
()
override
;
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
protected:
std
::
unordered_map
<
int32_t
,
serviceFunc
>
_service_handler_map
;
int32_t
InitializeShardInfo
();
int32_t
pull_graph_list
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_random_sample_neighbors
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_random_sample_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_get_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_set_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
clear_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
add_graph_node
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
remove_graph_node
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
Barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
LoadOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
LoadAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
StopServer
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
StartProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
StopProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PrintTableStat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
sample_neighbors_across_multi_servers
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
use_neighbors_sample_cache
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
load_graph_split_config
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
private:
bool
_is_initialize_shard_info
;
std
::
mutex
_initialize_shard_mutex
;
std
::
unordered_map
<
int32_t
,
serviceHandlerFunc
>
_msg_handler_map
;
std
::
vector
<
float
>
_ori_values
;
const
int
sample_nodes_ranges
=
23
;
size_t
server_size
;
std
::
shared_ptr
<::
ThreadPool
>
task_pool
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/heter_client.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
distributed
{
DEFINE_int32
(
heter_world_size
,
100
,
"group size"
);
// group max size
DEFINE_int32
(
switch_send_recv_timeout_s
,
600
,
"switch_send_recv_timeout_s"
);
std
::
shared_ptr
<
HeterClient
>
HeterClient
::
s_instance_
=
nullptr
;
std
::
mutex
HeterClient
::
mtx_
;
std
::
shared_ptr
<
HeterClient
>
HeterClient
::
switch_s_instance_
=
nullptr
;
int
GetMicroId
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
)
{
framework
::
Variable
*
var
=
scope
->
FindVar
(
"microbatch_id"
);
PADDLE_ENFORCE_EQ
(
var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"the type of micro id shoulde be LoDTensor."
));
auto
micro_id
=
-
1
;
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
if
(
platform
::
is_cpu_place
(
tensor
->
place
()))
{
auto
data
=
reinterpret_cast
<
const
float
*>
(
tensor
->
data
());
micro_id
=
static_cast
<
int
>
(
data
[
0
]);
}
else
{
#ifdef PADDLE_WITH_CUDA
std
::
vector
<
char
>
temp
;
temp
.
resize
(
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
()));
char
*
temp_ptr
=
temp
.
data
();
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
platform
::
CPUPlace
(),
temp_ptr
,
tensor
->
place
(),
tensor
->
data
(),
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
()),
stream
);
float
*
temp_ptr_float
=
reinterpret_cast
<
float
*>
(
temp_ptr
);
micro_id
=
static_cast
<
int
>
(
temp_ptr_float
[
0
]);
#endif
}
return
micro_id
;
}
void
HeterClient
::
Stop
()
{
auto
status
=
StopHeterWorker
();
status
.
wait
();
}
std
::
future
<
int32_t
>
HeterClient
::
StopHeterWorker
()
{
return
SendCmd
(
-
1
,
PS_STOP_SERVER
,
{});
}
std
::
future
<
int32_t
>
HeterClient
::
StartProfiler
()
{
return
SendCmd
(
-
1
,
PS_START_PROFILER
,
{});
}
std
::
future
<
int32_t
>
HeterClient
::
StopProfiler
()
{
return
SendCmd
(
-
1
,
PS_STOP_PROFILER
,
{});
}
void
HeterClient
::
CreateClient2XpuConnection
()
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
connection_type
=
"single"
;
options
.
timeout_ms
=
FLAGS_pserver_timeout_ms
;
xpu_channels_
.
resize
(
xpu_list_
.
size
());
for
(
size_t
i
=
0
;
i
<
xpu_list_
.
size
();
++
i
)
{
xpu_channels_
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
xpu_channels_
[
i
]
->
Init
(
xpu_list_
[
i
].
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"HeterClient channel init fail. Try Again"
;
auto
ip_port
=
paddle
::
string
::
Split
(
xpu_list_
[
i
],
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
xpu_channels_
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPsServer start failed, ip_port= "
<<
int_ip_port
;
}
}
}
previous_xpu_channels_
.
resize
(
previous_xpu_list_
.
size
());
for
(
size_t
i
=
0
;
i
<
previous_xpu_list_
.
size
();
++
i
)
{
previous_xpu_channels_
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
previous_xpu_channels_
[
i
]
->
Init
(
previous_xpu_list_
[
i
].
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"HeterClient channel init fail. Try Again"
;
auto
ip_port
=
paddle
::
string
::
Split
(
previous_xpu_list_
[
i
],
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
previous_xpu_channels_
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPsServer start failed, ip_port= "
<<
int_ip_port
;
}
}
}
}
void
HeterClient
::
SendAndRecvAsync
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_name
,
const
std
::
vector
<
std
::
string
>&
recv_var_name
,
const
std
::
string
&
mode
)
{
platform
::
RecordEvent
record_event
(
"HeterClient->SendAndRecvAsync"
,
platform
::
TracerEventType
::
Communication
,
1
);
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
std
::
vector
<
std
::
string
>
send_var_name_val
=
send_var_name
;
const
std
::
vector
<
std
::
string
>
recv_var_name_val
=
recv_var_name
;
VLOG
(
3
)
<<
"BRPCClient::SendAndRecv Begin, message_name: "
<<
message_name
;
brpc
::
Channel
*
channel
=
nullptr
;
distributed
::
MultiVarMsg
request
;
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
PADDLE_ENFORCE_NE
(
closure
->
cntl
.
Failed
(),
true
,
platform
::
errors
::
Unimplemented
(
"HeterClient::SendAndRecv meets brpc error, error message is %s"
,
closure
->
cntl
.
ErrorText
()));
VLOG
(
4
)
<<
"call heter_worker success"
;
});
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
auto
&
request_io_buffer
=
closure
->
cntl
.
request_attachment
();
distributed
::
SerializeToMultiVarMsgAndIOBuf
(
message_name
,
send_var_name_val
,
recv_var_name_val
,
*
p_ctx
,
p_scope
,
&
request
,
&
request_io_buffer
);
int
micro_id
=
GetMicroId
(
ctx
,
p_scope
);
// global
auto
minibatch_id
=
micro_id
/
10
;
VLOG
(
4
)
<<
"micro_id: "
<<
micro_id
;
// select channel according to micro id
if
(
mode
==
"forward"
)
{
int
num
=
minibatch_id
%
xpu_channels_
.
size
();
channel
=
xpu_channels_
[
num
].
get
();
}
else
if
(
mode
==
"backward"
)
{
int
num
=
minibatch_id
%
previous_xpu_channels_
.
size
();
channel
=
previous_xpu_channels_
[
num
].
get
();
}
else
if
(
mode
==
"send_to_switch"
)
{
VLOG
(
4
)
<<
"calling switch service"
;
// auto promise = std::make_shared<std::promise<int32_t>>();
// closure->add_promise(promise);
// std::future<int> fut = promise->get_future();
// int idx = 1; // for test
// LOG(INFO) << "xpu_channels_ size: " << xpu_channels_.size();
// channel = xpu_channels_[idx].get(); // 为了适配 send_and_recv op
// ::paddle::distributed::PsService_Stub stub(channel);
// stub.SendToSwitch(&closure->cntl, &request, &closure->response,
// closure); fut.wait();
VLOG
(
4
)
<<
"calling switch service done"
;
return
;
}
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
SendAndRecvVariable
(
&
closure
->
cntl
,
&
request
,
&
closure
->
response
,
closure
);
}
std
::
future
<
int32_t
>
HeterClient
::
SendCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>&
params
)
{
size_t
request_call_num
=
xpu_channels_
.
size
();
paddle
::
distributed
::
DownpourBrpcClosure
*
closure
=
new
paddle
::
distributed
::
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
cmd_id
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
cmd_id
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
cmd_id
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
trainer_id_
);
for
(
const
auto
&
param
:
params
)
{
closure
->
request
(
i
)
->
add_params
(
param
);
}
::
paddle
::
distributed
::
PsService_Stub
rpc_stub
(
xpu_channels_
[
i
].
get
());
closure
->
cntl
(
i
)
->
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
int
HeterClient
::
Send
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_names
)
{
const
framework
::
Scope
*
p_scope
=
&
scope
;
// 注意是 const
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
int
ret
=
0
;
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
PADDLE_ENFORCE_NE
(
closure
->
cntl
.
Failed
(),
true
,
platform
::
errors
::
Unimplemented
(
"HeterClient::SendToSwitch meets brpc error, error message is %s"
,
closure
->
cntl
.
ErrorText
()));
}
});
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
auto
&
request_io_buffer
=
closure
->
cntl
.
request_attachment
();
distributed
::
MultiVarMsg
request
;
// 1. set req message_name(string)
request
.
set_message_name
(
message_name
);
request
.
set_group_id
(
0
);
// 2. set req send_var_names(<string>)
for
(
auto
&
send_var_name
:
send_var_names
)
{
request
.
add_send_var_names
(
send_var_name
);
}
// 3. set req var_messages(<VarMessage>)
for
(
auto
&
send_var_name
:
send_var_names
)
{
auto
*
send_var_msg
=
request
.
add_var_messages
();
send_var_msg
->
set_varname
(
send_var_name
);
framework
::
Variable
*
var
=
p_scope
->
FindVar
(
send_var_name
);
butil
::
IOBuf
temp_iobuf
;
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
SerializeLodTensor
(
var
,
ctx
,
send_var_msg
,
&
temp_iobuf
);
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
SerializeSelectedRows
(
var
,
ctx
,
send_var_msg
,
&
temp_iobuf
);
}
request_io_buffer
.
append
(
temp_iobuf
);
}
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
send_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"send_switch_channels_ is null, get xpu_channels_[0]"
;
if
(
xpu_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"xpu_channels_ is null"
;
}
send_switch_channels_
.
push_back
(
xpu_channels_
[
0
]);
}
brpc
::
Channel
*
channel
=
send_switch_channels_
[
0
].
get
();
// brpc::Channel* channel = xpu_channels_[0].get();
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
SendToSwitch
(
&
closure
->
cntl
,
&
request
,
&
closure
->
ps_response
,
closure
);
VLOG
(
4
)
<<
"waiting SendToSwitch response result......"
;
fut
.
wait
();
VLOG
(
4
)
<<
"Send done"
;
return
0
;
}
int
HeterClient
::
Send
(
int
group_id
,
const
std
::
vector
<
std
::
string
>&
var_names
,
const
std
::
vector
<
int64_t
>&
vars_size
,
void
*
data_ptr
,
int64_t
data_size
)
{
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
int
ret
=
0
;
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
LOG
(
ERROR
)
<<
"Send meets brpc error, err msg is %s"
<<
closure
->
cntl
.
ErrorText
();
}
});
distributed
::
MultiVarMsg
request
;
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
std
::
string
message_name
=
"send and save"
;
request
.
set_message_name
(
message_name
);
request
.
set_group_id
(
group_id
);
for
(
auto
&
send_var_name
:
var_names
)
{
request
.
add_send_var_names
(
send_var_name
);
}
for
(
auto
var_len
:
vars_size
)
{
request
.
add_vars_len
(
var_len
);
}
auto
&
request_buffer
=
closure
->
cntl
.
request_attachment
();
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
data_ptr
),
data_size
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
send_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"send_switch_channels_ is null, get xpu_channels_[0]"
;
if
(
xpu_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"xpu_channels_ is null"
;
}
send_switch_channels_
.
push_back
(
xpu_channels_
[
0
]);
}
brpc
::
Channel
*
channel
=
send_switch_channels_
[
0
].
get
();
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
SendToSwitch
(
&
closure
->
cntl
,
&
request
,
&
closure
->
ps_response
,
closure
);
fut
.
wait
();
delete
closure
;
return
0
;
}
int
HeterClient
::
Recv
(
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
&
recv_scope
,
// NOLINT
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
recv_var_names
)
{
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
VLOG
(
4
)
<<
"Recv service call done"
;
int
ret
=
0
;
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
VLOG
(
4
)
<<
"HeterClient::RecvFromSwitch meets "
"brpc error, error message is %s"
<<
closure
->
cntl
.
ErrorText
();
}
});
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
distributed
::
MultiVarMsg
request
;
// 1. set req message_name(string)
request
.
set_message_name
(
message_name
);
request
.
set_group_id
(
0
);
// 2. set req recv_var_names(<string>)
for
(
auto
&
recv_var_name
:
recv_var_names
)
{
request
.
add_recv_var_names
(
recv_var_name
);
}
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
recv_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"peer_switch_channels_ is null, get xpu_channels_[1]"
;
if
(
xpu_channels_
.
size
()
<
2
)
{
LOG
(
ERROR
)
<<
"xpu_channels_ is null"
;
}
recv_switch_channels_
.
push_back
(
xpu_channels_
[
1
]);
}
brpc
::
Channel
*
channel
=
recv_switch_channels_
[
0
].
get
();
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
RecvFromSwitch
(
&
closure
->
cntl
,
&
request
,
&
closure
->
response
,
closure
);
fut
.
wait
();
VLOG
(
4
)
<<
"RecvFromSwitch done"
;
// save in worker
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CPUPlace
cpu_place
;
auto
&
cpu_dev_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
&
res_io_buffer
=
closure
->
cntl
.
response_attachment
();
VLOG
(
4
)
<<
"entering DeserializeFromMultiVarMsgAndIOBuf"
;
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
closure
->
response
,
&
res_io_buffer
,
cpu_dev_ctx
,
&
recv_scope
);
VLOG
(
4
)
<<
"Recv done"
;
return
0
;
}
int
HeterClient
::
Recv
(
int
group_id
,
const
std
::
vector
<
std
::
string
>&
var_names
,
void
*
data_ptr
,
int64_t
data_size
)
{
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
int
ret
=
0
;
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
LOG
(
ERROR
)
<<
"Recv meets brpc error, err msg is %s"
<<
closure
->
cntl
.
ErrorText
();
}
});
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
distributed
::
MultiVarMsg
request
;
std
::
string
message_name
=
"query and recv"
;
request
.
set_message_name
(
message_name
);
request
.
set_group_id
(
group_id
);
for
(
auto
&
recv_var_name
:
var_names
)
{
request
.
add_recv_var_names
(
recv_var_name
);
}
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
recv_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"peer_switch_channels_ is null, get xpu_channels_[1]"
;
if
(
xpu_channels_
.
size
()
<
2
)
{
LOG
(
ERROR
)
<<
"xpu_channels_ is null"
;
}
recv_switch_channels_
.
push_back
(
xpu_channels_
[
0
]);
}
brpc
::
Channel
*
channel
=
recv_switch_channels_
[
0
].
get
();
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
RecvFromSwitch
(
&
closure
->
cntl
,
&
request
,
&
closure
->
response
,
closure
);
fut
.
wait
();
VLOG
(
4
)
<<
"RecvFromSwitch done"
;
// save in worker
auto
&
res_io_buffer
=
closure
->
cntl
.
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
data_ptr
),
data_size
);
delete
closure
;
VLOG
(
4
)
<<
"Recv done"
;
return
0
;
}
}
// namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/ps/service/heter_client.h
0 → 100644
View file @
de2e6515
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/string/split.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
distributed
{
DECLARE_int32
(
pserver_timeout_ms
);
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
HeterRpcCallbackFunc
;
class
OnHeterRpcDone
:
public
google
::
protobuf
::
Closure
{
public:
explicit
OnHeterRpcDone
(
HeterRpcCallbackFunc
func
)
:
handler_
(
func
)
{}
virtual
~
OnHeterRpcDone
()
{}
void
Run
()
{
handler_
(
this
);
}
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>&
promise
)
{
// NOLINT
_promises
.
push_back
(
promise
);
}
void
set_promise_value
(
int
value
)
{
for
(
auto
&
promise
:
_promises
)
{
promise
->
set_value
(
value
);
}
}
int
CheckResponse
()
{
return
0
;
}
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
HeterRpcCallbackFunc
handler_
;
MultiVariableMessage
request
;
MultiVariableMessage
response
;
PsResponseMessage
ps_response
;
brpc
::
Controller
cntl
;
// PsRequestMessage *request(size_t i) { return &_requests[i]; }
// PsResponseMessage *response(size_t i) { return &_responses[i]; }
// std::vector<PsRequestMessage> _requests;
// std::vector<PsResponseMessage> _responses;
// std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
class
HeterClient
{
public:
virtual
~
HeterClient
()
{}
void
InitClientChannels
(
bool
need_encrypt
,
const
std
::
vector
<
std
::
string
>&
node_list
,
int32_t
peer_role
)
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
connection_type
=
"single"
;
options
.
timeout_ms
=
FLAGS_pserver_timeout_ms
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>*
client_channels
=
nullptr
;
if
(
peer_role
==
PEER_ROLE_IS_SWITCH
)
{
#ifdef PADDLE_WITH_ARM_BRPC
if
(
need_encrypt
)
{
options
.
mutable_ssl_options
();
}
options
.
connection_type
=
""
;
VLOG
(
4
)
<<
"ssl enabled in arm"
;
#else
if
(
need_encrypt
)
{
options
.
mutable_ssl_options
();
}
#endif
client_channels
=
&
peer_switch_channels_
;
}
else
if
(
peer_role
==
PEER_ROLE_IS_WORKER
)
{
client_channels
=
&
peer_worker_channels_
;
}
else
{
LOG
(
ERROR
)
<<
"init switch client failed, peer_role not valid"
;
}
(
*
client_channels
).
resize
(
node_list
.
size
());
for
(
size_t
i
=
0
;
i
<
node_list
.
size
();
++
i
)
{
(
*
client_channels
)[
i
].
reset
(
new
brpc
::
Channel
());
if
((
*
client_channels
)[
i
]
->
Init
(
node_list
[
i
].
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"client channel init failed! try again"
;
auto
ip_port
=
paddle
::
string
::
Split
(
node_list
[
i
],
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
((
*
client_channels
)[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"client channel init failed! peer ip_port = "
<<
int_ip_port
;
}
}
}
VLOG
(
4
)
<<
"InitClientChannels success"
;
}
void
CreateClient2XpuConnection
();
void
SendAndRecvAsync
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_name
,
const
std
::
vector
<
std
::
string
>&
recv_var_name
,
const
std
::
string
&
mode
=
"forward"
);
int
Send
(
int
group_id
,
const
std
::
vector
<
std
::
string
>&
var_names
,
const
std
::
vector
<
int64_t
>&
vars_len
,
void
*
data_ptr
,
int64_t
data_size
);
int
Send
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_names
);
int
Recv
(
int
group_id
,
const
std
::
vector
<
std
::
string
>&
var_names
,
void
*
data_ptr
,
int64_t
data_size
);
int
Recv
(
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
&
recv_scope
,
// NOLINT
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
recv_var_names
);
// HeterClient singleton
static
std
::
shared_ptr
<
HeterClient
>
GetInstance
(
const
std
::
vector
<
std
::
string
>&
endpoints
,
const
std
::
vector
<
std
::
string
>&
previous_endpoints
,
const
int
&
trainer_id
)
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
HeterClient
());
s_instance_
->
SetXpuList
(
endpoints
);
s_instance_
->
SetPreviousXpuList
(
previous_endpoints
);
s_instance_
->
SetTrainerID
(
trainer_id
);
s_instance_
->
CreateClient2XpuConnection
();
}
return
s_instance_
;
}
// switch client singleton
static
std
::
shared_ptr
<
HeterClient
>
GetSwitchInstance
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
,
int32_t
peer_role
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mtx_
);
if
(
peer_endpoints
.
empty
())
{
VLOG
(
4
)
<<
"init switch client failed, null peer_endpoints"
;
}
VLOG
(
4
)
<<
"peer role is: "
<<
peer_role
<<
", addr is: "
<<
peer_endpoints
[
0
];
if
(
switch_s_instance_
==
nullptr
)
{
switch_s_instance_
.
reset
(
new
HeterClient
());
switch_s_instance_
->
SetPeerSwitchList
(
peer_endpoints
);
switch_s_instance_
->
InitClientChannels
(
false
,
peer_endpoints
,
peer_role
);
}
return
switch_s_instance_
;
}
void
SetPeerSwitchList
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
)
{
peer_switch_list_
=
peer_endpoints
;
}
void
SetPeerWorkerList
(
const
std
::
vector
<
std
::
string
>&
worker_endpoints
)
{
peer_worker_list_
=
worker_endpoints
;
}
void
Stop
();
std
::
future
<
int32_t
>
SendCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>&
params
);
std
::
future
<
int32_t
>
StartProfiler
();
std
::
future
<
int32_t
>
StopProfiler
();
std
::
future
<
int32_t
>
StopHeterWorker
();
std
::
vector
<
std
::
string
>&
GetXpuList
()
{
return
xpu_list_
;
}
void
SetXpuList
(
const
std
::
vector
<
std
::
string
>&
xpu_list
)
{
xpu_list_
=
xpu_list
;
}
void
SetPreviousXpuList
(
const
std
::
vector
<
std
::
string
>&
xpu_list
)
{
previous_xpu_list_
=
xpu_list
;
}
void
SetTrainerID
(
const
int
&
trainer_id
)
{
trainer_id_
=
trainer_id
;
}
public:
std
::
vector
<
std
::
string
>
send_switch_list_
;
std
::
vector
<
std
::
string
>
recv_switch_list_
;
std
::
vector
<
std
::
string
>
peer_switch_list_
;
std
::
vector
<
std
::
string
>
peer_worker_list_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
send_switch_channels_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
recv_switch_channels_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
peer_switch_channels_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
peer_worker_channels_
;
private:
HeterClient
()
{}
HeterClient
&
operator
=
(
const
HeterClient
&
);
HeterClient
(
const
HeterClient
&
);
static
std
::
shared_ptr
<
HeterClient
>
s_instance_
;
static
std
::
mutex
mtx_
;
static
std
::
shared_ptr
<
HeterClient
>
switch_s_instance_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
xpu_channels_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
previous_xpu_channels_
;
// DISABLE_COPY_AND_ASSIGN(HeterClient);
std
::
vector
<
std
::
string
>
xpu_list_
;
std
::
vector
<
std
::
string
>
previous_xpu_list_
;
int
trainer_id_
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/ps/service/heter_server.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_server.h"
#include "paddle/fluid/string/split.h"
namespace
paddle
{
namespace
distributed
{
// DEFINE_string(cert_path, "./cert.pem", "cert.pem path");
// DEFINE_string(key_path, "./key.pem", "key.pem path");
std
::
shared_ptr
<
HeterServer
>
HeterServer
::
s_instance_
=
nullptr
;
std
::
mutex
HeterServer
::
mtx_
;
void
HeterServer
::
RegisterServiceHandler
(
std
::
string
message_name
,
HeterServiceHandler
func
)
{
service_
.
RegisterServiceHandler
(
message_name
,
func
);
}
void
HeterServer
::
StartHeterService
(
bool
neeed_encrypt
)
{
server_
.
AddService
(
&
service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
);
brpc
::
ServerOptions
options
;
if
(
neeed_encrypt
)
{
options
.
mutable_ssl_options
()
->
default_cert
.
certificate
=
"/cert.pem"
;
options
.
mutable_ssl_options
()
->
default_cert
.
private_key
=
"/key.pem"
;
}
if
(
server_
.
Start
(
endpoint_
.
c_str
(),
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"HeterServer start fail. Try again."
;
auto
ip_port
=
paddle
::
string
::
Split
(
endpoint_
,
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
server_
.
Start
(
endpoint_
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"HeterServer start failed, ip_port= "
<<
int_ip_port
;
}
}
else
{
VLOG
(
0
)
<<
"heter server start success! listen on "
<<
endpoint_
;
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
stoped_
=
false
;
ready_
=
1
;
}
condition_ready_
.
notify_all
();
VLOG
(
4
)
<<
"stopped: "
<<
stoped_
<<
", ready_: "
<<
ready_
;
std
::
unique_lock
<
std
::
mutex
>
running_lock
(
mutex_
);
cv_
.
wait
(
running_lock
,
[
&
]
{
VLOG
(
4
)
<<
"Heter Server is Stop? "
<<
stoped_
;
return
stoped_
;
});
VLOG
(
4
)
<<
"start service done"
;
}
void
HeterServer
::
StartHeterInterService
(
bool
neeed_encrypt
)
{
server_inter_
.
AddService
(
&
service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
);
brpc
::
ServerOptions
options
;
if
(
neeed_encrypt
)
{
options
.
mutable_ssl_options
()
->
default_cert
.
certificate
=
"/cert.pem"
;
options
.
mutable_ssl_options
()
->
default_cert
.
private_key
=
"/key.pem"
;
}
if
(
server_inter_
.
Start
(
endpoint_inter_
.
c_str
(),
&
options
)
!=
0
)
{
VLOG
(
4
)
<<
"switch inter server start fail. Try again."
;
auto
ip_port
=
paddle
::
string
::
Split
(
endpoint_inter_
,
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
server_inter_
.
Start
(
endpoint_inter_
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"switch inter server start failed, ip_port= "
<<
int_ip_port
;
}
}
else
{
VLOG
(
4
)
<<
"switch inter server server start success! listen on "
<<
endpoint_inter_
;
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
stoped_
=
false
;
ready_
=
1
;
}
condition_ready_
.
notify_all
();
VLOG
(
4
)
<<
"stopped: "
<<
stoped_
<<
", ready_: "
<<
ready_
;
std
::
unique_lock
<
std
::
mutex
>
running_lock
(
mutex_
);
cv_
.
wait
(
running_lock
,
[
&
]
{
VLOG
(
4
)
<<
"Heter Server is Stop? "
<<
stoped_
;
return
stoped_
;
});
VLOG
(
4
)
<<
"start service done"
;
}
void
HeterServer
::
SetFanin
(
const
int
&
fan_in
)
{
service_
.
SetFanin
(
fan_in
);
}
void
HeterServer
::
WaitServerReady
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
}
int
SendAndRecvVariableHandler
::
SaveInSwitchWithShard
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
)
{
VLOG
(
4
)
<<
"entering SaveInSwitchWithShard"
;
int32_t
group_id
=
request
->
group_id
();
if
(
group_id
>=
FLAGS_heter_world_size
)
{
LOG
(
ERROR
)
<<
"group id exceed maxmium"
;
}
auto
&
local_shard
=
_local_shards
[
group_id
];
auto
&
request_io_buffer
=
cntl
->
request_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
request_io_buffer
);
for
(
int
idx
=
0
;
idx
<
request
->
send_var_names_size
();
idx
++
)
{
const
auto
&
var_name
=
request
->
send_var_names
(
idx
);
const
auto
&
var_size
=
request
->
vars_len
(
idx
);
WaitForVarsConsumed
(
group_id
,
var_name
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
auto
&
value
=
local_shard
[
var_name
];
value
.
resize
(
var_size
);
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
value
.
data
()),
var_size
);
vars_ready_flag
[
group_id
][
var_name
]
=
1
;
VLOG
(
4
)
<<
"saved var_name: "
<<
var_name
<<
"is saved ready!"
;
}
VLOG
(
4
)
<<
"SaveInSwitchWithShard success"
;
return
0
;
}
int
SendAndRecvVariableHandler
::
QueryInSwitchWithShard
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
{
VLOG
(
4
)
<<
"entering QueryInSwitchWithShard"
;
int32_t
group_id
=
request
->
group_id
();
VLOG
(
4
)
<<
"group id: "
<<
group_id
;
auto
&
local_shard
=
_local_shards
[
group_id
];
auto
&
response_io_buffer
=
cntl
->
response_attachment
();
auto
req_var_nums
=
request
->
recv_var_names_size
();
std
::
vector
<
std
::
string
>
req_var_names
(
req_var_nums
);
for
(
int
var_idx
=
0
;
var_idx
<
req_var_nums
;
++
var_idx
)
{
req_var_names
[
var_idx
]
=
request
->
recv_var_names
(
var_idx
);
}
auto
msg_name
=
request
->
message_name
();
response
->
set_message_name
(
msg_name
);
for
(
auto
&
req_var_name
:
req_var_names
)
{
VLOG
(
4
)
<<
"req var name: "
<<
req_var_name
;
response
->
add_send_var_names
(
req_var_name
);
WaitForVarsProduced
(
group_id
,
req_var_name
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
auto
itr
=
local_shard
.
find
(
req_var_name
);
auto
&
value
=
itr
.
value
();
response_io_buffer
.
append
(
value
.
data
(),
value
.
size
());
value
.
resize
(
0
);
// 清空内存
vars_ready_flag
[
group_id
][
req_var_name
]
=
0
;
VLOG
(
4
)
<<
"query var_name: "
<<
req_var_name
<<
"is consumed ready!"
;
}
VLOG
(
4
)
<<
"heter server QueryInSwitchWithShard done"
;
return
0
;
}
int
SendAndRecvVariableHandler
::
SaveInSwitchWithScope
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
)
{
VLOG
(
4
)
<<
"entering SaveInSwitchWithScope"
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CPUPlace
cpu_place
;
auto
&
cpu_dev_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
message_name
=
request
->
message_name
();
VLOG
(
4
)
<<
"message_name in heter server: "
<<
message_name
;
auto
send_var_nums
=
request
->
send_var_names_size
();
std
::
vector
<
std
::
string
>
send_var_names
(
send_var_nums
);
for
(
int
idx
=
0
;
idx
<
send_var_nums
;
idx
++
)
{
send_var_names
[
idx
]
=
request
->
var_messages
(
idx
).
varname
();
}
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
auto
local_scope
=
local_scope_ptr
.
get
();
if
(
!
local_scope
)
{
LOG
(
ERROR
)
<<
"local_scope_ptr is null in SaveInSwitchWithScope"
;
}
for
(
auto
var_name
:
send_var_names
)
{
auto
*
var_exist_ptr
=
local_scope
->
FindVar
(
var_name
);
if
(
!
var_exist_ptr
)
{
VLOG
(
4
)
<<
"not find var: "
<<
var_name
<<
" in local_scope"
;
}
WaitForVarsConsumed
(
0
,
var_name
);
}
auto
&
request_io_buffer
=
cntl
->
request_attachment
();
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
*
request
,
&
request_io_buffer
,
cpu_dev_ctx
,
local_scope
);
lk
.
unlock
();
for
(
auto
var_name
:
send_var_names
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
vars_ready_flag
[
0
][
var_name
]
=
1
;
}
VLOG
(
4
)
<<
"SaveInSwitchWithScope success"
;
return
0
;
}
int
SendAndRecvVariableHandler
::
QueryInSwitchWithScope
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
{
VLOG
(
4
)
<<
"entering QueryInSwitchWithScope"
;
auto
local_scope
=
local_scope_ptr
.
get
();
if
(
!
local_scope
)
{
LOG
(
INFO
)
<<
"local_scope is null"
;
}
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CPUPlace
cpu_place
;
auto
&
cpu_dev_ctx
=
*
pool
.
Get
(
cpu_place
);
// get req message_name & req_var_names
auto
msg_name
=
request
->
message_name
();
auto
req_var_nums
=
request
->
recv_var_names_size
();
std
::
vector
<
std
::
string
>
req_var_names
(
req_var_nums
);
for
(
int
var_idx
=
0
;
var_idx
<
req_var_nums
;
++
var_idx
)
{
req_var_names
[
var_idx
]
=
request
->
recv_var_names
(
var_idx
);
}
auto
&
response_io_buffer
=
cntl
->
response_attachment
();
// 1. fill message_name(string)
response
->
set_message_name
(
msg_name
);
// 2. fill var_names(string)
for
(
auto
&
req_var_name
:
req_var_names
)
{
response
->
add_send_var_names
(
req_var_name
);
}
// 3. fill var_messages(VarMessage)
for
(
auto
&
req_var_name
:
req_var_names
)
{
WaitForVarsProduced
(
0
,
req_var_name
);
auto
*
send_var_msg
=
response
->
add_var_messages
();
send_var_msg
->
set_varname
(
req_var_name
);
framework
::
Variable
*
var_ptr
;
var_ptr
=
local_scope
->
FindVar
(
req_var_name
);
if
(
!
var_ptr
)
{
LOG
(
INFO
)
<<
"local_scope not find var: "
<<
req_var_name
;
}
butil
::
IOBuf
temp_iobuf
;
if
(
var_ptr
->
IsType
<
framework
::
LoDTensor
>
())
{
SerializeLodTensor
(
var_ptr
,
cpu_dev_ctx
,
send_var_msg
,
&
temp_iobuf
);
}
else
if
(
var_ptr
->
IsType
<
phi
::
SelectedRows
>
())
{
SerializeSelectedRows
(
var_ptr
,
cpu_dev_ctx
,
send_var_msg
,
&
temp_iobuf
);
}
response_io_buffer
.
append
(
temp_iobuf
);
}
for
(
auto
&
req_var_name
:
req_var_names
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
vars_ready_flag
[
0
][
req_var_name
]
=
0
;
}
VLOG
(
4
)
<<
"heter server QueryInSwitchWithScope done"
;
return
0
;
}
}
// end namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/heter_server.h
0 → 100644
View file @
de2e6515
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
namespace
google
{
namespace
protobuf
{
class
Closure
;
class
RpcController
;
}
// namespace protobuf
}
// namespace google
namespace
paddle
{
namespace
framework
{
class
Executor
;
class
ProgramDesc
;
class
Scope
;
}
// namespace framework
}
// namespace paddle
DECLARE_double
(
eager_delete_tensor_gb
);
namespace
paddle
{
namespace
distributed
{
DECLARE_int32
(
pserver_timeout_ms
);
DECLARE_int32
(
heter_world_size
);
DECLARE_int32
(
switch_send_recv_timeout_s
);
using
MultiVarMsg
=
MultiVariableMessage
;
using
VarMsg
=
VariableMessage
;
using
serviceHandler
=
std
::
function
<
int32_t
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
>
;
using
HeterServiceHandler
=
std
::
function
<
int32_t
(
const
MultiVarMsg
*
,
MultiVarMsg
*
,
brpc
::
Controller
*
)
>
;
using
HeterRpcCallbackFunc
=
std
::
function
<
void
(
void
*
)
>
;
class
ServiceHandlerBase
{
public:
ServiceHandlerBase
()
:
dev_ctx_
(
nullptr
),
scope_
(
nullptr
)
{}
virtual
~
ServiceHandlerBase
()
{}
void
SetScope
(
const
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
virtual
int
Handle
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
=
0
;
protected:
const
platform
::
DeviceContext
*
dev_ctx_
;
const
framework
::
Scope
*
scope_
;
};
using
SharedMiniScope
=
std
::
shared_ptr
<
std
::
unordered_map
<
int
,
::
paddle
::
framework
::
Scope
*>>
;
using
SharedMicroScope
=
std
::
shared_ptr
<
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
std
::
vector
<::
paddle
::
framework
::
Scope
*>>>>
;
using
SharedTaskQueue
=
std
::
shared_ptr
<
std
::
unordered_map
<
int
,
std
::
shared_ptr
<::
paddle
::
framework
::
BlockingQueue
<
std
::
pair
<
std
::
string
,
int
>>>>>
;
class
ValueInSwitch
{
public:
ValueInSwitch
()
{}
~
ValueInSwitch
()
{}
char
*
data
()
{
return
_data
.
data
();
}
size_t
size
()
{
return
_data
.
size
();
}
void
resize
(
size_t
size
)
{
_data
.
resize
(
size
);
}
void
shrink_to_fit
()
{
_data
.
shrink_to_fit
();
}
private:
std
::
vector
<
char
>
_data
;
};
class
SendAndRecvVariableHandler
final
:
public
ServiceHandlerBase
{
public:
SendAndRecvVariableHandler
()
{
this
->
num_microbatch_
=
0
;
this
->
num_minibatch_
=
0
;
_local_shards
.
reset
(
new
shard_type
[
FLAGS_heter_world_size
]);
}
virtual
~
SendAndRecvVariableHandler
()
{}
void
SetMiniScopes
(
SharedMiniScope
mini_scopes
)
{
mini_scopes_
=
mini_scopes
;
num_minibatch_
=
mini_scopes_
->
size
();
}
void
SetMicroScopes
(
SharedMicroScope
micro_scopes
)
{
micro_scopes_
=
micro_scopes
;
for
(
auto
&
scope_pair
:
(
*
micro_scopes_
))
{
// auto mini_idx = scope_pair.first;
auto
&
micro_scopes
=
scope_pair
.
second
;
num_microbatch_
=
micro_scopes
->
size
();
break
;
}
}
int
GetThreadNum
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
return
(
*
task_queue_
).
size
();
}
int
SaveInSwitchWithScope
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
);
void
WaitForVarsConsumed
(
int32_t
group_id
,
const
std
::
string
&
var_name
)
{
// timeline_.Start();
while
(
true
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
scope_mutex_
);
if
(
vars_ready_flag
[
group_id
][
var_name
]
==
0
)
{
break
;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not consumed exceed 10 miniutes";
break;
}
*/
}
return
;
}
void
WaitForVarsProduced
(
int32_t
group_id
,
const
std
::
string
&
var_name
)
{
// timeline_.Start();
while
(
true
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
scope_mutex_
);
if
(
vars_ready_flag
[
group_id
][
var_name
]
==
1
)
{
break
;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not produced exceed 10 miniutes";
break;
}
*/
}
return
;
}
int
SaveInSwitchWithShard
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
);
int
QueryInSwitchWithShard
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
);
int
QueryInSwitchWithScope
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
);
void
SetTaskQueue
(
SharedTaskQueue
task_queue
)
{
task_queue_
=
task_queue
;
}
int
Handle
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
override
{
LOG
(
INFO
)
<<
"entered Handle"
;
platform
::
RecordEvent
record_event
(
"SendAndRecvVariableHandler->Handle"
,
platform
::
TracerEventType
::
Communication
,
1
);
FLAGS_eager_delete_tensor_gb
=
-
1
;
// get microID from request
// deserialize variable to micro scope
// Push to heter worker's task_queue
std
::
unique_ptr
<
paddle
::
framework
::
Scope
>
local_scope_ptr
(
new
paddle
::
framework
::
Scope
());
auto
&
local_scope
=
*
(
local_scope_ptr
.
get
());
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CPUPlace
cpu_place
;
auto
&
cpu_dev_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
message_name
=
request
->
message_name
();
auto
&
request_io_buffer
=
cntl
->
request_attachment
();
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
*
request
,
&
request_io_buffer
,
cpu_dev_ctx
,
&
local_scope
);
auto
*
var
=
local_scope
.
FindVar
(
"microbatch_id"
);
PADDLE_ENFORCE_NE
(
var
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Not find variable microbatch_id in scope."
));
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
data
=
reinterpret_cast
<
const
float
*>
(
tensor
->
data
());
auto
micro_id
=
static_cast
<
int
>
(
data
[
0
]);
VLOG
(
4
)
<<
"micro_id in heter server: "
<<
micro_id
;
int
minibatch_index
=
micro_id
/
10
;
int
microbatch_index
=
micro_id
%
10
;
// check minibatch_index is in mini_scopes_
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
if
((
*
mini_scopes_
).
find
(
minibatch_index
)
!=
(
*
mini_scopes_
).
end
())
{
lk
.
unlock
();
PADDLE_ENFORCE_EQ
(
(
*
micro_scopes_
).
find
(
minibatch_index
)
!=
(
*
micro_scopes_
).
end
(),
1
,
platform
::
errors
::
InvalidArgument
(
"minibatch index should in current trainer"
));
}
else
{
// create mini scope & micro scopes
auto
*
minibatch_scope
=
&
(
scope_
->
NewScope
());
(
*
mini_scopes_
)[
minibatch_index
]
=
minibatch_scope
;
(
*
micro_scopes_
)[
minibatch_index
].
reset
(
new
std
::
vector
<
paddle
::
framework
::
Scope
*>
{});
for
(
int
i
=
0
;
i
<
num_microbatch_
;
i
++
)
{
auto
*
micro_scope
=
&
(
minibatch_scope
->
NewScope
());
(
*
((
*
micro_scopes_
)[
minibatch_index
])).
push_back
(
micro_scope
);
}
(
*
task_queue_
)[
minibatch_index
].
reset
(
new
::
paddle
::
framework
::
BlockingQueue
<
std
::
pair
<
std
::
string
,
int
>>
());
lk
.
unlock
();
}
auto
*
micro_scope
=
(
*
((
*
micro_scopes_
)[
minibatch_index
]))[
microbatch_index
];
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
*
request
,
&
request_io_buffer
,
*
dev_ctx_
,
micro_scope
);
// blocking queue handles multi thread
VLOG
(
4
)
<<
"Handle in HeterServer: "
<<
message_name
<<
", "
<<
microbatch_index
;
VLOG
(
4
)
<<
"task_queue_ size: "
<<
task_queue_
->
size
();
(
*
task_queue_
)[
minibatch_index
]
->
Push
(
std
::
make_pair
(
message_name
,
microbatch_index
));
auto
response_var_nums
=
request
->
recv_var_names_size
();
std
::
vector
<
std
::
string
>
response_var_names
(
response_var_nums
),
empty_var_names
{};
for
(
int
var_idx
=
0
;
var_idx
<
response_var_nums
;
++
var_idx
)
{
response_var_names
[
var_idx
]
=
request
->
recv_var_names
(
var_idx
);
}
auto
&
response_io_buffer
=
cntl
->
response_attachment
();
distributed
::
SerializeToMultiVarMsgAndIOBuf
(
message_name
,
response_var_names
,
empty_var_names
,
*
dev_ctx_
,
&
local_scope
,
response
,
&
response_io_buffer
);
VLOG
(
4
)
<<
"Handle over"
;
return
0
;
}
public:
using
shard_type
=
SparseTableShard
<
std
::
string
,
ValueInSwitch
>
;
std
::
shared_ptr
<
paddle
::
framework
::
Scope
>
local_scope_ptr
;
// for switch
std
::
unordered_map
<
uint32_t
,
std
::
unordered_map
<
std
::
string
,
uint32_t
>>
vars_ready_flag
;
std
::
unique_ptr
<
shard_type
[]
>
_local_shards
;
platform
::
Timer
timeline_
;
private:
// share with HeterPipelineTrainer
SharedMiniScope
mini_scopes_
{
nullptr
};
SharedMicroScope
micro_scopes_
{
nullptr
};
int
num_microbatch_
;
int
num_minibatch_
;
std
::
mutex
scope_mutex_
;
bool
is_first_stage_
=
false
;
bool
is_last_stage_
=
false
;
SharedTaskQueue
task_queue_
;
};
class
HeterService
:
public
PsService
{
public:
HeterService
()
{
_service_handler_map
[
PS_STOP_SERVER
]
=
std
::
bind
(
&
HeterService
::
stop_heter_worker
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
_service_handler_map
[
PS_START_PROFILER
]
=
std
::
bind
(
&
HeterService
::
start_profiler
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
_service_handler_map
[
PS_STOP_PROFILER
]
=
std
::
bind
(
&
HeterService
::
stop_profiler
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
service_handler_
.
local_scope_ptr
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
}
virtual
~
HeterService
()
{}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
auto
itr
=
_service_handler_map
.
find
(
request
->
cmd_id
());
if
(
itr
==
_service_handler_map
.
end
())
{
std
::
string
err_msg
(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"
);
err_msg
.
append
(
std
::
to_string
(
request
->
cmd_id
()));
return
;
}
serviceHandler
handler
=
itr
->
second
;
int
service_ret
=
handler
(
*
request
,
*
response
,
cntl
);
VLOG
(
4
)
<<
"handler in service ret: "
<<
service_ret
;
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
service_ret
);
response
->
set_err_msg
(
"server internal error"
);
}
}
virtual
void
SendAndRecvVariable
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
message_name
=
request
->
message_name
();
VLOG
(
0
)
<<
"SendAndRecvVariable message_name: "
<<
message_name
;
auto
itr
=
handler_map_
.
find
(
message_name
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
LOG
(
INFO
)
<<
"SendAndRecvVariable(client addr) ="
<<
cntl
->
remote_side
();
PADDLE_ENFORCE_NE
(
itr
,
handler_map_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_"
,
message_name
));
itr
->
second
(
request
,
response
,
cntl
);
// We don't want to call done->Run() here, release the guard.
// done_guard.release();
}
virtual
void
RecvFromSwitch
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
// int ret = service_handler_.QueryInSwitchWithScope(request, response,
// cntl);
int
ret
=
service_handler_
.
QueryInSwitchWithShard
(
request
,
response
,
cntl
);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// int ret = itr->second(request, response, cntl);
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"QueryInSwitchWithScope failed!"
;
}
// response->set_message_name(message_name);
}
virtual
void
SendToSwitch
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
VLOG
(
4
)
<<
"entering SendToSwitch"
;
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
shared_ptr
<
HeterClient
>
switch_client_ptr_
=
HeterClient
::
GetSwitchInstance
(
peer_endpoints_
,
PEER_ROLE_IS_SWITCH
);
if
(
switch_client_ptr_
->
peer_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"switch_client_ptr_->peer_switch_channels_ null"
;
}
brpc
::
Channel
*
channel
=
switch_client_ptr_
->
peer_switch_channels_
[
0
].
get
();
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
// proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset)
OnHeterRpcDone
*
closure2
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
int
ret
=
closure
->
CheckResponse
();
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
PADDLE_ENFORCE_NE
(
closure
->
cntl
.
Failed
(),
true
,
platform
::
errors
::
Unimplemented
(
"HeterClient::SendS2S meets brpc error, error message is %s"
,
closure
->
cntl
.
ErrorText
()));
}
});
auto
&
std_cntl
=
closure2
->
cntl
;
std_cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
std_cntl
.
request_attachment
().
append
(
cntl
->
request_attachment
().
movable
());
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure2
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
// brpc::Controller std_cntl;
// std_cntl.request_attachment().append(cntl->request_attachment().movable());
PsService_Stub
stub
(
channel
);
stub
.
SendS2S
(
&
std_cntl
,
request
,
response
,
closure2
);
cntl
->
response_attachment
().
append
(
std_cntl
.
response_attachment
().
movable
());
fut
.
wait
();
VLOG
(
4
)
<<
"SendToSwitch done"
;
delete
closure2
;
}
void
SendS2S
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
VLOG
(
4
)
<<
"entering SendS2S"
;
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
// int ret = service_handler_.SaveInSwitchWithScope(request, response,
// cntl);
int
ret
=
service_handler_
.
SaveInSwitchWithShard
(
request
,
response
,
cntl
);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// if (itr == handler_map_.end()) {
// LOG(ERROR) << "can not find func handler";
//}
// int ret = itr->second(request, response, cntl);
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"SaveInSwitchWithScope failed"
;
}
std
::
string
err_msg
=
"ok"
;
response
->
set_err_msg
(
err_msg
.
c_str
());
response
->
set_err_code
(
ret
);
VLOG
(
4
)
<<
"heter server SendS2S done"
;
}
void
SendToWorker
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
VLOG
(
4
)
<<
"SendToWorker(client addr) ="
<<
cntl
->
remote_side
();
std
::
shared_ptr
<
distributed
::
HeterClient
>
switch_client_ptr_
=
HeterClient
::
GetSwitchInstance
(
peer_endpoints_
,
PEER_ROLE_IS_WORKER
);
VLOG
(
4
)
<<
"in switch client, peer worker 0: "
<<
switch_client_ptr_
->
peer_worker_list_
[
0
];
brpc
::
Channel
*
channel
=
switch_client_ptr_
->
peer_worker_channels_
[
0
].
get
();
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
PsService_Stub
stub
(
channel
);
stub
.
SendAndRecvVariable
(
controller
,
request
,
&
closure
->
response
,
done
);
// fill response content
std
::
string
err_msg
(
"pass to worker"
);
response
->
set_err_msg
(
err_msg
.
c_str
());
response
->
set_err_code
(
0
);
}
void
RegisterServiceHandler
(
std
::
string
message_name
,
HeterServiceHandler
func
)
{
handler_map_
[
message_name
]
=
func
;
}
void
SetEndpoint
(
const
std
::
string
&
end_point
)
{
endpoint_
=
end_point
;
}
void
SetInterEndpoint
(
const
std
::
string
&
end_point
)
{
endpoint_inter_
=
end_point
;
}
void
SetPeerEndPoints
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
)
{
peer_endpoints_
=
peer_endpoints
;
}
void
SetFanin
(
const
int
&
fan_in
)
{
fan_in_
=
fan_in
;
}
void
ForceExit
()
{
VLOG
(
3
)
<<
"heter service force exit"
;
is_exit_
=
true
;
return
;
}
bool
IsExit
()
{
return
is_exit_
;
}
private:
int32_t
stop_profiler
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"heter_worker_%s_profile"
,
endpoint_
));
return
0
;
}
int32_t
start_profiler
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
return
0
;
}
int32_t
stop_heter_worker
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
{
auto
client_id
=
request
.
client_id
();
stop_cpu_worker_set_
.
insert
(
client_id
);
if
(
stop_cpu_worker_set_
.
size
()
==
fan_in_
)
{
is_exit_
=
true
;
}
return
0
;
}
private:
SendAndRecvVariableHandler
service_handler_
;
std
::
string
endpoint_
;
std
::
string
endpoint_inter_
;
// for switch
std
::
vector
<
std
::
string
>
peer_endpoints_
;
std
::
unordered_map
<
int32_t
,
serviceHandler
>
_service_handler_map
;
std
::
unordered_map
<
std
::
string
,
HeterServiceHandler
>
handler_map_
;
std
::
unordered_set
<
int
>
stop_cpu_worker_set_
;
uint32_t
fan_in_
;
bool
is_exit_
=
false
;
};
class
HeterServer
{
public:
HeterServer
()
:
ready_
(
0
)
{}
virtual
~
HeterServer
()
{}
void
Stop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
stoped_
==
true
)
return
;
if
(
!
IsExit
())
{
service_
.
ForceExit
();
}
stoped_
=
true
;
cv_
.
notify_all
();
server_
.
Stop
(
1000
);
server_
.
Join
();
}
bool
IsStop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
stoped_
;
}
bool
IsExit
()
{
return
service_
.
IsExit
();
}
void
RegisterServiceHandler
(
std
::
string
message_name
,
HeterServiceHandler
func
);
void
StartHeterService
(
bool
need_encrypt
=
false
);
void
StartHeterInterService
(
bool
need_encrypt
=
false
);
void
SetEndPoint
(
const
std
::
string
&
endpoint
)
{
this
->
endpoint_
=
endpoint
;
service_
.
SetEndpoint
(
endpoint
);
}
void
SetLocalScope
()
{
request_handler_
->
local_scope_ptr
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
}
void
SetInterEndpoint
(
const
std
::
string
&
endpoint
)
{
this
->
endpoint_inter_
=
endpoint
;
service_
.
SetInterEndpoint
(
endpoint
);
}
void
SetPeerEndPoints
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
)
{
this
->
peer_endpoints_
=
peer_endpoints
;
service_
.
SetPeerEndPoints
(
peer_endpoints
);
}
void
SetFanin
(
const
int
&
fan_in
);
void
SetServiceHandler
(
std
::
shared_ptr
<
SendAndRecvVariableHandler
>
request_handler
)
{
request_handler_
=
request_handler
;
}
void
SetMiniBatchScopes
(
SharedMiniScope
mini_scopes
)
{
request_handler_
->
SetMiniScopes
(
mini_scopes
);
}
void
SetMicroBatchScopes
(
SharedMicroScope
micro_scopes
)
{
request_handler_
->
SetMicroScopes
(
micro_scopes
);
}
int
GetThreadNum
()
{
return
request_handler_
->
GetThreadNum
();
}
void
SetTaskQueue
(
SharedTaskQueue
task_queue
)
{
request_handler_
->
SetTaskQueue
(
task_queue
);
}
// HeterWrapper singleton
static
std
::
shared_ptr
<
HeterServer
>
GetInstance
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mtx_
);
if
(
s_instance_
==
nullptr
)
{
s_instance_
.
reset
(
new
HeterServer
());
}
return
s_instance_
;
}
void
WaitServerReady
();
private:
static
std
::
shared_ptr
<
HeterServer
>
s_instance_
;
mutable
std
::
mutex
mutex_
;
static
std
::
mutex
mtx_
;
std
::
condition_variable
cv_
;
std
::
condition_variable
condition_ready_
;
bool
stoped_
=
true
;
std
::
string
endpoint_
;
std
::
string
endpoint_inter_
;
// for switch
std
::
vector
<
std
::
string
>
peer_endpoints_
;
protected:
brpc
::
Server
server_
;
brpc
::
Server
server_inter_
;
HeterService
service_
;
std
::
shared_ptr
<
SendAndRecvVariableHandler
>
request_handler_
;
DISABLE_COPY_AND_ASSIGN
(
HeterServer
);
std
::
mutex
mutex_ready_
;
int
ready_
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/ps/service/ps_client.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace
paddle
{
namespace
distributed
{
REGISTER_PSCORE_CLASS
(
PSClient
,
BrpcPsClient
);
REGISTER_PSCORE_CLASS
(
PSClient
,
PsLocalClient
);
REGISTER_PSCORE_CLASS
(
PSClient
,
GraphBrpcClient
);
REGISTER_PSCORE_CLASS
(
PSClient
,
CoordinatorClient
);
int32_t
PSClient
::
Configure
(
// called in FleetWrapper::InitWorker
const
PSParameter
&
config
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
&
regions
,
PSEnvironment
&
env
,
size_t
client_id
)
{
_env
=
&
env
;
_config
=
config
;
_dense_pull_regions
=
regions
;
_client_id
=
client_id
;
_config
.
mutable_worker_param
()
->
mutable_downpour_worker_param
()
->
mutable_downpour_table_param
()
->
CopyFrom
(
_config
.
server_param
()
.
downpour_server_param
()
.
downpour_table_param
());
const
auto
&
work_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
int
i
=
0
;
i
<
work_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
accessor
=
CREATE_PSCORE_CLASS
(
ValueAccessor
,
work_param
.
downpour_table_param
(
i
).
accessor
().
accessor_class
());
accessor
->
Configure
(
work_param
.
downpour_table_param
(
i
).
accessor
());
accessor
->
Initialize
();
_table_accessors
[
work_param
.
downpour_table_param
(
i
).
table_id
()].
reset
(
accessor
);
}
return
Initialize
();
}
PSClient
*
PSClientFactory
::
Create
(
const
PSParameter
&
ps_config
)
{
const
auto
&
config
=
ps_config
.
server_param
();
if
(
!
config
.
has_downpour_server_param
())
{
LOG
(
ERROR
)
<<
"miss downpour_server_param in ServerParameter"
;
return
NULL
;
}
if
(
!
config
.
downpour_server_param
().
has_service_param
())
{
LOG
(
ERROR
)
<<
"miss service_param in ServerParameter.downpour_server_param"
;
return
NULL
;
}
if
(
!
config
.
downpour_server_param
().
service_param
().
has_client_class
())
{
LOG
(
ERROR
)
<<
"miss client_class in "
"ServerParameter.downpour_server_param.service_param"
;
return
NULL
;
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSClient
*
client
=
CREATE_PSCORE_CLASS
(
PSClient
,
service_param
.
client_class
());
if
(
client
==
NULL
)
{
LOG
(
ERROR
)
<<
"client is not registered, server_name:"
<<
service_param
.
client_class
();
return
NULL
;
}
TableManager
::
Instance
().
Initialize
();
VLOG
(
3
)
<<
"Create PSClient["
<<
service_param
.
client_class
()
<<
"] success"
;
return
client
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_client.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
distributed
{
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
PSClientCallBack
;
class
PSClientClosure
:
public
google
::
protobuf
::
Closure
{
public:
explicit
PSClientClosure
(
PSClientCallBack
callback
)
:
_callback
(
callback
)
{}
virtual
~
PSClientClosure
()
{}
virtual
void
set_promise_value
(
int
value
)
{
for
(
auto
&
promise
:
_promises
)
{
promise
->
set_value
(
value
);
}
}
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
)
{
// NOLINT
_promises
.
push_back
(
promise
);
}
void
add_timer
(
std
::
shared_ptr
<
CostTimer
>
&
timer
)
{
// NOLINT
_timers
.
push_back
(
timer
);
}
protected:
PSClientCallBack
_callback
;
std
::
vector
<
std
::
shared_ptr
<
CostTimer
>>
_timers
;
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
};
class
PSClient
{
public:
PSClient
()
{}
virtual
~
PSClient
()
{}
PSClient
(
PSClient
&&
)
=
delete
;
PSClient
(
const
PSClient
&
)
=
delete
;
virtual
int32_t
Configure
(
const
PSParameter
&
config
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
&
regions
,
PSEnvironment
&
_env
,
// NOLINT
size_t
client_id
)
final
;
virtual
int32_t
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
=
0
;
// 触发table数据退场
virtual
std
::
future
<
int32_t
>
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
=
0
;
// 全量table进行数据load
virtual
std
::
future
<
int32_t
>
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 指定table数据load
virtual
std
::
future
<
int32_t
>
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual
std
::
future
<
int32_t
>
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual
std
::
future
<
int32_t
>
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 清空table数据
virtual
std
::
future
<
int32_t
>
Clear
()
=
0
;
virtual
std
::
future
<
int32_t
>
Clear
(
uint32_t
table_id
)
=
0
;
// pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数
// future结束前keys和values缓冲区不能再次使用
// client将values按照区块拆包后送交多个sender
// sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual
std
::
future
<
int32_t
>
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
// 保留
// firstly push dense param for parameter server
// this is necessary because dense weight initialized in trainer on cold
// start
virtual
std
::
future
<
int32_t
>
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual
std
::
future
<
int32_t
>
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
=
0
;
virtual
std
::
future
<
int32_t
>
PullSparseParam
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
PullSparsePtr
(
char
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PrintTableStat
(
uint32_t
table_id
)
=
0
;
// 确保所有积攒中的请求都发起发送
virtual
std
::
future
<
int32_t
>
Flush
()
=
0
;
// server优雅退出
virtual
std
::
future
<
int32_t
>
StopServer
()
=
0
;
// server profilera
virtual
std
::
future
<
int32_t
>
StartProfiler
()
=
0
;
virtual
std
::
future
<
int32_t
>
StopProfiler
()
=
0
;
virtual
std
::
future
<
int32_t
>
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
)
=
0
;
virtual
std
::
future
<
int32_t
>
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
uint64_t
>
*
keys
,
int
pserver_idx
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
)
=
0
;
// recv table from server and save it in LodTensor
virtual
int32_t
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
=
0
;
virtual
void
FinalizeWorker
()
=
0
;
// client to client, 消息发送
virtual
std
::
future
<
int32_t
>
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
virtual
int
RegisteClient2ClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
_msg_handler_map
[
msg_type
]
=
handler
;
return
0
;
}
virtual
int
HandleClient2ClientMsg
(
int
msg_type
,
int
from_client_id
,
const
std
::
string
&
msg
)
{
auto
itr
=
_msg_handler_map
.
find
(
msg_type
);
if
(
itr
==
_msg_handler_map
.
end
())
{
LOG
(
WARNING
)
<<
"unknown client2client_msg type:"
<<
msg_type
;
return
-
1
;
}
return
itr
->
second
(
msg_type
,
from_client_id
,
msg
);
}
virtual
ValueAccessor
*
GetTableAccessor
(
size_t
table_id
)
{
auto
itr
=
_table_accessors
.
find
(
table_id
);
if
(
itr
==
_table_accessors
.
end
())
{
return
NULL
;
}
return
itr
->
second
.
get
();
}
virtual
size_t
GetServerNums
()
=
0
;
virtual
std
::
future
<
int32_t
>
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
done
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
=
0
;
// for save cache
virtual
std
::
future
<
int32_t
>
CacheShuffle
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
CacheShuffleMultiTable
(
std
::
vector
<
int
>
tables
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
SaveCache
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
GetCacheThreshold
(
uint32_t
table_id
,
double
&
cache_threshold
)
{
// NOLINT
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
Revert
()
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
CheckSavePrePatchDone
()
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
protected:
virtual
int32_t
Initialize
()
=
0
;
PSParameter
_config
;
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
_dense_pull_regions
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
ValueAccessor
>>
_table_accessors
;
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
// 处理client2client消息
public:
size_t
_client_id
;
PSEnvironment
*
_env
;
};
template
<
class
T
>
class
AsyncRequestTask
{
public:
AsyncRequestTask
()
:
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{}
AsyncRequestTask
(
T
&
data
,
size_t
table_id
,
std
::
shared_ptr
<
CostTimer
>
&
timer
)
:
_table_id
(
table_id
),
_timer
(
timer
),
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{
_data
=
std
::
move
(
data
);
}
AsyncRequestTask
(
AsyncRequestTask
&
data
)
// NOLINT
:
_table_id
(
data
.
table_id
()),
_timer
(
data
.
timer
()),
_promise
(
data
.
promise
())
{
_data
=
std
::
move
(
data
.
data
());
}
~
AsyncRequestTask
()
{}
inline
T
&
data
()
{
return
_data
;
}
inline
size_t
table_id
()
{
return
_table_id
;
}
inline
std
::
shared_ptr
<
CostTimer
>
&
timer
()
{
return
_timer
;
}
inline
std
::
future
<
int32_t
>
get_future
()
{
return
_promise
->
get_future
();
}
inline
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
()
{
return
_promise
;
}
private:
T
_data
;
size_t
_table_id
;
std
::
shared_ptr
<
CostTimer
>
_timer
;
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
_promise
;
};
REGISTER_PSCORE_REGISTERER
(
PSClient
);
class
PSClientFactory
{
public:
static
PSClient
*
Create
(
const
PSParameter
&
config
);
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_local_client.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
//#define pslib_debug_dense_compress
namespace
paddle
{
namespace
distributed
{
int32_t
PsLocalClient
::
Initialize
()
{
const
auto
&
downpour_param
=
_config
.
server_param
().
downpour_server_param
();
TableManager
::
Instance
().
Initialize
();
for
(
int
i
=
0
;
i
<
downpour_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
table
=
CREATE_PSCORE_CLASS
(
Table
,
downpour_param
.
downpour_table_param
(
i
).
table_class
());
table
->
SetShard
(
0
,
1
);
table
->
Initialize
(
downpour_param
.
downpour_table_param
(
i
),
_config
.
fs_client_param
());
_table_map
[
downpour_param
.
downpour_table_param
(
i
).
table_id
()].
reset
(
table
);
}
return
0
;
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
{
// TODO
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
for
(
auto
&
it
:
_table_map
)
{
Load
(
it
.
first
,
epoch
,
mode
);
}
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
auto
*
table_ptr
=
GetTable
(
table_id
);
table_ptr
->
Load
(
epoch
,
mode
);
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
for
(
auto
&
it
:
_table_map
)
{
Save
(
it
.
first
,
epoch
,
mode
);
}
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
auto
*
table_ptr
=
GetTable
(
table_id
);
table_ptr
->
Flush
();
table_ptr
->
Save
(
epoch
,
mode
);
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Clear
()
{
// TODO
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Clear
(
uint32_t
table_id
)
{
// TODO
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Flush
()
{
// no need
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
StopServer
()
{
// no need
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
*
table_ptr
=
GetTable
(
table_id
);
uint32_t
num_per_shard
=
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
1
);
std
::
vector
<
float
>
region_buffer
;
region_buffer
.
resize
(
num_per_shard
);
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
pull_context
.
values
=
region_buffer
.
data
();
table_context
.
num
=
region_buffer
.
size
();
table_ptr
->
Pull
(
table_context
);
// table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t
region_idx
=
0
;
size_t
region_data_idx
=
0
;
size_t
shard_data_size
=
num_per_shard
;
size_t
shard_buffer_remain
=
shard_data_size
*
sizeof
(
float
);
PADDLE_ENFORCE_EQ
(
shard_buffer_remain
,
region_buffer
.
size
()
*
sizeof
(
float
),
platform
::
errors
::
PreconditionNotMet
(
"pull dense size error."
));
size_t
index
=
0
;
while
(
shard_buffer_remain
>
0
&&
region_idx
<
region_num
)
{
auto
&
region
=
regions
[
region_idx
];
if
(
region
.
size
-
region_data_idx
>=
shard_buffer_remain
)
{
memcpy
((
void
*
)(
region
.
data
+
region_data_idx
),
(
uint8_t
*
)(
void
*
)(
region_buffer
.
data
())
+
index
,
shard_buffer_remain
);
region_data_idx
+=
shard_buffer_remain
;
shard_buffer_remain
=
0
;
}
else
if
(
region
.
size
-
region_data_idx
==
0
)
{
++
region_idx
;
region_data_idx
=
0
;
}
else
{
memcpy
((
void
*
)(
region
.
data
+
region_data_idx
),
(
uint8_t
*
)(
void
*
)(
region_buffer
.
data
())
+
index
,
region
.
size
-
region_data_idx
);
shard_buffer_remain
-=
(
region
.
size
-
region_data_idx
);
index
+=
(
region
.
size
-
region_data_idx
);
++
region_idx
;
region_data_idx
=
0
;
}
}
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
*
table_ptr
=
GetTable
(
table_id
);
std
::
vector
<
float
>
region_buffer
;
region_buffer
.
resize
(
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
1
),
0
);
for
(
size_t
i
=
0
,
offset
=
0
;
i
<
region_num
;
++
i
)
{
uint32_t
data_num
=
regions
[
i
].
size
/
sizeof
(
float
);
memcpy
(
region_buffer
.
data
()
+
offset
,
regions
[
i
].
data
,
regions
[
i
].
size
);
offset
+=
data_num
;
}
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
region_buffer
.
data
();
table_context
.
push_context
.
is_param
=
true
;
table_context
.
num
=
region_buffer
.
size
();
table_ptr
->
Push
(
table_context
);
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
callback
)
{
VLOG
(
1
)
<<
"wxx push_dense_raw_gradient"
;
PSClientClosure
*
closure
=
reinterpret_cast
<
PSClientClosure
*>
(
callback
);
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
total_send_data
;
table_context
.
num
=
total_send_data_size
;
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr
->
Push
(
table_context
);
delete
closure
;
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
*
table_ptr
=
GetTable
(
table_id
);
std
::
vector
<
float
>
region_buffer
;
region_buffer
.
resize
(
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
1
));
size_t
data_size
=
region_buffer
.
size
();
for
(
size_t
i
=
0
,
offset
=
0
;
i
<
region_num
;
++
i
)
{
uint32_t
data_num
=
regions
[
i
].
size
/
sizeof
(
float
);
PADDLE_ENFORCE_LE
(
offset
+
data_num
,
data_size
,
platform
::
errors
::
PreconditionNotMet
(
"invalid dense size, cur pos[%d] data_num[%d] size[%d]"
,
offset
,
data_num
,
data_size
));
memcpy
(
region_buffer
.
data
()
+
offset
,
regions
[
i
].
data
,
regions
[
i
].
size
);
offset
+=
data_num
;
}
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
region_buffer
.
data
();
table_context
.
num
=
region_buffer
.
size
();
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr
->
Push
(
table_context
);
return
done
();
}
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id,
// const uint64_t* keys,
// size_t num) {
// // FIXME
// // auto timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size();
//
// // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float));
// table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float));
// size_t offset = 0;
// for (int i = 0; i < num; ++i) {
// memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
// offset += value_size;
// }
//
// // return fut;
// return done();
//}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PullSparsePtr
(
char
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
)
{
// FIXME
// auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
pull_context
.
keys
=
keys
;
table_context
.
pull_context
.
ptr_values
=
select_values
;
table_context
.
use_ptr
=
true
;
table_context
.
num
=
num
;
// table_ptr->PullSparsePtr(select_values, keys, num);
table_ptr
->
Pull
(
table_context
);
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
callback
)
{
PSClientClosure
*
closure
=
reinterpret_cast
<
PSClientClosure
*>
(
callback
);
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
keys
;
table_context
.
push_context
.
ptr_values
=
update_values
;
table_context
.
num
=
num
;
table_context
.
use_ptr
=
true
;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr
->
Push
(
table_context
);
delete
closure
;
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
{
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
keys
;
table_context
.
push_context
.
ptr_values
=
update_values
;
table_context
.
num
=
num
;
table_context
.
use_ptr
=
true
;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr
->
Push
(
table_context
);
return
done
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_local_client.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
namespace
paddle
{
namespace
distributed
{
class
Table
;
class
PsLocalClient
:
public
PSClient
{
public:
PsLocalClient
()
{}
virtual
~
PsLocalClient
()
{
_running
=
false
;
}
virtual
int32_t
CreateClient2ClientConnection
(
int
pslib_timeout_ms
,
int
pslib_connect_timeout_ms
,
int
max_retry
)
{
return
0
;
}
virtual
::
std
::
future
<
int32_t
>
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
override
;
virtual
::
std
::
future
<
int32_t
>
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Clear
()
override
;
virtual
::
std
::
future
<
int32_t
>
Clear
(
uint32_t
table_id
)
override
;
virtual
::
std
::
future
<
int32_t
>
StopServer
()
override
;
virtual
void
FinalizeWorker
()
override
{}
virtual
::
std
::
future
<
int32_t
>
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
PullSparsePtr
(
char
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
);
virtual
::
std
::
future
<
int32_t
>
PrintTableStat
(
uint32_t
table_id
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
);
virtual
::
std
::
future
<
int32_t
>
Flush
();
// server profilera
virtual
std
::
future
<
int32_t
>
StartProfiler
()
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
};
virtual
std
::
future
<
int32_t
>
StopProfiler
()
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>*
values
,
std
::
vector
<
uint64_t
>*
keys
,
int
pserver_idx
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
// recv table from server and save it in LodTensor
virtual
int32_t
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
{
return
0
;
}
virtual
::
std
::
future
<
int32_t
>
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
override
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
size_t
GetServerNums
()
{
return
1
;
}
virtual
std
::
future
<
int32_t
>
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
callback
)
override
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
callback
)
override
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
override
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
override
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
private:
virtual
int32_t
Initialize
()
override
;
std
::
future
<
int32_t
>
done
()
{
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
prom
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int32_t
>
fut
=
prom
->
get_future
();
prom
->
set_value
(
0
);
return
fut
;
}
inline
uint32_t
DenseDimPerShard
(
uint32_t
dense_dim_total
,
uint32_t
shard_num
)
{
return
dense_dim_total
/
shard_num
+
1
;
}
inline
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>*
GetTable
()
{
return
&
_table_map
;
}
inline
Table
*
GetTable
(
size_t
table_id
)
{
auto
itr
=
_table_map
.
find
(
table_id
);
if
(
itr
!=
_table_map
.
end
())
{
return
itr
->
second
.
get
();
}
LOG
(
ERROR
)
<<
"table not found "
<<
table_id
;
return
NULL
;
}
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
_table_map
;
bool
_running
=
false
;
bool
_flushing
=
false
;
private:
float
_mae
=
0
;
float
_mse
=
0
;
uint16_t
_push_times
=
0
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_local_server.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/ps/service/server.h"
namespace
paddle
{
namespace
distributed
{
class
PsLocalServer
:
public
PSServer
{
public:
PsLocalServer
()
{}
virtual
~
PsLocalServer
()
{}
virtual
uint64_t
Start
()
{
return
0
;
}
virtual
uint64_t
Start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
return
0
;
}
virtual
int32_t
Stop
()
{
return
0
;
}
virtual
int32_t
Configure
(
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{})
{
return
0
;
}
private:
virtual
int32_t
Initialize
()
{
return
0
;
}
};
}
// namespace distributed
}
// namespace paddle
Prev
1
…
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment