Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
d2d32668
Commit
d2d32668
authored
Apr 26, 2023
by
yuguo960516yuguo
Browse files
2.3.0-dtk-22.04.2
parent
ad08b8ce
Pipeline
#226
failed with stages
in 0 seconds
Changes
268
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5794 additions
and
0 deletions
+5794
-0
paddle/fluid/distributed/ps/service/communicator/communicator_common.h
...distributed/ps/service/communicator/communicator_common.h
+114
-0
paddle/fluid/distributed/ps/service/env.cc
paddle/fluid/distributed/ps/service/env.cc
+19
-0
paddle/fluid/distributed/ps/service/env.h
paddle/fluid/distributed/ps/service/env.h
+291
-0
paddle/fluid/distributed/ps/service/graph_brpc_client.cc
paddle/fluid/distributed/ps/service/graph_brpc_client.cc
+702
-0
paddle/fluid/distributed/ps/service/graph_brpc_client.h
paddle/fluid/distributed/ps/service/graph_brpc_client.h
+139
-0
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
+719
-0
paddle/fluid/distributed/ps/service/graph_brpc_server.h
paddle/fluid/distributed/ps/service/graph_brpc_server.h
+170
-0
paddle/fluid/distributed/ps/service/heter_client.cc
paddle/fluid/distributed/ps/service/heter_client.cc
+429
-0
paddle/fluid/distributed/ps/service/heter_client.h
paddle/fluid/distributed/ps/service/heter_client.h
+255
-0
paddle/fluid/distributed/ps/service/heter_server.cc
paddle/fluid/distributed/ps/service/heter_server.cc
+272
-0
paddle/fluid/distributed/ps/service/heter_server.h
paddle/fluid/distributed/ps/service/heter_server.h
+685
-0
paddle/fluid/distributed/ps/service/ps_client.cc
paddle/fluid/distributed/ps/service/ps_client.cc
+91
-0
paddle/fluid/distributed/ps/service/ps_client.h
paddle/fluid/distributed/ps/service/ps_client.h
+357
-0
paddle/fluid/distributed/ps/service/ps_local_client.cc
paddle/fluid/distributed/ps/service/ps_local_client.cc
+328
-0
paddle/fluid/distributed/ps/service/ps_local_client.h
paddle/fluid/distributed/ps/service/ps_local_client.h
+237
-0
paddle/fluid/distributed/ps/service/ps_local_server.h
paddle/fluid/distributed/ps/service/ps_local_server.h
+44
-0
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc
...uid/distributed/ps/service/ps_service/graph_py_service.cc
+511
-0
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h
...luid/distributed/ps/service/ps_service/graph_py_service.h
+215
-0
paddle/fluid/distributed/ps/service/ps_service/service.cc
paddle/fluid/distributed/ps/service/ps_service/service.cc
+137
-0
paddle/fluid/distributed/ps/service/ps_service/service.h
paddle/fluid/distributed/ps/service/ps_service/service.h
+79
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
paddle/fluid/distributed/ps/service/communicator/communicator_common.h
0 → 100644
View file @
d2d32668
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
namespace
paddle
{
namespace
distributed
{
struct
CommContext
{
CommContext
()
=
default
;
CommContext
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
std
::
string
>
&
emap
,
const
std
::
vector
<
int64_t
>
&
sections
,
const
std
::
vector
<
std
::
string
>
&
origin_names
,
int
id
,
bool
merge_add_
=
true
,
bool
is_sparse_
=
true
,
bool
is_distributed_
=
false
,
int
table_id_
=
-
1
,
bool
is_tensor_table_
=
false
,
bool
is_datanorm_table_
=
false
,
int64_t
program_id_
=
-
1
)
:
var_name
(
name
),
splited_varnames
(
names
),
epmap
(
emap
),
height_sections
(
sections
),
origin_varnames
(
origin_names
),
trainer_id
(
id
),
merge_add
(
merge_add_
),
is_sparse
(
is_sparse_
),
is_distributed
(
is_distributed_
),
table_id
(
table_id_
),
program_id
(
program_id_
),
is_tensor_table
(
is_tensor_table_
),
is_datanorm_table
(
is_datanorm_table_
)
{}
CommContext
(
const
CommContext
&
ctx
)
{
var_name
=
ctx
.
var_name
;
splited_varnames
=
ctx
.
splited_varnames
;
epmap
=
ctx
.
epmap
;
height_sections
=
ctx
.
height_sections
;
trainer_id
=
ctx
.
trainer_id
;
merge_add
=
ctx
.
merge_add
;
is_sparse
=
ctx
.
is_sparse
;
origin_varnames
=
ctx
.
origin_varnames
;
is_distributed
=
ctx
.
is_distributed
;
table_id
=
ctx
.
table_id
;
program_id
=
ctx
.
program_id
;
is_tensor_table
=
ctx
.
is_tensor_table
;
is_datanorm_table
=
ctx
.
is_datanorm_table
;
}
std
::
string
print
()
const
{
std
::
stringstream
ss
;
ss
<<
"varname: "
<<
var_name
<<
" trainer_id: "
<<
trainer_id
<<
" "
;
ss
<<
" table_id: "
<<
table_id
;
for
(
size_t
i
=
0
;
i
<
splited_varnames
.
size
();
i
++
)
{
ss
<<
"slice varname: "
<<
splited_varnames
[
i
]
<<
" ep: "
<<
epmap
[
i
]
<<
" section: "
<<
height_sections
[
i
]
<<
" "
;
}
ss
<<
"origin varnames: "
;
for
(
size_t
i
=
0
;
i
<
origin_varnames
.
size
();
i
++
)
{
ss
<<
origin_varnames
[
i
]
<<
" "
;
}
ss
<<
" aggregation->add: "
<<
merge_add
;
ss
<<
" is_sparse: "
<<
is_sparse
;
ss
<<
" is_distributed: "
<<
is_distributed
<<
"
\n
"
;
ss
<<
" table_id: "
<<
table_id
<<
"
\n
"
;
ss
<<
" program_id: "
<<
program_id
<<
"
\n
"
;
ss
<<
" is_tensor_table: "
<<
is_tensor_table
<<
"
\n
"
;
ss
<<
" is_datanorm_table: "
<<
is_datanorm_table
<<
"
\n
"
;
return
ss
.
str
();
}
std
::
string
var_name
;
std
::
vector
<
std
::
string
>
splited_varnames
;
std
::
vector
<
std
::
string
>
epmap
;
std
::
vector
<
int64_t
>
height_sections
;
std
::
vector
<
std
::
string
>
origin_varnames
;
int
trainer_id
;
bool
merge_add
;
bool
is_sparse
;
bool
is_distributed
;
int
table_id
;
int64_t
program_id
;
bool
is_tensor_table
;
bool
is_datanorm_table
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/env.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/env.h"
namespace
paddle
{
namespace
distributed
{}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/env.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <arpa/inet.h>
#include <glog/logging.h>
#include <netinet/in.h>
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "gflags/gflags.h"
namespace
paddle
{
namespace
distributed
{
struct
PSHost
{
std
::
string
ip
;
uint32_t
port
;
uint32_t
rank
;
PSHost
()
=
default
;
PSHost
(
const
std
::
string
ip
,
uint32_t
port
,
uint32_t
rank
)
:
ip
(
ip
),
port
(
port
),
rank
(
rank
)
{}
// |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-|
uint64_t
SerializeToUint64
()
{
uint64_t
host_label
=
0
;
host_label
=
inet_addr
(
ip
.
c_str
());
host_label
=
host_label
<<
32
;
host_label
+=
(
port
<<
12
);
host_label
+=
rank
;
return
host_label
;
}
void
ParseFromUint64
(
uint64_t
host_label
)
{
static
uint64_t
rank_label_mask
=
(
1L
<<
12
)
-
1
;
static
uint64_t
port_label_mask
=
(
1L
<<
20
)
-
1
;
rank
=
host_label
&
rank_label_mask
;
port
=
(
host_label
>>
12
)
&
port_label_mask
;
uint32_t
ip_addr
=
(
host_label
>>
32
);
ip
=
inet_ntoa
(
*
(
in_addr
*
)
&
ip_addr
);
// NOLINT
}
std
::
string
ToString
()
{
std
::
stringstream
s
;
s
<<
"host: "
<<
ip
;
s
<<
" port: "
<<
port
;
s
<<
" rank: "
<<
rank
;
s
<<
" uint: "
<<
SerializeToUint64
();
return
s
.
str
();
}
// for open source parameter server
std
::
string
SerializeToString
()
{
std
::
stringstream
s
;
s
<<
ip
<<
":"
;
s
<<
port
<<
":"
;
s
<<
rank
;
return
s
.
str
();
}
void
ParseFromString
(
std
::
string
endpoint
)
{
std
::
vector
<
std
::
string
>
endpoint_info
;
StringSplit
(
endpoint
,
':'
,
&
endpoint_info
);
ip
=
endpoint_info
[
0
];
port
=
std
::
stoi
(
endpoint_info
[
1
]);
rank
=
std
::
stoi
(
endpoint_info
[
2
]);
}
void
StringSplit
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
,
bool
ignore_null
=
true
)
{
pieces
->
clear
();
if
(
str
.
empty
())
{
if
(
!
ignore_null
)
{
pieces
->
push_back
(
str
);
}
return
;
}
size_t
pos
=
0
;
size_t
next
=
str
.
find
(
sep
,
pos
);
while
(
next
!=
std
::
string
::
npos
)
{
pieces
->
push_back
(
str
.
substr
(
pos
,
next
-
pos
));
pos
=
next
+
1
;
next
=
str
.
find
(
sep
,
pos
);
}
if
(
!
str
.
substr
(
pos
).
empty
())
{
pieces
->
push_back
(
str
.
substr
(
pos
));
}
}
};
class
PSEnvironment
{
public:
explicit
PSEnvironment
()
{}
// NOLINT
virtual
~
PSEnvironment
()
{}
virtual
int32_t
SetPsServers
(
uint64_t
*
host_sign_list
,
int
node_num
)
{
return
0
;
}
virtual
int32_t
SetPsServers
(
const
std
::
vector
<
std
::
string
>
*
host_endpoint_list
,
int
node_num
)
{
return
0
;
}
virtual
int32_t
SetPsClients
(
uint64_t
*
host_sign_list
,
int
node_num
)
{
return
0
;
}
virtual
int32_t
SetPsClients
(
std
::
string
*
host_endpoint_list
,
int
node_num
)
{
return
0
;
}
virtual
uint64_t
GetLocalHostSign
()
{
return
0
;
}
virtual
std
::
vector
<
PSHost
>
GetPsServers
()
const
{
return
_ps_server_list
;
}
virtual
int32_t
RegistePsServer
(
const
std
::
string
&
ip
,
uint32_t
port
,
int32_t
rank
)
{
return
RegistePsHost
(
ip
,
port
,
rank
,
_ps_server_list
,
_ps_server_sign_set
);
}
virtual
std
::
vector
<
PSHost
>
GetPsClients
()
const
{
return
_ps_client_list
;
}
virtual
int32_t
RegistePsClient
(
const
std
::
string
&
ip
,
uint32_t
port
,
int32_t
rank
)
{
return
RegistePsHost
(
ip
,
port
,
rank
,
_ps_client_list
,
_ps_client_sign_set
);
}
virtual
std
::
vector
<
uint64_t
>
GetClientInfo
()
{
std
::
vector
<
uint64_t
>
client_info
;
for
(
auto
&
i
:
_ps_client_list
)
{
client_info
.
push_back
(
i
.
SerializeToUint64
());
}
return
client_info
;
}
virtual
std
::
vector
<
std
::
string
>
GetClientInfo
(
bool
use_string_endpoint
)
{
if
(
use_string_endpoint
)
{
std
::
vector
<
std
::
string
>
client_info
;
for
(
auto
&
i
:
_ps_client_list
)
{
client_info
.
push_back
(
i
.
SerializeToString
());
}
return
client_info
;
}
return
{};
}
virtual
void
SetTrainers
(
int
trainers
)
{
trainers_
=
trainers
;
}
virtual
int
GetTrainers
()
{
return
trainers_
;
}
protected:
//注册一个host // NOLINT
virtual
int32_t
RegistePsHost
(
const
std
::
string
&
ip
,
uint32_t
port
,
int32_t
rank
,
std
::
vector
<
PSHost
>
&
host_list
,
// NOLINT
std
::
unordered_set
<
uint64_t
>
&
sign_set
)
{
// NOLINT
PSHost
host
;
host
.
ip
=
ip
;
host
.
port
=
port
;
host
.
rank
=
rank
;
if
(
sign_set
.
count
(
rank
)
==
0
)
{
host_list
.
push_back
(
host
);
sign_set
.
insert
(
rank
);
}
return
0
;
}
int
trainers_
=
0
;
std
::
vector
<
PSHost
>
_ps_client_list
;
std
::
unordered_set
<
uint64_t
>
_ps_client_sign_set
;
// for unique filter
std
::
vector
<
PSHost
>
_ps_server_list
;
std
::
unordered_set
<
uint64_t
>
_ps_server_sign_set
;
// for unique filter
};
class
PaddlePSEnvironment
:
public
PSEnvironment
{
public:
explicit
PaddlePSEnvironment
()
{}
// NOLINT
virtual
~
PaddlePSEnvironment
()
{}
virtual
int32_t
SetPsServers
(
uint64_t
*
host_sign_list
,
int
node_num
)
{
_ps_server_list
.
clear
();
_ps_server_sign_set
.
clear
();
for
(
int
i
=
0
;
i
<
node_num
;
++
i
)
{
if
(
host_sign_list
[
i
]
>
0
)
{
PSHost
host
;
host
.
ParseFromUint64
(
host_sign_list
[
i
]);
_ps_server_list
.
push_back
(
host
);
_ps_server_sign_set
.
insert
(
host
.
SerializeToUint64
());
}
}
std
::
sort
(
_ps_server_list
.
begin
(),
_ps_server_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
return
0
;
}
virtual
int32_t
SetPsServers
(
const
std
::
vector
<
std
::
string
>
*
host_sign_list
,
int
node_num
)
{
_ps_server_list
.
clear
();
_ps_server_sign_set
.
clear
();
for
(
int
i
=
0
;
i
<
node_num
;
++
i
)
{
if
(
host_sign_list
->
at
(
i
)
!=
""
)
{
PSHost
host
;
host
.
ParseFromString
(
host_sign_list
->
at
(
i
));
_ps_server_list
.
push_back
(
host
);
_ps_server_sign_set
.
insert
(
host
.
rank
);
}
}
std
::
sort
(
_ps_server_list
.
begin
(),
_ps_server_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
return
0
;
}
virtual
int32_t
SetPsClients
(
uint64_t
*
host_sign_list
,
int
node_num
)
{
_ps_client_list
.
clear
();
_ps_client_sign_set
.
clear
();
for
(
int
i
=
0
;
i
<
node_num
;
++
i
)
{
if
(
host_sign_list
[
i
]
>
0
)
{
PSHost
host
;
host
.
ParseFromUint64
(
host_sign_list
[
i
]);
_ps_client_list
.
push_back
(
host
);
_ps_client_sign_set
.
insert
(
host
.
SerializeToUint64
());
}
}
std
::
sort
(
_ps_client_list
.
begin
(),
_ps_client_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
return
0
;
}
virtual
int32_t
SetPsClients
(
const
std
::
vector
<
std
::
string
>
*
host_sign_list
,
int
node_num
)
{
_ps_client_list
.
clear
();
_ps_client_sign_set
.
clear
();
for
(
int
i
=
0
;
i
<
node_num
;
++
i
)
{
if
(
host_sign_list
->
at
(
i
)
!=
""
)
{
PSHost
host
;
host
.
ParseFromString
(
host_sign_list
->
at
(
i
));
_ps_client_list
.
push_back
(
host
);
_ps_client_sign_set
.
insert
(
host
.
rank
);
}
}
std
::
sort
(
_ps_client_list
.
begin
(),
_ps_client_list
.
end
(),
[](
const
PSHost
&
h1
,
const
PSHost
&
h2
)
{
return
h1
.
rank
<
h2
.
rank
;
});
VLOG
(
1
)
<<
"env.set_ps_clients done
\n
"
;
return
0
;
}
virtual
uint64_t
GetLocalHostSign
()
{
if
(
_ps_client_list
.
size
()
>
0
)
{
return
_ps_client_list
[
0
].
SerializeToUint64
();
}
else
{
return
0
;
}
}
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/graph_brpc_client.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
distributed
{
void
GraphPsService_Stub
::
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
distributed
::
PsRequestMessage
*
request
,
::
paddle
::
distributed
::
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
if
(
graph_service
!=
NULL
&&
local_channel
==
channel
())
{
// VLOG(0)<<"use local";
task_pool
->
enqueue
([
this
,
controller
,
request
,
response
,
done
]()
->
int
{
this
->
graph_service
->
service
(
controller
,
request
,
response
,
done
);
return
0
;
});
}
else
{
// VLOG(0)<<"use server";
PsService_Stub
::
service
(
controller
,
request
,
response
,
done
);
}
}
int
GraphBrpcClient
::
get_server_index_by_id
(
int64_t
id
)
{
int
shard_num
=
get_shard_num
();
int
shard_per_server
=
shard_num
%
server_size
==
0
?
shard_num
/
server_size
:
shard_num
/
server_size
+
1
;
return
id
%
shard_num
/
shard_per_server
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
get_node_feat
(
const
uint32_t
&
table_id
,
int
idx_
,
const
std
::
vector
<
int64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
)
{
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_ids
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_GET_NODE_FEAT
)
!=
0
)
{
++
fail_num
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
request_idx
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
std
::
unique_ptr
<
char
[]
>
buffer_wrapper
(
new
char
[
bytes_size
]);
char
*
buffer
=
buffer_wrapper
.
get
();
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
query_idx_buckets
.
at
(
request_idx
).
size
();
++
node_idx
)
{
int
query_idx
=
query_idx_buckets
.
at
(
request_idx
).
at
(
node_idx
);
size_t
feat_len
=
*
(
size_t
*
)(
buffer
);
buffer
+=
sizeof
(
size_t
);
auto
feature
=
std
::
string
(
buffer
,
feat_len
);
res
[
feat_idx
][
query_idx
]
=
feature
;
buffer
+=
feat_len
;
}
}
}
if
(
fail_num
==
request_call_num
)
{
ret
=
-
1
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_GET_NODE_FEAT
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
std
::
string
joint_feature_name
=
paddle
::
string
::
join_strings
(
feature_names
,
'\t'
);
closure
->
request
(
request_idx
)
->
add_params
(
joint_feature_name
.
c_str
(),
joint_feature_name
.
size
());
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
clear_nodes
(
uint32_t
table_id
,
int
type_id
,
int
idx_
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
server_size
,
[
&
,
server_size
=
this
->
server_size
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
server_size
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_CLEAR
)
!=
0
)
{
++
fail_num
;
break
;
}
}
ret
=
fail_num
==
0
?
0
:
-
1
;
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
server_size
;
i
++
)
{
int
server_index
=
i
;
closure
->
request
(
server_index
)
->
set_cmd_id
(
PS_GRAPH_CLEAR
);
closure
->
request
(
server_index
)
->
set_table_id
(
table_id
);
closure
->
request
(
server_index
)
->
set_client_id
(
_client_id
);
closure
->
request
(
server_index
)
->
add_params
((
char
*
)
&
type_id
,
sizeof
(
int
));
closure
->
request
(
server_index
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
server_index
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
server_index
),
closure
->
request
(
server_index
),
closure
->
response
(
server_index
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
add_graph_node
(
uint32_t
table_id
,
int
idx_
,
std
::
vector
<
int64_t
>
&
node_id_list
,
std
::
vector
<
bool
>
&
is_weighted_list
)
{
std
::
vector
<
std
::
vector
<
int64_t
>>
request_bucket
;
std
::
vector
<
std
::
vector
<
bool
>>
is_weighted_bucket
;
bool
add_weight
=
is_weighted_list
.
size
()
>
0
;
std
::
vector
<
int
>
server_index_arr
;
std
::
vector
<
int
>
index_mapping
(
server_size
,
-
1
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_id_list
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_id_list
[
query_idx
]);
if
(
index_mapping
[
server_index
]
==
-
1
)
{
index_mapping
[
server_index
]
=
request_bucket
.
size
();
server_index_arr
.
push_back
(
server_index
);
request_bucket
.
push_back
(
std
::
vector
<
int64_t
>
());
if
(
add_weight
)
is_weighted_bucket
.
push_back
(
std
::
vector
<
bool
>
());
}
request_bucket
[
index_mapping
[
server_index
]].
push_back
(
node_id_list
[
query_idx
]);
if
(
add_weight
)
is_weighted_bucket
[
index_mapping
[
server_index
]].
push_back
(
query_idx
<
is_weighted_list
.
size
()
?
is_weighted_list
[
query_idx
]
:
false
);
}
size_t
request_call_num
=
request_bucket
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_ADD_GRAPH_NODE
)
!=
0
)
{
++
fail_num
;
}
}
ret
=
fail_num
==
request_call_num
?
-
1
:
0
;
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
server_index_arr
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_ADD_GRAPH_NODE
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
request_bucket
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
request_bucket
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
if
(
add_weight
)
{
bool
weighted
[
is_weighted_bucket
[
request_idx
].
size
()
+
1
];
for
(
size_t
j
=
0
;
j
<
is_weighted_bucket
[
request_idx
].
size
();
j
++
)
weighted
[
j
]
=
is_weighted_bucket
[
request_idx
][
j
];
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
weighted
,
sizeof
(
bool
)
*
is_weighted_bucket
[
request_idx
].
size
());
}
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
remove_graph_node
(
uint32_t
table_id
,
int
idx_
,
std
::
vector
<
int64_t
>
&
node_id_list
)
{
std
::
vector
<
std
::
vector
<
int64_t
>>
request_bucket
;
std
::
vector
<
int
>
server_index_arr
;
std
::
vector
<
int
>
index_mapping
(
server_size
,
-
1
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_id_list
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_id_list
[
query_idx
]);
if
(
index_mapping
[
server_index
]
==
-
1
)
{
index_mapping
[
server_index
]
=
request_bucket
.
size
();
server_index_arr
.
push_back
(
server_index
);
request_bucket
.
push_back
(
std
::
vector
<
int64_t
>
());
}
request_bucket
[
index_mapping
[
server_index
]].
push_back
(
node_id_list
[
query_idx
]);
}
size_t
request_call_num
=
request_bucket
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_REMOVE_GRAPH_NODE
)
!=
0
)
{
++
fail_num
;
}
}
ret
=
fail_num
==
request_call_num
?
-
1
:
0
;
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
server_index_arr
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_REMOVE_GRAPH_NODE
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
request_bucket
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
request_bucket
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
// char* &buffer,int &actual_size
std
::
future
<
int32_t
>
GraphBrpcClient
::
batch_sample_neighbors
(
uint32_t
table_id
,
int
idx_
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
// std::vector<std::vector<std::pair<int64_t, float>>> &res,
std
::
vector
<
std
::
vector
<
int64_t
>>
&
res
,
std
::
vector
<
std
::
vector
<
float
>>
&
res_weight
,
bool
need_weight
,
int
server_index
)
{
if
(
server_index
!=
-
1
)
{
res
.
resize
(
node_ids
.
size
());
if
(
need_weight
)
{
res_weight
.
resize
(
node_ids
.
size
());
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
if
(
closure
->
check_response
(
0
,
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
)
!=
0
)
{
ret
=
-
1
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
std
::
unique_ptr
<
char
[]
>
buffer_wrapper
(
new
char
[
bytes_size
]);
char
*
buffer
=
buffer_wrapper
.
get
();
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
size_t
node_num
=
*
(
size_t
*
)
buffer
;
int
*
actual_sizes
=
(
int
*
)(
buffer
+
sizeof
(
size_t
));
char
*
node_buffer
=
buffer
+
sizeof
(
size_t
)
+
sizeof
(
int
)
*
node_num
;
int
offset
=
0
;
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
int
actual_size
=
actual_sizes
[
node_idx
];
int
start
=
0
;
while
(
start
<
actual_size
)
{
res
[
node_idx
].
emplace_back
(
*
(
int64_t
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
id_size
;
if
(
need_weight
)
{
res_weight
[
node_idx
].
emplace_back
(
*
(
float
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
weight_size
;
}
}
offset
+=
actual_size
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
;
closure
->
request
(
0
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
node_ids
.
data
(),
sizeof
(
int64_t
)
*
node_ids
.
size
());
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
need_weight
,
sizeof
(
bool
));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
res
.
clear
();
res_weight
.
clear
();
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
// res.push_back(std::vector<std::pair<int64_t, float>>());
res
.
push_back
({});
if
(
need_weight
)
{
res_weight
.
push_back
({});
}
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_ids
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_SAMPLE_NEIGHBORS
)
!=
0
)
{
++
fail_num
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
request_idx
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
std
::
unique_ptr
<
char
[]
>
buffer_wrapper
(
new
char
[
bytes_size
]);
char
*
buffer
=
buffer_wrapper
.
get
();
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
size_t
node_num
=
*
(
size_t
*
)
buffer
;
int
*
actual_sizes
=
(
int
*
)(
buffer
+
sizeof
(
size_t
));
char
*
node_buffer
=
buffer
+
sizeof
(
size_t
)
+
sizeof
(
int
)
*
node_num
;
int
offset
=
0
;
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
int
query_idx
=
query_idx_buckets
.
at
(
request_idx
).
at
(
node_idx
);
int
actual_size
=
actual_sizes
[
node_idx
];
int
start
=
0
;
while
(
start
<
actual_size
)
{
res
[
query_idx
].
emplace_back
(
*
(
int64_t
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
id_size
;
if
(
need_weight
)
{
res_weight
[
query_idx
].
emplace_back
(
*
(
float
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
weight_size
;
}
}
offset
+=
actual_size
;
}
}
if
(
fail_num
==
request_call_num
)
{
ret
=
-
1
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NEIGHBORS
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
need_weight
,
sizeof
(
bool
));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
random_sample_nodes
(
uint32_t
table_id
,
int
type_id
,
int
idx_
,
int
server_index
,
int
sample_size
,
std
::
vector
<
int64_t
>
&
ids
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
if
(
closure
->
check_response
(
0
,
PS_GRAPH_SAMPLE_NODES
)
!=
0
)
{
ret
=
-
1
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
char
*
buffer
=
new
char
[
bytes_size
];
size_t
index
=
0
;
while
(
index
<
bytes_size
)
{
ids
.
push_back
(
*
(
int64_t
*
)(
buffer
+
index
));
index
+=
GraphNode
::
id_size
;
}
delete
[]
buffer
;
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
;
closure
->
request
(
0
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NODES
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
type_id
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
pull_graph_list
(
uint32_t
table_id
,
int
type_id
,
int
idx_
,
int
server_index
,
int
start
,
int
size
,
int
step
,
std
::
vector
<
FeatureNode
>
&
res
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
if
(
closure
->
check_response
(
0
,
PS_PULL_GRAPH_LIST
)
!=
0
)
{
ret
=
-
1
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
char
*
buffer
=
new
char
[
bytes_size
];
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
size_t
index
=
0
;
while
(
index
<
bytes_size
)
{
FeatureNode
node
;
node
.
recover_from_buffer
(
buffer
+
index
);
index
+=
node
.
get_size
(
false
);
res
.
push_back
(
node
);
}
delete
[]
buffer
;
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
closure
->
request
(
0
)
->
set_cmd_id
(
PS_PULL_GRAPH_LIST
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
type_id
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
start
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
size
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
step
,
sizeof
(
int
));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
set_node_feat
(
const
uint32_t
&
table_id
,
int
idx_
,
const
std
::
vector
<
int64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
features
)
{
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
string
>>>
features_idx_buckets
(
request_call_num
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_ids
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
if
(
features_idx_buckets
[
request_idx
].
size
()
==
0
)
{
features_idx_buckets
[
request_idx
].
resize
(
feature_names
.
size
());
}
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
features_idx_buckets
[
request_idx
][
feat_idx
].
push_back
(
features
[
feat_idx
][
query_idx
]);
}
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_SET_NODE_FEAT
)
!=
0
)
{
++
fail_num
;
}
if
(
fail_num
==
request_call_num
)
{
ret
=
-
1
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_SET_NODE_FEAT
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
std
::
string
joint_feature_name
=
paddle
::
string
::
join_strings
(
feature_names
,
'\t'
);
closure
->
request
(
request_idx
)
->
add_params
(
joint_feature_name
.
c_str
(),
joint_feature_name
.
size
());
// set features
std
::
string
set_feature
=
""
;
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
size_t
feat_len
=
features_idx_buckets
[
request_idx
][
feat_idx
][
node_idx
].
size
();
set_feature
.
append
((
char
*
)
&
feat_len
,
sizeof
(
size_t
));
set_feature
.
append
(
features_idx_buckets
[
request_idx
][
feat_idx
][
node_idx
].
data
(),
feat_len
);
}
}
closure
->
request
(
request_idx
)
->
add_params
(
set_feature
.
c_str
(),
set_feature
.
size
());
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
GetCmdChannel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
int32_t
GraphBrpcClient
::
Initialize
()
{
// set_shard_num(_config.shard_num());
BrpcPsClient
::
Initialize
();
server_size
=
GetServerNums
();
graph_service
=
NULL
;
local_channel
=
NULL
;
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/graph_brpc_client.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace
paddle
{
namespace
distributed
{
class
GraphPsService_Stub
:
public
PsService_Stub
{
public:
GraphPsService_Stub
(
::
google
::
protobuf
::
RpcChannel
*
channel
,
::
google
::
protobuf
::
RpcChannel
*
local_channel
=
NULL
,
GraphBrpcService
*
service
=
NULL
,
int
thread_num
=
1
)
:
PsService_Stub
(
channel
)
{
this
->
local_channel
=
local_channel
;
this
->
graph_service
=
service
;
task_pool
.
reset
(
new
::
ThreadPool
(
thread_num
));
}
virtual
~
GraphPsService_Stub
()
{}
// implements PsService ------------------------------------------
GraphBrpcService
*
graph_service
;
std
::
shared_ptr
<::
ThreadPool
>
task_pool
;
::
google
::
protobuf
::
RpcChannel
*
local_channel
;
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS
(
GraphPsService_Stub
);
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
distributed
::
PsRequestMessage
*
request
,
::
paddle
::
distributed
::
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
);
};
class
GraphBrpcClient
:
public
BrpcPsClient
{
public:
GraphBrpcClient
()
{}
virtual
~
GraphBrpcClient
()
{}
// given a batch of nodes, sample graph_neighbors for each of them
virtual
std
::
future
<
int32_t
>
batch_sample_neighbors
(
uint32_t
table_id
,
int
idx
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
vector
<
int64_t
>>&
res
,
std
::
vector
<
std
::
vector
<
float
>>&
res_weight
,
bool
need_weight
,
int
server_index
=
-
1
);
virtual
std
::
future
<
int32_t
>
pull_graph_list
(
uint32_t
table_id
,
int
type_id
,
int
idx
,
int
server_index
,
int
start
,
int
size
,
int
step
,
std
::
vector
<
FeatureNode
>&
res
);
virtual
std
::
future
<
int32_t
>
random_sample_nodes
(
uint32_t
table_id
,
int
type_id
,
int
idx
,
int
server_index
,
int
sample_size
,
std
::
vector
<
int64_t
>&
ids
);
virtual
std
::
future
<
int32_t
>
get_node_feat
(
const
uint32_t
&
table_id
,
int
idx
,
const
std
::
vector
<
int64_t
>&
node_ids
,
const
std
::
vector
<
std
::
string
>&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>&
res
);
virtual
std
::
future
<
int32_t
>
set_node_feat
(
const
uint32_t
&
table_id
,
int
idx
,
const
std
::
vector
<
int64_t
>&
node_ids
,
const
std
::
vector
<
std
::
string
>&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
features
);
virtual
std
::
future
<
int32_t
>
clear_nodes
(
uint32_t
table_id
,
int
type_id
,
int
idx
);
virtual
std
::
future
<
int32_t
>
add_graph_node
(
uint32_t
table_id
,
int
idx
,
std
::
vector
<
int64_t
>&
node_id_list
,
std
::
vector
<
bool
>&
is_weighted_list
);
virtual
std
::
future
<
int32_t
>
remove_graph_node
(
uint32_t
table_id
,
int
idx_
,
std
::
vector
<
int64_t
>&
node_id_list
);
virtual
int32_t
Initialize
();
int
get_shard_num
()
{
return
shard_num
;
}
void
set_shard_num
(
int
shard_num
)
{
this
->
shard_num
=
shard_num
;
}
int
get_server_index_by_id
(
int64_t
id
);
void
set_local_channel
(
int
index
)
{
this
->
local_channel
=
GetCmdChannel
(
index
);
}
void
set_local_graph_service
(
GraphBrpcService
*
graph_service
)
{
this
->
graph_service
=
graph_service
;
}
GraphPsService_Stub
getServiceStub
(
::
google
::
protobuf
::
RpcChannel
*
channel
,
int
thread_num
=
1
)
{
return
GraphPsService_Stub
(
channel
,
local_channel
,
graph_service
,
thread_num
);
}
private:
int
shard_num
;
size_t
server_size
;
::
google
::
protobuf
::
RpcChannel
*
local_channel
;
GraphBrpcService
*
graph_service
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include <thread> // NOLINT
#include <utility>
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
distributed
{
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t
GraphBrpcServer
::
Initialize
()
{
auto
&
service_config
=
_config
.
downpour_server_param
().
service_param
();
if
(
!
service_config
.
has_service_class
())
{
LOG
(
ERROR
)
<<
"miss service_class in ServerServiceParameter"
;
return
-
1
;
}
auto
*
service
=
CREATE_PSCORE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
if
(
service
==
NULL
)
{
LOG
(
ERROR
)
<<
"service is unregistered, service_name:"
<<
service_config
.
service_class
();
return
-
1
;
}
_service
.
reset
(
service
);
if
(
service
->
Configure
(
this
)
!=
0
||
service
->
Initialize
()
!=
0
)
{
LOG
(
ERROR
)
<<
"service initialize failed, service_name:"
<<
service_config
.
service_class
();
return
-
1
;
}
if
(
_server
.
AddService
(
service
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
)
!=
0
)
{
LOG
(
ERROR
)
<<
"service add to brpc failed, service:"
<<
service_config
.
service_class
();
return
-
1
;
}
return
0
;
}
brpc
::
Channel
*
GraphBrpcServer
::
GetCmdChannel
(
size_t
server_index
)
{
return
_pserver_channels
[
server_index
].
get
();
}
uint64_t
GraphBrpcServer
::
Start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
string
ip_port
=
ip
+
":"
+
std
::
to_string
(
port
);
VLOG
(
3
)
<<
"server of rank "
<<
_rank
<<
" starts at "
<<
ip_port
;
brpc
::
ServerOptions
options
;
int
num_threads
=
std
::
thread
::
hardware_concurrency
();
auto
trainers
=
_environment
->
GetTrainers
();
options
.
num_threads
=
trainers
>
num_threads
?
trainers
:
num_threads
;
if
(
_server
.
Start
(
ip_port
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"GraphBrpcServer start failed, ip_port="
<<
ip_port
;
return
0
;
}
_environment
->
RegistePsServer
(
ip
,
port
,
_rank
);
return
0
;
}
int32_t
GraphBrpcServer
::
build_peer2peer_connection
(
int
rank
)
{
this
->
rank
=
rank
;
auto
_env
=
Environment
();
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
500000
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
10000
;
options
.
max_retry
=
3
;
std
::
vector
<
PSHost
>
server_list
=
_env
->
GetPsServers
();
_pserver_channels
.
resize
(
server_list
.
size
());
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
for
(
size_t
i
=
0
;
i
<
server_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
server_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
server_list
[
i
].
port
));
_pserver_channels
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
_pserver_channels
[
i
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"GraphServer connect to Server:"
<<
server_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
server_list
[
i
].
ip
,
server_list
[
i
].
port
);
if
(
_pserver_channels
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"GraphServer connect to Server:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
os
<<
server_ip_port
<<
","
;
}
LOG
(
INFO
)
<<
"servers peer2peer connection success:"
<<
os
.
str
();
return
0
;
}
int32_t
GraphBrpcService
::
clear_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int
type_id
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
int
idx_
=
*
(
int
*
)(
request
.
params
(
1
).
c_str
());
((
GraphTable
*
)
table
)
->
clear_nodes
(
type_id
,
idx_
);
return
0
;
}
int32_t
GraphBrpcService
::
add_graph_node
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"add_graph_node request requires at least 2 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
int64_t
);
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
1
).
c_str
());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
std
::
vector
<
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
bool
>
is_weighted_list
;
if
(
request
.
params_size
()
==
3
)
{
size_t
weight_list_size
=
request
.
params
(
2
).
size
()
/
sizeof
(
bool
);
bool
*
is_weighted_buffer
=
(
bool
*
)(
request
.
params
(
2
).
c_str
());
is_weighted_list
=
std
::
vector
<
bool
>
(
is_weighted_buffer
,
is_weighted_buffer
+
weight_list_size
);
}
// if (request.params_size() == 2) {
// size_t weight_list_size = request.params(1).size() / sizeof(bool);
// bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
// is_weighted_list = std::vector<bool>(is_weighted_buffer,
// is_weighted_buffer +
// weight_list_size);
// }
((
GraphTable
*
)
table
)
->
add_graph_node
(
idx_
,
node_ids
,
is_weighted_list
);
return
0
;
}
int32_t
GraphBrpcService
::
remove_graph_node
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"remove_graph_node request requires at least 2 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
int64_t
);
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
1
).
c_str
());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
std
::
vector
<
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
((
GraphTable
*
)
table
)
->
remove_graph_node
(
idx_
,
node_ids
);
return
0
;
}
int32_t
GraphBrpcServer
::
Port
()
{
return
_server
.
listen_address
().
port
;
}
int32_t
GraphBrpcService
::
Initialize
()
{
_is_initialize_shard_info
=
false
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
GraphBrpcService
::
StopServer
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
GraphBrpcService
::
LoadOneTable
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
GraphBrpcService
::
LoadAllTable
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
GraphBrpcService
::
PrintTableStat
;
_service_handler_map
[
PS_BARRIER
]
=
&
GraphBrpcService
::
Barrier
;
_service_handler_map
[
PS_START_PROFILER
]
=
&
GraphBrpcService
::
StartProfiler
;
_service_handler_map
[
PS_STOP_PROFILER
]
=
&
GraphBrpcService
::
StopProfiler
;
_service_handler_map
[
PS_PULL_GRAPH_LIST
]
=
&
GraphBrpcService
::
pull_graph_list
;
_service_handler_map
[
PS_GRAPH_SAMPLE_NEIGHBORS
]
=
&
GraphBrpcService
::
graph_random_sample_neighbors
;
_service_handler_map
[
PS_GRAPH_SAMPLE_NODES
]
=
&
GraphBrpcService
::
graph_random_sample_nodes
;
_service_handler_map
[
PS_GRAPH_GET_NODE_FEAT
]
=
&
GraphBrpcService
::
graph_get_node_feat
;
_service_handler_map
[
PS_GRAPH_CLEAR
]
=
&
GraphBrpcService
::
clear_nodes
;
_service_handler_map
[
PS_GRAPH_ADD_GRAPH_NODE
]
=
&
GraphBrpcService
::
add_graph_node
;
_service_handler_map
[
PS_GRAPH_REMOVE_GRAPH_NODE
]
=
&
GraphBrpcService
::
remove_graph_node
;
_service_handler_map
[
PS_GRAPH_SET_NODE_FEAT
]
=
&
GraphBrpcService
::
graph_set_node_feat
;
_service_handler_map
[
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
]
=
&
GraphBrpcService
::
sample_neighbors_across_multi_servers
;
// _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
// &GraphBrpcService::use_neighbors_sample_cache;
// _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
// &GraphBrpcService::load_graph_split_config;
// shard初始化,server启动后才可从env获取到server_list的shard信息
InitializeShardInfo
();
return
0
;
}
int32_t
GraphBrpcService
::
InitializeShardInfo
()
{
if
(
!
_is_initialize_shard_info
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
_initialize_shard_mutex
);
if
(
_is_initialize_shard_info
)
{
return
0
;
}
server_size
=
_server
->
Environment
()
->
GetPsServers
().
size
();
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
itr
:
table_map
)
{
itr
.
second
->
SetShard
(
_rank
,
server_size
);
}
_is_initialize_shard_info
=
true
;
}
return
0
;
}
void
GraphBrpcService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
log_label
(
"ReceiveCmd-"
);
if
(
!
request
->
has_table_id
())
{
set_response_code
(
*
response
,
-
1
,
"PsRequestMessage.tabel_id is required"
);
return
;
}
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
auto
*
table
=
_server
->
GetTable
(
request
->
table_id
());
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_base
);
auto
itr
=
_service_handler_map
.
find
(
request
->
cmd_id
());
if
(
itr
==
_service_handler_map
.
end
())
{
std
::
string
err_msg
(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"
);
err_msg
.
append
(
std
::
to_string
(
request
->
cmd_id
()));
set_response_code
(
*
response
,
-
1
,
err_msg
.
c_str
());
return
;
}
serviceFunc
handler_func
=
itr
->
second
;
int
service_ret
=
(
this
->*
handler_func
)(
table
,
*
request
,
*
response
,
cntl
);
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
service_ret
);
if
(
!
response
->
has_err_msg
())
{
response
->
set_err_msg
(
"server internal error"
);
}
}
}
int32_t
GraphBrpcService
::
Barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
auto
trainer_id
=
request
.
client_id
();
auto
barrier_type
=
request
.
params
(
0
);
table
->
Barrier
(
trainer_id
,
barrier_type
);
return
0
;
}
int32_t
GraphBrpcService
::
PrintTableStat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
std
::
pair
<
int64_t
,
int64_t
>
ret
=
table
->
PrintTableStat
();
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
ret
.
first
<<
ret
.
second
;
std
::
string
table_info
(
ar
.
Buffer
(),
ar
.
Length
());
response
.
set_data
(
table_info
);
return
0
;
}
int32_t
GraphBrpcService
::
LoadOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 2 for path & load_param"
);
return
-
1
;
}
if
(
table
->
Load
(
request
.
params
(
0
),
request
.
params
(
1
))
!=
0
)
{
set_response_code
(
response
,
-
1
,
"table load failed"
);
return
-
1
;
}
return
0
;
}
int32_t
GraphBrpcService
::
LoadAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
&
itr
:
table_map
)
{
if
(
LoadOneTable
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
LOG
(
ERROR
)
<<
"load table["
<<
itr
.
first
<<
"] failed"
;
return
-
1
;
}
}
return
0
;
}
int32_t
GraphBrpcService
::
StopServer
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
GraphBrpcServer
*
p_server
=
(
GraphBrpcServer
*
)
_server
;
std
::
thread
t_stop
([
p_server
]()
{
p_server
->
Stop
();
LOG
(
INFO
)
<<
"Server Stoped"
;
});
p_server
->
export_cv
()
->
notify_all
();
t_stop
.
detach
();
return
0
;
}
int32_t
GraphBrpcService
::
StopProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"server_%s_profile"
,
_rank
));
return
0
;
}
int32_t
GraphBrpcService
::
StartProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
return
0
;
}
int32_t
GraphBrpcService
::
pull_graph_list
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
5
)
{
set_response_code
(
response
,
-
1
,
"pull_graph_list request requires at least 5 arguments"
);
return
0
;
}
int
type_id
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
int
idx
=
*
(
int
*
)(
request
.
params
(
1
).
c_str
());
int
start
=
*
(
int
*
)(
request
.
params
(
2
).
c_str
());
int
size
=
*
(
int
*
)(
request
.
params
(
3
).
c_str
());
int
step
=
*
(
int
*
)(
request
.
params
(
4
).
c_str
());
// int start = *(int *)(request.params(0).c_str());
// int size = *(int *)(request.params(1).c_str());
// int step = *(int *)(request.params(2).c_str());
std
::
unique_ptr
<
char
[]
>
buffer
;
int
actual_size
;
((
GraphTable
*
)
table
)
->
pull_graph_list
(
type_id
,
idx
,
start
,
size
,
buffer
,
actual_size
,
false
,
step
);
cntl
->
response_attachment
().
append
(
buffer
.
get
(),
actual_size
);
return
0
;
}
int32_t
GraphBrpcService
::
graph_random_sample_neighbors
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
4
)
{
set_response_code
(
response
,
-
1
,
"graph_random_sample_neighbors request requires at least 3 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
int64_t
);
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
1
).
c_str
());
int
sample_size
=
*
(
int64_t
*
)(
request
.
params
(
2
).
c_str
());
bool
need_weight
=
*
(
bool
*
)(
request
.
params
(
3
).
c_str
());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
// int sample_size = *(int64_t *)(request.params(1).c_str());
// bool need_weight = *(bool *)(request.params(2).c_str());
std
::
vector
<
std
::
shared_ptr
<
char
>>
buffers
(
node_num
);
std
::
vector
<
int
>
actual_sizes
(
node_num
,
0
);
((
GraphTable
*
)
table
)
->
random_sample_neighbors
(
idx_
,
node_data
,
sample_size
,
buffers
,
actual_sizes
,
need_weight
);
cntl
->
response_attachment
().
append
(
&
node_num
,
sizeof
(
size_t
));
cntl
->
response_attachment
().
append
(
actual_sizes
.
data
(),
sizeof
(
int
)
*
node_num
);
for
(
size_t
idx
=
0
;
idx
<
node_num
;
++
idx
)
{
cntl
->
response_attachment
().
append
(
buffers
[
idx
].
get
(),
actual_sizes
[
idx
]);
}
return
0
;
}
int32_t
GraphBrpcService
::
graph_random_sample_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int
type_id
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
int
idx_
=
*
(
int
*
)(
request
.
params
(
1
).
c_str
());
size_t
size
=
*
(
int64_t
*
)(
request
.
params
(
2
).
c_str
());
// size_t size = *(int64_t *)(request.params(0).c_str());
std
::
unique_ptr
<
char
[]
>
buffer
;
int
actual_size
;
if
(((
GraphTable
*
)
table
)
->
random_sample_nodes
(
type_id
,
idx_
,
size
,
buffer
,
actual_size
)
==
0
)
{
cntl
->
response_attachment
().
append
(
buffer
.
get
(),
actual_size
);
}
else
cntl
->
response_attachment
().
append
(
NULL
,
0
);
return
0
;
}
int32_t
GraphBrpcService
::
graph_get_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
3
)
{
set_response_code
(
response
,
-
1
,
"graph_get_node_feat request requires at least 3 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
int64_t
);
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
1
).
c_str
());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
std
::
vector
<
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
std
::
string
>
feature_names
=
paddle
::
string
::
split_string
<
std
::
string
>
(
request
.
params
(
2
),
"
\t
"
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
feature
(
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_num
));
((
GraphTable
*
)
table
)
->
get_node_feat
(
idx_
,
node_ids
,
feature_names
,
feature
);
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
size_t
feat_len
=
feature
[
feat_idx
][
node_idx
].
size
();
cntl
->
response_attachment
().
append
(
&
feat_len
,
sizeof
(
size_t
));
cntl
->
response_attachment
().
append
(
feature
[
feat_idx
][
node_idx
].
data
(),
feat_len
);
}
}
return
0
;
}
int32_t
GraphBrpcService
::
sample_neighbors_across_multi_servers
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
// sleep(5);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
4
)
{
set_response_code
(
response
,
-
1
,
"sample_neighbors_across_multi_servers request requires "
"at least 4 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
int64_t
);
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
1
).
c_str
());
int
sample_size
=
*
(
int64_t
*
)(
request
.
params
(
2
).
c_str
());
bool
need_weight
=
*
(
int64_t
*
)(
request
.
params
(
3
).
c_str
());
// size_t node_num = request.params(0).size() / sizeof(int64_t),
// size_of_size_t = sizeof(size_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
// int sample_size = *(int64_t *)(request.params(1).c_str());
// bool need_weight = *(int64_t *)(request.params(2).c_str());
// std::vector<int64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
std
::
vector
<
int64_t
>
local_id
;
std
::
vector
<
int
>
local_query_idx
;
size_t
rank
=
GetRank
();
for
(
size_t
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
}
if
(
server2request
[
rank
]
!=
-
1
)
{
auto
pos
=
server2request
[
rank
];
std
::
swap
(
request2server
[
pos
],
request2server
[(
int
)
request2server
.
size
()
-
1
]);
server2request
[
request2server
[
pos
]]
=
pos
;
server2request
[
request2server
[(
int
)
request2server
.
size
()
-
1
]]
=
request2server
.
size
()
-
1
;
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
shared_ptr
<
char
>>
local_buffers
;
std
::
vector
<
int
>
local_actual_sizes
;
std
::
vector
<
size_t
>
seq
;
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_data
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
seq
.
push_back
(
request_idx
);
}
size_t
remote_call_num
=
request_call_num
;
if
(
request2server
.
size
()
!=
0
&&
static_cast
<
size_t
>
(
request2server
.
back
())
==
rank
)
{
remote_call_num
--
;
local_buffers
.
resize
(
node_id_buckets
.
back
().
size
());
local_actual_sizes
.
resize
(
node_id_buckets
.
back
().
size
());
}
cntl
->
response_attachment
().
append
(
&
node_num
,
sizeof
(
size_t
));
auto
local_promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int
>
local_fut
=
local_promise
->
get_future
();
std
::
vector
<
bool
>
failed
(
server_size
,
false
);
std
::
function
<
void
(
void
*
)
>
func
=
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
local_fut
.
get
();
std
::
vector
<
int
>
actual_size
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
std
::
vector
<
std
::
unique_ptr
<
butil
::
IOBufBytesIterator
>>
res
(
remote_call_num
);
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
remote_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_SAMPLE_NEIGHBORS
)
!=
0
)
{
++
fail_num
;
failed
[
request2server
[
request_idx
]]
=
true
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
request_idx
)
->
response_attachment
();
res
[
request_idx
].
reset
(
new
butil
::
IOBufBytesIterator
(
res_io_buffer
));
size_t
num
;
res
[
request_idx
]
->
copy_and_forward
(
&
num
,
sizeof
(
size_t
));
}
}
int
size
;
int
local_index
=
0
;
for
(
size_t
i
=
0
;
i
<
node_num
;
i
++
)
{
if
(
fail_num
>
0
&&
failed
[
seq
[
i
]])
{
size
=
0
;
}
else
if
(
static_cast
<
size_t
>
(
request2server
[
seq
[
i
]])
!=
rank
)
{
res
[
seq
[
i
]]
->
copy_and_forward
(
&
size
,
sizeof
(
int
));
}
else
{
size
=
local_actual_sizes
[
local_index
++
];
}
actual_size
.
push_back
(
size
);
}
cntl
->
response_attachment
().
append
(
actual_size
.
data
(),
actual_size
.
size
()
*
sizeof
(
int
));
local_index
=
0
;
for
(
size_t
i
=
0
;
i
<
node_num
;
i
++
)
{
if
(
fail_num
>
0
&&
failed
[
seq
[
i
]])
{
continue
;
}
else
if
(
static_cast
<
size_t
>
(
request2server
[
seq
[
i
]])
!=
rank
)
{
char
temp
[
actual_size
[
i
]
+
1
];
res
[
seq
[
i
]]
->
copy_and_forward
(
temp
,
actual_size
[
i
]);
cntl
->
response_attachment
().
append
(
temp
,
actual_size
[
i
]);
}
else
{
char
*
temp
=
local_buffers
[
local_index
++
].
get
();
cntl
->
response_attachment
().
append
(
temp
,
actual_size
[
i
]);
}
}
closure
->
set_promise_value
(
0
);
};
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
remote_call_num
,
func
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
request_idx
=
0
;
request_idx
<
remote_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NEIGHBORS
);
closure
->
request
(
request_idx
)
->
set_table_id
(
request
.
table_id
());
closure
->
request
(
request_idx
)
->
set_client_id
(
rank
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
int64_t
)
*
node_num
);
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
need_weight
,
sizeof
(
bool
));
PsService_Stub
rpc_stub
(
((
GraphBrpcServer
*
)
GetServer
())
->
GetCmdChannel
(
server_index
));
// GraphPsService_Stub rpc_stub =
// getServiceStub(GetCmdChannel(server_index));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
if
(
server2request
[
rank
]
!=
-
1
)
{
((
GraphTable
*
)
table
)
->
random_sample_neighbors
(
idx_
,
node_id_buckets
.
back
().
data
(),
sample_size
,
local_buffers
,
local_actual_sizes
,
need_weight
);
}
local_promise
.
get
()
->
set_value
(
0
);
if
(
remote_call_num
==
0
)
func
(
closure
);
fut
.
get
();
return
0
;
}
int32_t
GraphBrpcService
::
graph_set_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
4
)
{
set_response_code
(
response
,
-
1
,
"graph_set_node_feat request requires at least 3 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)(
request
.
params
(
0
).
c_str
());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
int64_t
);
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
1
).
c_str
());
std
::
vector
<
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
// std::vector<std::string> feature_names =
// paddle::string::split_string<std::string>(request.params(1), "\t");
std
::
vector
<
std
::
string
>
feature_names
=
paddle
::
string
::
split_string
<
std
::
string
>
(
request
.
params
(
2
),
"
\t
"
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
(
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_num
));
// const char *buffer = request.params(2).c_str();
const
char
*
buffer
=
request
.
params
(
3
).
c_str
();
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
size_t
feat_len
=
*
(
size_t
*
)(
buffer
);
buffer
+=
sizeof
(
size_t
);
auto
feat
=
std
::
string
(
buffer
,
feat_len
);
features
[
feat_idx
][
node_idx
]
=
feat
;
buffer
+=
feat_len
;
}
}
((
GraphTable
*
)
table
)
->
set_node_feat
(
idx_
,
node_ids
,
feature_names
,
features
);
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/graph_brpc_server.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/ps/service/server.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace
paddle
{
namespace
distributed
{
class
GraphBrpcServer
:
public
PSServer
{
public:
GraphBrpcServer
()
{}
virtual
~
GraphBrpcServer
()
{}
PsBaseService
*
get_service
()
{
return
_service
.
get
();
}
virtual
uint64_t
Start
(
const
std
::
string
&
ip
,
uint32_t
port
);
virtual
int32_t
build_peer2peer_connection
(
int
rank
);
virtual
brpc
::
Channel
*
GetCmdChannel
(
size_t
server_index
);
virtual
int32_t
Stop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
stoped_
)
return
0
;
stoped_
=
true
;
// cv_.notify_all();
_server
.
Stop
(
1000
);
_server
.
Join
();
return
0
;
}
int32_t
Port
();
std
::
condition_variable
*
export_cv
()
{
return
&
cv_
;
}
private:
virtual
int32_t
Initialize
();
mutable
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
bool
stoped_
=
false
;
int
rank
;
brpc
::
Server
_server
;
std
::
shared_ptr
<
PsBaseService
>
_service
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
};
class
GraphBrpcService
;
typedef
int32_t
(
GraphBrpcService
::*
serviceFunc
)(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
class
GraphBrpcService
:
public
PsBaseService
{
public:
virtual
int32_t
Initialize
()
override
;
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
protected:
std
::
unordered_map
<
int32_t
,
serviceFunc
>
_service_handler_map
;
int32_t
InitializeShardInfo
();
int32_t
pull_graph_list
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_random_sample_neighbors
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_random_sample_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_get_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_set_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
clear_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
add_graph_node
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
remove_graph_node
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
Barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
LoadOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
LoadAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
StopServer
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
StartProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
StopProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
PrintTableStat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
sample_neighbors_across_multi_servers
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
use_neighbors_sample_cache
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
load_graph_split_config
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
private:
bool
_is_initialize_shard_info
;
std
::
mutex
_initialize_shard_mutex
;
std
::
unordered_map
<
int32_t
,
serviceHandlerFunc
>
_msg_handler_map
;
std
::
vector
<
float
>
_ori_values
;
const
int
sample_nodes_ranges
=
23
;
size_t
server_size
;
std
::
shared_ptr
<::
ThreadPool
>
task_pool
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/heter_client.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32
(
heter_world_size
,
100
,
"group size"
);
// group max size
DEFINE_int32
(
switch_send_recv_timeout_s
,
600
,
"switch_send_recv_timeout_s"
);
namespace
paddle
{
namespace
distributed
{
std
::
shared_ptr
<
HeterClient
>
HeterClient
::
s_instance_
=
nullptr
;
std
::
mutex
HeterClient
::
mtx_
;
std
::
shared_ptr
<
HeterClient
>
HeterClient
::
switch_s_instance_
=
nullptr
;
int
GetMicroId
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
)
{
framework
::
Variable
*
var
=
scope
->
FindVar
(
"microbatch_id"
);
PADDLE_ENFORCE_EQ
(
var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"the type of micro id shoulde be LoDTensor."
));
auto
micro_id
=
-
1
;
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
if
(
platform
::
is_cpu_place
(
tensor
->
place
()))
{
auto
data
=
reinterpret_cast
<
const
float
*>
(
tensor
->
data
());
micro_id
=
static_cast
<
int
>
(
data
[
0
]);
}
else
{
#ifdef PADDLE_WITH_CUDA
std
::
vector
<
char
>
temp
;
temp
.
resize
(
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
()));
char
*
temp_ptr
=
temp
.
data
();
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
platform
::
CPUPlace
(),
temp_ptr
,
tensor
->
place
(),
tensor
->
data
(),
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
()),
stream
);
float
*
temp_ptr_float
=
reinterpret_cast
<
float
*>
(
temp_ptr
);
micro_id
=
static_cast
<
int
>
(
temp_ptr_float
[
0
]);
#endif
}
return
micro_id
;
}
void
HeterClient
::
Stop
()
{
auto
status
=
StopHeterWorker
();
status
.
wait
();
}
std
::
future
<
int32_t
>
HeterClient
::
StopHeterWorker
()
{
return
SendCmd
(
-
1
,
PS_STOP_SERVER
,
{});
}
std
::
future
<
int32_t
>
HeterClient
::
StartProfiler
()
{
return
SendCmd
(
-
1
,
PS_START_PROFILER
,
{});
}
std
::
future
<
int32_t
>
HeterClient
::
StopProfiler
()
{
return
SendCmd
(
-
1
,
PS_STOP_PROFILER
,
{});
}
void
HeterClient
::
CreateClient2XpuConnection
()
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
connection_type
=
"single"
;
options
.
timeout_ms
=
FLAGS_pserver_timeout_ms
;
xpu_channels_
.
resize
(
xpu_list_
.
size
());
for
(
size_t
i
=
0
;
i
<
xpu_list_
.
size
();
++
i
)
{
xpu_channels_
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
xpu_channels_
[
i
]
->
Init
(
xpu_list_
[
i
].
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"HeterClient channel init fail. Try Again"
;
auto
ip_port
=
paddle
::
string
::
Split
(
xpu_list_
[
i
],
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
xpu_channels_
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPsServer start failed, ip_port= "
<<
int_ip_port
;
}
}
}
previous_xpu_channels_
.
resize
(
previous_xpu_list_
.
size
());
for
(
size_t
i
=
0
;
i
<
previous_xpu_list_
.
size
();
++
i
)
{
previous_xpu_channels_
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
previous_xpu_channels_
[
i
]
->
Init
(
previous_xpu_list_
[
i
].
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"HeterClient channel init fail. Try Again"
;
auto
ip_port
=
paddle
::
string
::
Split
(
previous_xpu_list_
[
i
],
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
previous_xpu_channels_
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPsServer start failed, ip_port= "
<<
int_ip_port
;
}
}
}
}
void
HeterClient
::
SendAndRecvAsync
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_name
,
const
std
::
vector
<
std
::
string
>&
recv_var_name
,
const
std
::
string
&
mode
)
{
platform
::
RecordEvent
record_event
(
"HeterClient->SendAndRecvAsync"
,
platform
::
TracerEventType
::
Communication
,
1
);
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
std
::
vector
<
std
::
string
>
send_var_name_val
=
send_var_name
;
const
std
::
vector
<
std
::
string
>
recv_var_name_val
=
recv_var_name
;
VLOG
(
3
)
<<
"BRPCClient::SendAndRecv Begin, message_name: "
<<
message_name
;
brpc
::
Channel
*
channel
=
nullptr
;
distributed
::
MultiVarMsg
request
;
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
PADDLE_ENFORCE_NE
(
closure
->
cntl
.
Failed
(),
true
,
platform
::
errors
::
Unimplemented
(
"HeterClient::SendAndRecv meets brpc error, error message is %s"
,
closure
->
cntl
.
ErrorText
()));
VLOG
(
4
)
<<
"call heter_worker success"
;
});
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
auto
&
request_io_buffer
=
closure
->
cntl
.
request_attachment
();
distributed
::
SerializeToMultiVarMsgAndIOBuf
(
message_name
,
send_var_name_val
,
recv_var_name_val
,
*
p_ctx
,
p_scope
,
&
request
,
&
request_io_buffer
);
int
micro_id
=
GetMicroId
(
ctx
,
p_scope
);
// global
auto
minibatch_id
=
micro_id
/
10
;
VLOG
(
4
)
<<
"micro_id: "
<<
micro_id
;
// select channel according to micro id
if
(
mode
==
"forward"
)
{
int
num
=
minibatch_id
%
xpu_channels_
.
size
();
channel
=
xpu_channels_
[
num
].
get
();
}
else
if
(
mode
==
"backward"
)
{
int
num
=
minibatch_id
%
previous_xpu_channels_
.
size
();
channel
=
previous_xpu_channels_
[
num
].
get
();
}
else
if
(
mode
==
"send_to_switch"
)
{
VLOG
(
4
)
<<
"calling switch service"
;
// auto promise = std::make_shared<std::promise<int32_t>>();
// closure->add_promise(promise);
// std::future<int> fut = promise->get_future();
// int idx = 1; // for test
// LOG(INFO) << "xpu_channels_ size: " << xpu_channels_.size();
// channel = xpu_channels_[idx].get(); // 为了适配 send_and_recv op
// ::paddle::distributed::PsService_Stub stub(channel);
// stub.SendToSwitch(&closure->cntl, &request, &closure->response,
// closure); fut.wait();
VLOG
(
4
)
<<
"calling switch service done"
;
return
;
}
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
SendAndRecvVariable
(
&
closure
->
cntl
,
&
request
,
&
closure
->
response
,
closure
);
}
std
::
future
<
int32_t
>
HeterClient
::
SendCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>&
params
)
{
size_t
request_call_num
=
xpu_channels_
.
size
();
paddle
::
distributed
::
DownpourBrpcClosure
*
closure
=
new
paddle
::
distributed
::
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
cmd_id
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
cmd_id
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
cmd_id
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
trainer_id_
);
for
(
const
auto
&
param
:
params
)
{
closure
->
request
(
i
)
->
add_params
(
param
);
}
::
paddle
::
distributed
::
PsService_Stub
rpc_stub
(
xpu_channels_
[
i
].
get
());
closure
->
cntl
(
i
)
->
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
int
HeterClient
::
Send
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_names
)
{
const
framework
::
Scope
*
p_scope
=
&
scope
;
// 注意是 const
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
int
ret
=
0
;
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
PADDLE_ENFORCE_NE
(
closure
->
cntl
.
Failed
(),
true
,
platform
::
errors
::
Unimplemented
(
"HeterClient::SendToSwitch meets brpc error, error message is %s"
,
closure
->
cntl
.
ErrorText
()));
}
});
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
auto
&
request_io_buffer
=
closure
->
cntl
.
request_attachment
();
distributed
::
MultiVarMsg
request
;
// 1. set req message_name(string)
request
.
set_message_name
(
message_name
);
request
.
set_group_id
(
0
);
// 2. set req send_var_names(<string>)
for
(
auto
&
send_var_name
:
send_var_names
)
{
request
.
add_send_var_names
(
send_var_name
);
}
// 3. set req var_messages(<VarMessage>)
for
(
auto
&
send_var_name
:
send_var_names
)
{
auto
*
send_var_msg
=
request
.
add_var_messages
();
send_var_msg
->
set_varname
(
send_var_name
);
framework
::
Variable
*
var
=
p_scope
->
FindVar
(
send_var_name
);
butil
::
IOBuf
temp_iobuf
;
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
SerializeLodTensor
(
var
,
ctx
,
send_var_msg
,
&
temp_iobuf
);
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
SerializeSelectedRows
(
var
,
ctx
,
send_var_msg
,
&
temp_iobuf
);
}
request_io_buffer
.
append
(
temp_iobuf
);
}
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
send_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"send_switch_channels_ is null, get xpu_channels_[0]"
;
if
(
xpu_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"xpu_channels_ is null"
;
}
send_switch_channels_
.
push_back
(
xpu_channels_
[
0
]);
}
brpc
::
Channel
*
channel
=
send_switch_channels_
[
0
].
get
();
// brpc::Channel* channel = xpu_channels_[0].get();
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
SendToSwitch
(
&
closure
->
cntl
,
&
request
,
&
closure
->
ps_response
,
closure
);
VLOG
(
4
)
<<
"waiting SendToSwitch response result......"
;
fut
.
wait
();
VLOG
(
4
)
<<
"Send done"
;
return
0
;
}
int
HeterClient
::
Send
(
int
group_id
,
const
std
::
vector
<
std
::
string
>&
var_names
,
const
std
::
vector
<
int64_t
>&
vars_size
,
void
*
data_ptr
,
int64_t
data_size
)
{
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
int
ret
=
0
;
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
LOG
(
ERROR
)
<<
"Send meets brpc error, err msg is %s"
<<
closure
->
cntl
.
ErrorText
();
}
});
distributed
::
MultiVarMsg
request
;
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
std
::
string
message_name
=
"send and save"
;
request
.
set_message_name
(
message_name
);
request
.
set_group_id
(
group_id
);
for
(
auto
&
send_var_name
:
var_names
)
{
request
.
add_send_var_names
(
send_var_name
);
}
for
(
auto
var_len
:
vars_size
)
{
request
.
add_vars_len
(
var_len
);
}
auto
&
request_buffer
=
closure
->
cntl
.
request_attachment
();
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
data_ptr
),
data_size
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
send_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"send_switch_channels_ is null, get xpu_channels_[0]"
;
if
(
xpu_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"xpu_channels_ is null"
;
}
send_switch_channels_
.
push_back
(
xpu_channels_
[
0
]);
}
brpc
::
Channel
*
channel
=
send_switch_channels_
[
0
].
get
();
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
SendToSwitch
(
&
closure
->
cntl
,
&
request
,
&
closure
->
ps_response
,
closure
);
fut
.
wait
();
delete
closure
;
return
0
;
}
int
HeterClient
::
Recv
(
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
&
recv_scope
,
// NOLINT
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
recv_var_names
)
{
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
VLOG
(
4
)
<<
"Recv service call done"
;
int
ret
=
0
;
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
VLOG
(
4
)
<<
"HeterClient::RecvFromSwitch meets "
"brpc error, error message is %s"
<<
closure
->
cntl
.
ErrorText
();
}
});
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
distributed
::
MultiVarMsg
request
;
// 1. set req message_name(string)
request
.
set_message_name
(
message_name
);
request
.
set_group_id
(
0
);
// 2. set req recv_var_names(<string>)
for
(
auto
&
recv_var_name
:
recv_var_names
)
{
request
.
add_recv_var_names
(
recv_var_name
);
}
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
recv_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"peer_switch_channels_ is null, get xpu_channels_[1]"
;
if
(
xpu_channels_
.
size
()
<
2
)
{
LOG
(
ERROR
)
<<
"xpu_channels_ is null"
;
}
recv_switch_channels_
.
push_back
(
xpu_channels_
[
1
]);
}
brpc
::
Channel
*
channel
=
recv_switch_channels_
[
0
].
get
();
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
RecvFromSwitch
(
&
closure
->
cntl
,
&
request
,
&
closure
->
response
,
closure
);
fut
.
wait
();
VLOG
(
4
)
<<
"RecvFromSwitch done"
;
// save in worker
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CPUPlace
cpu_place
;
auto
&
cpu_dev_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
&
res_io_buffer
=
closure
->
cntl
.
response_attachment
();
VLOG
(
4
)
<<
"entering DeserializeFromMultiVarMsgAndIOBuf"
;
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
closure
->
response
,
&
res_io_buffer
,
cpu_dev_ctx
,
&
recv_scope
);
VLOG
(
4
)
<<
"Recv done"
;
return
0
;
}
int
HeterClient
::
Recv
(
int
group_id
,
const
std
::
vector
<
std
::
string
>&
var_names
,
void
*
data_ptr
,
int64_t
data_size
)
{
OnHeterRpcDone
*
closure
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
int
ret
=
0
;
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
LOG
(
ERROR
)
<<
"Recv meets brpc error, err msg is %s"
<<
closure
->
cntl
.
ErrorText
();
}
});
closure
->
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
distributed
::
MultiVarMsg
request
;
std
::
string
message_name
=
"query and recv"
;
request
.
set_message_name
(
message_name
);
request
.
set_group_id
(
group_id
);
for
(
auto
&
recv_var_name
:
var_names
)
{
request
.
add_recv_var_names
(
recv_var_name
);
}
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
recv_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"peer_switch_channels_ is null, get xpu_channels_[1]"
;
if
(
xpu_channels_
.
size
()
<
2
)
{
LOG
(
ERROR
)
<<
"xpu_channels_ is null"
;
}
recv_switch_channels_
.
push_back
(
xpu_channels_
[
0
]);
}
brpc
::
Channel
*
channel
=
recv_switch_channels_
[
0
].
get
();
::
paddle
::
distributed
::
PsService_Stub
stub
(
channel
);
stub
.
RecvFromSwitch
(
&
closure
->
cntl
,
&
request
,
&
closure
->
response
,
closure
);
fut
.
wait
();
VLOG
(
4
)
<<
"RecvFromSwitch done"
;
// save in worker
auto
&
res_io_buffer
=
closure
->
cntl
.
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
data_ptr
),
data_size
);
delete
closure
;
VLOG
(
4
)
<<
"Recv done"
;
return
0
;
}
}
// namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/ps/service/heter_client.h
0 → 100644
View file @
d2d32668
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/string/split.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
DECLARE_int32
(
pserver_timeout_ms
);
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
HeterRpcCallbackFunc
;
class
OnHeterRpcDone
:
public
google
::
protobuf
::
Closure
{
public:
explicit
OnHeterRpcDone
(
HeterRpcCallbackFunc
func
)
:
handler_
(
func
)
{}
virtual
~
OnHeterRpcDone
()
{}
void
Run
()
{
handler_
(
this
);
}
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>&
promise
)
{
// NOLINT
_promises
.
push_back
(
promise
);
}
void
set_promise_value
(
int
value
)
{
for
(
auto
&
promise
:
_promises
)
{
promise
->
set_value
(
value
);
}
}
int
CheckResponse
()
{
return
0
;
}
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
HeterRpcCallbackFunc
handler_
;
MultiVariableMessage
request
;
MultiVariableMessage
response
;
PsResponseMessage
ps_response
;
brpc
::
Controller
cntl
;
// PsRequestMessage *request(size_t i) { return &_requests[i]; }
// PsResponseMessage *response(size_t i) { return &_responses[i]; }
// std::vector<PsRequestMessage> _requests;
// std::vector<PsResponseMessage> _responses;
// std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
class
HeterClient
{
public:
virtual
~
HeterClient
()
{}
void
InitClientChannels
(
bool
need_encrypt
,
const
std
::
vector
<
std
::
string
>&
node_list
,
int32_t
peer_role
)
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
connection_type
=
"single"
;
options
.
timeout_ms
=
FLAGS_pserver_timeout_ms
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>*
client_channels
=
nullptr
;
if
(
peer_role
==
PEER_ROLE_IS_SWITCH
)
{
#ifdef PADDLE_WITH_ARM_BRPC
if
(
need_encrypt
)
{
options
.
mutable_ssl_options
();
}
options
.
connection_type
=
""
;
VLOG
(
4
)
<<
"ssl enabled in arm"
;
#else
options
.
ssl_options
.
enable
=
need_encrypt
;
#endif
client_channels
=
&
peer_switch_channels_
;
}
else
if
(
peer_role
==
PEER_ROLE_IS_WORKER
)
{
client_channels
=
&
peer_worker_channels_
;
}
else
{
LOG
(
ERROR
)
<<
"init switch client failed, peer_role not valid"
;
}
(
*
client_channels
).
resize
(
node_list
.
size
());
for
(
size_t
i
=
0
;
i
<
node_list
.
size
();
++
i
)
{
(
*
client_channels
)[
i
].
reset
(
new
brpc
::
Channel
());
if
((
*
client_channels
)[
i
]
->
Init
(
node_list
[
i
].
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"client channel init failed! try again"
;
auto
ip_port
=
paddle
::
string
::
Split
(
node_list
[
i
],
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
((
*
client_channels
)[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"client channel init failed! peer ip_port = "
<<
int_ip_port
;
}
}
}
VLOG
(
4
)
<<
"InitClientChannels success"
;
}
void
CreateClient2XpuConnection
();
void
SendAndRecvAsync
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_name
,
const
std
::
vector
<
std
::
string
>&
recv_var_name
,
const
std
::
string
&
mode
=
"forward"
);
int
Send
(
int
group_id
,
const
std
::
vector
<
std
::
string
>&
var_names
,
const
std
::
vector
<
int64_t
>&
vars_len
,
void
*
data_ptr
,
int64_t
data_size
);
int
Send
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_names
);
int
Recv
(
int
group_id
,
const
std
::
vector
<
std
::
string
>&
var_names
,
void
*
data_ptr
,
int64_t
data_size
);
int
Recv
(
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
&
recv_scope
,
// NOLINT
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
recv_var_names
);
// HeterClient singleton
static
std
::
shared_ptr
<
HeterClient
>
GetInstance
(
const
std
::
vector
<
std
::
string
>&
endpoints
,
const
std
::
vector
<
std
::
string
>&
previous_endpoints
,
const
int
&
trainer_id
)
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
HeterClient
());
s_instance_
->
SetXpuList
(
endpoints
);
s_instance_
->
SetPreviousXpuList
(
previous_endpoints
);
s_instance_
->
SetTrainerID
(
trainer_id
);
s_instance_
->
CreateClient2XpuConnection
();
}
return
s_instance_
;
}
// switch client singleton
static
std
::
shared_ptr
<
HeterClient
>
GetSwitchInstance
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
,
int32_t
peer_role
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mtx_
);
if
(
peer_endpoints
.
empty
())
{
VLOG
(
4
)
<<
"init switch client failed, null peer_endpoints"
;
}
VLOG
(
4
)
<<
"peer role is: "
<<
peer_role
<<
", addr is: "
<<
peer_endpoints
[
0
];
if
(
switch_s_instance_
==
nullptr
)
{
switch_s_instance_
.
reset
(
new
HeterClient
());
switch_s_instance_
->
SetPeerSwitchList
(
peer_endpoints
);
switch_s_instance_
->
InitClientChannels
(
false
,
peer_endpoints
,
peer_role
);
}
return
switch_s_instance_
;
}
void
SetPeerSwitchList
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
)
{
peer_switch_list_
=
peer_endpoints
;
}
void
SetPeerWorkerList
(
const
std
::
vector
<
std
::
string
>&
worker_endpoints
)
{
peer_worker_list_
=
worker_endpoints
;
}
void
Stop
();
std
::
future
<
int32_t
>
SendCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>&
params
);
std
::
future
<
int32_t
>
StartProfiler
();
std
::
future
<
int32_t
>
StopProfiler
();
std
::
future
<
int32_t
>
StopHeterWorker
();
std
::
vector
<
std
::
string
>&
GetXpuList
()
{
return
xpu_list_
;
}
void
SetXpuList
(
const
std
::
vector
<
std
::
string
>&
xpu_list
)
{
xpu_list_
=
xpu_list
;
}
void
SetPreviousXpuList
(
const
std
::
vector
<
std
::
string
>&
xpu_list
)
{
previous_xpu_list_
=
xpu_list
;
}
void
SetTrainerID
(
const
int
&
trainer_id
)
{
trainer_id_
=
trainer_id
;
}
public:
std
::
vector
<
std
::
string
>
send_switch_list_
;
std
::
vector
<
std
::
string
>
recv_switch_list_
;
std
::
vector
<
std
::
string
>
peer_switch_list_
;
std
::
vector
<
std
::
string
>
peer_worker_list_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
send_switch_channels_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
recv_switch_channels_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
peer_switch_channels_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
peer_worker_channels_
;
private:
HeterClient
()
{}
HeterClient
&
operator
=
(
const
HeterClient
&
);
HeterClient
(
const
HeterClient
&
);
static
std
::
shared_ptr
<
HeterClient
>
s_instance_
;
static
std
::
mutex
mtx_
;
static
std
::
shared_ptr
<
HeterClient
>
switch_s_instance_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
xpu_channels_
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
previous_xpu_channels_
;
// DISABLE_COPY_AND_ASSIGN(HeterClient);
std
::
vector
<
std
::
string
>
xpu_list_
;
std
::
vector
<
std
::
string
>
previous_xpu_list_
;
int
trainer_id_
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/ps/service/heter_server.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_server.h"
#include "paddle/fluid/string/split.h"
namespace
paddle
{
namespace
distributed
{
// DEFINE_string(cert_path, "./cert.pem", "cert.pem path");
// DEFINE_string(key_path, "./key.pem", "key.pem path");
std
::
shared_ptr
<
HeterServer
>
HeterServer
::
s_instance_
=
nullptr
;
std
::
mutex
HeterServer
::
mtx_
;
void
HeterServer
::
RegisterServiceHandler
(
std
::
string
message_name
,
HeterServiceHandler
func
)
{
service_
.
RegisterServiceHandler
(
message_name
,
func
);
}
void
HeterServer
::
StartHeterService
(
bool
neeed_encrypt
)
{
server_
.
AddService
(
&
service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
);
brpc
::
ServerOptions
options
;
if
(
neeed_encrypt
)
{
#ifdef PADDLE_WITH_ARM_BRPC
options
.
mutable_ssl_options
()
->
default_cert
.
certificate
=
"/cert.pem"
;
options
.
mutable_ssl_options
()
->
default_cert
.
private_key
=
"/key.pem"
;
#else
options
.
ssl_options
.
default_cert
.
certificate
=
"/cert.pem"
;
options
.
ssl_options
.
default_cert
.
private_key
=
"/key.pem"
;
#endif
}
if
(
server_
.
Start
(
endpoint_
.
c_str
(),
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"HeterServer start fail. Try again."
;
auto
ip_port
=
paddle
::
string
::
Split
(
endpoint_
,
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
server_
.
Start
(
endpoint_
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"HeterServer start failed, ip_port= "
<<
int_ip_port
;
}
}
else
{
VLOG
(
0
)
<<
"heter server start success! listen on "
<<
endpoint_
;
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
stoped_
=
false
;
ready_
=
1
;
}
condition_ready_
.
notify_all
();
VLOG
(
4
)
<<
"stopped: "
<<
stoped_
<<
", ready_: "
<<
ready_
;
std
::
unique_lock
<
std
::
mutex
>
running_lock
(
mutex_
);
cv_
.
wait
(
running_lock
,
[
&
]
{
VLOG
(
4
)
<<
"Heter Server is Stop? "
<<
stoped_
;
return
stoped_
;
});
VLOG
(
4
)
<<
"start service done"
;
}
void
HeterServer
::
StartHeterInterService
(
bool
neeed_encrypt
)
{
server_inter_
.
AddService
(
&
service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
);
brpc
::
ServerOptions
options
;
if
(
neeed_encrypt
)
{
#ifdef PADDLE_WITH_ARM_BRPC
options
.
mutable_ssl_options
()
->
default_cert
.
certificate
=
"/cert.pem"
;
options
.
mutable_ssl_options
()
->
default_cert
.
private_key
=
"/key.pem"
;
#else
options
.
ssl_options
.
default_cert
.
certificate
=
"/cert.pem"
;
options
.
ssl_options
.
default_cert
.
private_key
=
"/key.pem"
;
#endif
}
if
(
server_inter_
.
Start
(
endpoint_inter_
.
c_str
(),
&
options
)
!=
0
)
{
VLOG
(
4
)
<<
"switch inter server start fail. Try again."
;
auto
ip_port
=
paddle
::
string
::
Split
(
endpoint_inter_
,
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
server_inter_
.
Start
(
endpoint_inter_
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"switch inter server start failed, ip_port= "
<<
int_ip_port
;
}
}
else
{
VLOG
(
4
)
<<
"switch inter server server start success! listen on "
<<
endpoint_inter_
;
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
stoped_
=
false
;
ready_
=
1
;
}
condition_ready_
.
notify_all
();
VLOG
(
4
)
<<
"stopped: "
<<
stoped_
<<
", ready_: "
<<
ready_
;
std
::
unique_lock
<
std
::
mutex
>
running_lock
(
mutex_
);
cv_
.
wait
(
running_lock
,
[
&
]
{
VLOG
(
4
)
<<
"Heter Server is Stop? "
<<
stoped_
;
return
stoped_
;
});
VLOG
(
4
)
<<
"start service done"
;
}
void
HeterServer
::
SetFanin
(
const
int
&
fan_in
)
{
service_
.
SetFanin
(
fan_in
);
}
void
HeterServer
::
WaitServerReady
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
}
int
SendAndRecvVariableHandler
::
SaveInSwitchWithShard
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
)
{
VLOG
(
4
)
<<
"entering SaveInSwitchWithShard"
;
int32_t
group_id
=
request
->
group_id
();
if
(
group_id
>=
FLAGS_heter_world_size
)
{
LOG
(
ERROR
)
<<
"group id exceed maxmium"
;
}
auto
&
local_shard
=
_local_shards
[
group_id
];
auto
&
request_io_buffer
=
cntl
->
request_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
request_io_buffer
);
for
(
int
idx
=
0
;
idx
<
request
->
send_var_names_size
();
idx
++
)
{
const
auto
&
var_name
=
request
->
send_var_names
(
idx
);
const
auto
&
var_size
=
request
->
vars_len
(
idx
);
WaitForVarsConsumed
(
group_id
,
var_name
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
auto
&
value
=
local_shard
[
var_name
];
value
.
resize
(
var_size
);
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
value
.
data
()),
var_size
);
vars_ready_flag
[
group_id
][
var_name
]
=
1
;
VLOG
(
4
)
<<
"saved var_name: "
<<
var_name
<<
"is saved ready!"
;
}
VLOG
(
4
)
<<
"SaveInSwitchWithShard success"
;
return
0
;
}
int
SendAndRecvVariableHandler
::
QueryInSwitchWithShard
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
{
VLOG
(
4
)
<<
"entering QueryInSwitchWithShard"
;
int32_t
group_id
=
request
->
group_id
();
VLOG
(
4
)
<<
"group id: "
<<
group_id
;
auto
&
local_shard
=
_local_shards
[
group_id
];
auto
&
response_io_buffer
=
cntl
->
response_attachment
();
auto
req_var_nums
=
request
->
recv_var_names_size
();
std
::
vector
<
std
::
string
>
req_var_names
(
req_var_nums
);
for
(
int
var_idx
=
0
;
var_idx
<
req_var_nums
;
++
var_idx
)
{
req_var_names
[
var_idx
]
=
request
->
recv_var_names
(
var_idx
);
}
auto
msg_name
=
request
->
message_name
();
response
->
set_message_name
(
msg_name
);
for
(
auto
&
req_var_name
:
req_var_names
)
{
VLOG
(
4
)
<<
"req var name: "
<<
req_var_name
;
response
->
add_send_var_names
(
req_var_name
);
WaitForVarsProduced
(
group_id
,
req_var_name
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
auto
itr
=
local_shard
.
find
(
req_var_name
);
auto
&
value
=
itr
.
value
();
response_io_buffer
.
append
(
value
.
data
(),
value
.
size
());
value
.
resize
(
0
);
// 清空内存
vars_ready_flag
[
group_id
][
req_var_name
]
=
0
;
VLOG
(
4
)
<<
"query var_name: "
<<
req_var_name
<<
"is consumed ready!"
;
}
VLOG
(
4
)
<<
"heter server QueryInSwitchWithShard done"
;
return
0
;
}
int
SendAndRecvVariableHandler
::
SaveInSwitchWithScope
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
)
{
VLOG
(
4
)
<<
"entering SaveInSwitchWithScope"
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CPUPlace
cpu_place
;
auto
&
cpu_dev_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
message_name
=
request
->
message_name
();
VLOG
(
4
)
<<
"message_name in heter server: "
<<
message_name
;
auto
send_var_nums
=
request
->
send_var_names_size
();
std
::
vector
<
std
::
string
>
send_var_names
(
send_var_nums
);
for
(
int
idx
=
0
;
idx
<
send_var_nums
;
idx
++
)
{
send_var_names
[
idx
]
=
request
->
var_messages
(
idx
).
varname
();
}
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
auto
local_scope
=
local_scope_ptr
.
get
();
if
(
!
local_scope
)
{
LOG
(
ERROR
)
<<
"local_scope_ptr is null in SaveInSwitchWithScope"
;
}
for
(
auto
var_name
:
send_var_names
)
{
auto
*
var_exist_ptr
=
local_scope
->
FindVar
(
var_name
);
if
(
!
var_exist_ptr
)
{
VLOG
(
4
)
<<
"not find var: "
<<
var_name
<<
" in local_scope"
;
}
WaitForVarsConsumed
(
0
,
var_name
);
}
auto
&
request_io_buffer
=
cntl
->
request_attachment
();
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
*
request
,
&
request_io_buffer
,
cpu_dev_ctx
,
local_scope
);
lk
.
unlock
();
for
(
auto
var_name
:
send_var_names
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
vars_ready_flag
[
0
][
var_name
]
=
1
;
}
VLOG
(
4
)
<<
"SaveInSwitchWithScope success"
;
return
0
;
}
int
SendAndRecvVariableHandler
::
QueryInSwitchWithScope
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
{
VLOG
(
4
)
<<
"entering QueryInSwitchWithScope"
;
auto
local_scope
=
local_scope_ptr
.
get
();
if
(
!
local_scope
)
{
LOG
(
INFO
)
<<
"local_scope is null"
;
}
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CPUPlace
cpu_place
;
auto
&
cpu_dev_ctx
=
*
pool
.
Get
(
cpu_place
);
// get req message_name & req_var_names
auto
msg_name
=
request
->
message_name
();
auto
req_var_nums
=
request
->
recv_var_names_size
();
std
::
vector
<
std
::
string
>
req_var_names
(
req_var_nums
);
for
(
int
var_idx
=
0
;
var_idx
<
req_var_nums
;
++
var_idx
)
{
req_var_names
[
var_idx
]
=
request
->
recv_var_names
(
var_idx
);
}
auto
&
response_io_buffer
=
cntl
->
response_attachment
();
// 1. fill message_name(string)
response
->
set_message_name
(
msg_name
);
// 2. fill var_names(string)
for
(
auto
&
req_var_name
:
req_var_names
)
{
response
->
add_send_var_names
(
req_var_name
);
}
// 3. fill var_messages(VarMessage)
for
(
auto
&
req_var_name
:
req_var_names
)
{
WaitForVarsProduced
(
0
,
req_var_name
);
auto
*
send_var_msg
=
response
->
add_var_messages
();
send_var_msg
->
set_varname
(
req_var_name
);
framework
::
Variable
*
var_ptr
;
var_ptr
=
local_scope
->
FindVar
(
req_var_name
);
if
(
!
var_ptr
)
{
LOG
(
INFO
)
<<
"local_scope not find var: "
<<
req_var_name
;
}
butil
::
IOBuf
temp_iobuf
;
if
(
var_ptr
->
IsType
<
framework
::
LoDTensor
>
())
{
SerializeLodTensor
(
var_ptr
,
cpu_dev_ctx
,
send_var_msg
,
&
temp_iobuf
);
}
else
if
(
var_ptr
->
IsType
<
phi
::
SelectedRows
>
())
{
SerializeSelectedRows
(
var_ptr
,
cpu_dev_ctx
,
send_var_msg
,
&
temp_iobuf
);
}
response_io_buffer
.
append
(
temp_iobuf
);
}
for
(
auto
&
req_var_name
:
req_var_names
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
vars_ready_flag
[
0
][
req_var_name
]
=
0
;
}
VLOG
(
4
)
<<
"heter server QueryInSwitchWithScope done"
;
return
0
;
}
}
// end namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/heter_server.h
0 → 100644
View file @
d2d32668
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
namespace
google
{
namespace
protobuf
{
class
Closure
;
class
RpcController
;
}
// namespace protobuf
}
// namespace google
namespace
paddle
{
namespace
framework
{
class
Executor
;
class
ProgramDesc
;
class
Scope
;
}
// namespace framework
}
// namespace paddle
DECLARE_double
(
eager_delete_tensor_gb
);
DECLARE_int32
(
pserver_timeout_ms
);
DECLARE_int32
(
heter_world_size
);
DECLARE_int32
(
switch_send_recv_timeout_s
);
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
MultiVariableMessage
;
using
VarMsg
=
VariableMessage
;
using
serviceHandler
=
std
::
function
<
int32_t
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
>
;
using
HeterServiceHandler
=
std
::
function
<
int32_t
(
const
MultiVarMsg
*
,
MultiVarMsg
*
,
brpc
::
Controller
*
)
>
;
using
HeterRpcCallbackFunc
=
std
::
function
<
void
(
void
*
)
>
;
class
ServiceHandlerBase
{
public:
ServiceHandlerBase
()
:
dev_ctx_
(
nullptr
),
scope_
(
nullptr
)
{}
virtual
~
ServiceHandlerBase
()
{}
void
SetScope
(
const
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
virtual
int
Handle
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
=
0
;
protected:
const
platform
::
DeviceContext
*
dev_ctx_
;
const
framework
::
Scope
*
scope_
;
};
using
SharedMiniScope
=
std
::
shared_ptr
<
std
::
unordered_map
<
int
,
::
paddle
::
framework
::
Scope
*>>
;
using
SharedMicroScope
=
std
::
shared_ptr
<
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
std
::
vector
<::
paddle
::
framework
::
Scope
*>>>>
;
using
SharedTaskQueue
=
std
::
shared_ptr
<
std
::
unordered_map
<
int
,
std
::
shared_ptr
<::
paddle
::
framework
::
BlockingQueue
<
std
::
pair
<
std
::
string
,
int
>>>>>
;
class
ValueInSwitch
{
public:
ValueInSwitch
()
{}
~
ValueInSwitch
()
{}
char
*
data
()
{
return
_data
.
data
();
}
size_t
size
()
{
return
_data
.
size
();
}
void
resize
(
size_t
size
)
{
_data
.
resize
(
size
);
}
void
shrink_to_fit
()
{
_data
.
shrink_to_fit
();
}
private:
std
::
vector
<
char
>
_data
;
};
class
SendAndRecvVariableHandler
final
:
public
ServiceHandlerBase
{
public:
SendAndRecvVariableHandler
()
{
this
->
num_microbatch_
=
0
;
this
->
num_minibatch_
=
0
;
_local_shards
.
reset
(
new
shard_type
[
FLAGS_heter_world_size
]);
}
virtual
~
SendAndRecvVariableHandler
()
{}
void
SetMiniScopes
(
SharedMiniScope
mini_scopes
)
{
mini_scopes_
=
mini_scopes
;
num_minibatch_
=
mini_scopes_
->
size
();
}
void
SetMicroScopes
(
SharedMicroScope
micro_scopes
)
{
micro_scopes_
=
micro_scopes
;
for
(
auto
&
scope_pair
:
(
*
micro_scopes_
))
{
// auto mini_idx = scope_pair.first;
auto
&
micro_scopes
=
scope_pair
.
second
;
num_microbatch_
=
micro_scopes
->
size
();
break
;
}
}
int
GetThreadNum
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
return
(
*
task_queue_
).
size
();
}
int
SaveInSwitchWithScope
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
);
void
WaitForVarsConsumed
(
int32_t
group_id
,
const
std
::
string
&
var_name
)
{
// timeline_.Start();
while
(
true
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
scope_mutex_
);
if
(
vars_ready_flag
[
group_id
][
var_name
]
==
0
)
{
break
;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not consumed exceed 10 miniutes";
break;
}
*/
}
return
;
}
void
WaitForVarsProduced
(
int32_t
group_id
,
const
std
::
string
&
var_name
)
{
// timeline_.Start();
while
(
true
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
scope_mutex_
);
if
(
vars_ready_flag
[
group_id
][
var_name
]
==
1
)
{
break
;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not produced exceed 10 miniutes";
break;
}
*/
}
return
;
}
int
SaveInSwitchWithShard
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
);
int
QueryInSwitchWithShard
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
);
int
QueryInSwitchWithScope
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
);
void
SetTaskQueue
(
SharedTaskQueue
task_queue
)
{
task_queue_
=
task_queue
;
}
int
Handle
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
override
{
LOG
(
INFO
)
<<
"entered Handle"
;
platform
::
RecordEvent
record_event
(
"SendAndRecvVariableHandler->Handle"
,
platform
::
TracerEventType
::
Communication
,
1
);
FLAGS_eager_delete_tensor_gb
=
-
1
;
// get microID from request
// deserialize variable to micro scope
// Push to heter worker's task_queue
std
::
unique_ptr
<
paddle
::
framework
::
Scope
>
local_scope_ptr
(
new
paddle
::
framework
::
Scope
());
auto
&
local_scope
=
*
(
local_scope_ptr
.
get
());
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CPUPlace
cpu_place
;
auto
&
cpu_dev_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
message_name
=
request
->
message_name
();
auto
&
request_io_buffer
=
cntl
->
request_attachment
();
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
*
request
,
&
request_io_buffer
,
cpu_dev_ctx
,
&
local_scope
);
auto
*
var
=
local_scope
.
FindVar
(
"microbatch_id"
);
PADDLE_ENFORCE_NE
(
var
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Not find variable microbatch_id in scope."
));
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
data
=
reinterpret_cast
<
const
float
*>
(
tensor
->
data
());
auto
micro_id
=
static_cast
<
int
>
(
data
[
0
]);
VLOG
(
4
)
<<
"micro_id in heter server: "
<<
micro_id
;
int
minibatch_index
=
micro_id
/
10
;
int
microbatch_index
=
micro_id
%
10
;
// check minibatch_index is in mini_scopes_
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
if
((
*
mini_scopes_
).
find
(
minibatch_index
)
!=
(
*
mini_scopes_
).
end
())
{
lk
.
unlock
();
PADDLE_ENFORCE_EQ
(
(
*
micro_scopes_
).
find
(
minibatch_index
)
!=
(
*
micro_scopes_
).
end
(),
1
,
platform
::
errors
::
InvalidArgument
(
"minibatch index should in current trainer"
));
}
else
{
// create mini scope & micro scopes
auto
*
minibatch_scope
=
&
(
scope_
->
NewScope
());
(
*
mini_scopes_
)[
minibatch_index
]
=
minibatch_scope
;
(
*
micro_scopes_
)[
minibatch_index
].
reset
(
new
std
::
vector
<
paddle
::
framework
::
Scope
*>
{});
for
(
int
i
=
0
;
i
<
num_microbatch_
;
i
++
)
{
auto
*
micro_scope
=
&
(
minibatch_scope
->
NewScope
());
(
*
((
*
micro_scopes_
)[
minibatch_index
])).
push_back
(
micro_scope
);
}
(
*
task_queue_
)[
minibatch_index
].
reset
(
new
::
paddle
::
framework
::
BlockingQueue
<
std
::
pair
<
std
::
string
,
int
>>
());
lk
.
unlock
();
}
auto
*
micro_scope
=
(
*
((
*
micro_scopes_
)[
minibatch_index
]))[
microbatch_index
];
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
*
request
,
&
request_io_buffer
,
*
dev_ctx_
,
micro_scope
);
// blocking queue handles multi thread
VLOG
(
4
)
<<
"Handle in HeterServer: "
<<
message_name
<<
", "
<<
microbatch_index
;
VLOG
(
4
)
<<
"task_queue_ size: "
<<
task_queue_
->
size
();
(
*
task_queue_
)[
minibatch_index
]
->
Push
(
std
::
make_pair
(
message_name
,
microbatch_index
));
auto
response_var_nums
=
request
->
recv_var_names_size
();
std
::
vector
<
std
::
string
>
response_var_names
(
response_var_nums
),
empty_var_names
{};
for
(
int
var_idx
=
0
;
var_idx
<
response_var_nums
;
++
var_idx
)
{
response_var_names
[
var_idx
]
=
request
->
recv_var_names
(
var_idx
);
}
auto
&
response_io_buffer
=
cntl
->
response_attachment
();
distributed
::
SerializeToMultiVarMsgAndIOBuf
(
message_name
,
response_var_names
,
empty_var_names
,
*
dev_ctx_
,
&
local_scope
,
response
,
&
response_io_buffer
);
VLOG
(
4
)
<<
"Handle over"
;
return
0
;
}
public:
using
shard_type
=
SparseTableShard
<
std
::
string
,
ValueInSwitch
>
;
std
::
shared_ptr
<
paddle
::
framework
::
Scope
>
local_scope_ptr
;
// for switch
std
::
unordered_map
<
uint32_t
,
std
::
unordered_map
<
std
::
string
,
uint32_t
>>
vars_ready_flag
;
std
::
unique_ptr
<
shard_type
[]
>
_local_shards
;
platform
::
Timer
timeline_
;
private:
// share with HeterPipelineTrainer
SharedMiniScope
mini_scopes_
{
nullptr
};
SharedMicroScope
micro_scopes_
{
nullptr
};
int
num_microbatch_
;
int
num_minibatch_
;
std
::
mutex
scope_mutex_
;
bool
is_first_stage_
=
false
;
bool
is_last_stage_
=
false
;
SharedTaskQueue
task_queue_
;
};
class
HeterService
:
public
PsService
{
public:
HeterService
()
{
_service_handler_map
[
PS_STOP_SERVER
]
=
std
::
bind
(
&
HeterService
::
stop_heter_worker
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
_service_handler_map
[
PS_START_PROFILER
]
=
std
::
bind
(
&
HeterService
::
start_profiler
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
_service_handler_map
[
PS_STOP_PROFILER
]
=
std
::
bind
(
&
HeterService
::
stop_profiler
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
service_handler_
.
local_scope_ptr
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
}
virtual
~
HeterService
()
{}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
auto
itr
=
_service_handler_map
.
find
(
request
->
cmd_id
());
if
(
itr
==
_service_handler_map
.
end
())
{
std
::
string
err_msg
(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"
);
err_msg
.
append
(
std
::
to_string
(
request
->
cmd_id
()));
return
;
}
serviceHandler
handler
=
itr
->
second
;
int
service_ret
=
handler
(
*
request
,
*
response
,
cntl
);
VLOG
(
4
)
<<
"handler in service ret: "
<<
service_ret
;
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
service_ret
);
response
->
set_err_msg
(
"server internal error"
);
}
}
virtual
void
SendAndRecvVariable
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
message_name
=
request
->
message_name
();
VLOG
(
0
)
<<
"SendAndRecvVariable message_name: "
<<
message_name
;
auto
itr
=
handler_map_
.
find
(
message_name
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
LOG
(
INFO
)
<<
"SendAndRecvVariable(client addr) ="
<<
cntl
->
remote_side
();
PADDLE_ENFORCE_NE
(
itr
,
handler_map_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_"
,
message_name
));
itr
->
second
(
request
,
response
,
cntl
);
// We don't want to call done->Run() here, release the guard.
// done_guard.release();
}
virtual
void
RecvFromSwitch
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
// int ret = service_handler_.QueryInSwitchWithScope(request, response,
// cntl);
int
ret
=
service_handler_
.
QueryInSwitchWithShard
(
request
,
response
,
cntl
);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// int ret = itr->second(request, response, cntl);
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"QueryInSwitchWithScope failed!"
;
}
// response->set_message_name(message_name);
}
virtual
void
SendToSwitch
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
VLOG
(
4
)
<<
"entering SendToSwitch"
;
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
shared_ptr
<
HeterClient
>
switch_client_ptr_
=
HeterClient
::
GetSwitchInstance
(
peer_endpoints_
,
PEER_ROLE_IS_SWITCH
);
if
(
switch_client_ptr_
->
peer_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"switch_client_ptr_->peer_switch_channels_ null"
;
}
brpc
::
Channel
*
channel
=
switch_client_ptr_
->
peer_switch_channels_
[
0
].
get
();
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
// proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset)
OnHeterRpcDone
*
closure2
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
int
ret
=
closure
->
CheckResponse
();
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
PADDLE_ENFORCE_NE
(
closure
->
cntl
.
Failed
(),
true
,
platform
::
errors
::
Unimplemented
(
"HeterClient::SendS2S meets brpc error, error message is %s"
,
closure
->
cntl
.
ErrorText
()));
}
});
auto
&
std_cntl
=
closure2
->
cntl
;
std_cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
std_cntl
.
request_attachment
().
append
(
cntl
->
request_attachment
().
movable
());
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure2
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
// brpc::Controller std_cntl;
// std_cntl.request_attachment().append(cntl->request_attachment().movable());
PsService_Stub
stub
(
channel
);
stub
.
SendS2S
(
&
std_cntl
,
request
,
response
,
closure2
);
cntl
->
response_attachment
().
append
(
std_cntl
.
response_attachment
().
movable
());
fut
.
wait
();
VLOG
(
4
)
<<
"SendToSwitch done"
;
delete
closure2
;
}
void
SendS2S
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
VLOG
(
4
)
<<
"entering SendS2S"
;
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
// int ret = service_handler_.SaveInSwitchWithScope(request, response,
// cntl);
int
ret
=
service_handler_
.
SaveInSwitchWithShard
(
request
,
response
,
cntl
);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// if (itr == handler_map_.end()) {
// LOG(ERROR) << "can not find func handler";
//}
// int ret = itr->second(request, response, cntl);
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"SaveInSwitchWithScope failed"
;
}
std
::
string
err_msg
=
"ok"
;
response
->
set_err_msg
(
err_msg
.
c_str
());
response
->
set_err_code
(
ret
);
VLOG
(
4
)
<<
"heter server SendS2S done"
;
}
void
SendToWorker
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
VLOG
(
4
)
<<
"SendToWorker(client addr) ="
<<
cntl
->
remote_side
();
std
::
shared_ptr
<
distributed
::
HeterClient
>
switch_client_ptr_
=
HeterClient
::
GetSwitchInstance
(
peer_endpoints_
,
PEER_ROLE_IS_WORKER
);
VLOG
(
4
)
<<
"in switch client, peer worker 0: "
<<
switch_client_ptr_
->
peer_worker_list_
[
0
];
brpc
::
Channel
*
channel
=
switch_client_ptr_
->
peer_worker_channels_
[
0
].
get
();
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
PsService_Stub
stub
(
channel
);
stub
.
SendAndRecvVariable
(
controller
,
request
,
&
closure
->
response
,
done
);
// fill response content
std
::
string
err_msg
(
"pass to worker"
);
response
->
set_err_msg
(
err_msg
.
c_str
());
response
->
set_err_code
(
0
);
}
void
RegisterServiceHandler
(
std
::
string
message_name
,
HeterServiceHandler
func
)
{
handler_map_
[
message_name
]
=
func
;
}
void
SetEndpoint
(
const
std
::
string
&
end_point
)
{
endpoint_
=
end_point
;
}
void
SetInterEndpoint
(
const
std
::
string
&
end_point
)
{
endpoint_inter_
=
end_point
;
}
void
SetPeerEndPoints
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
)
{
peer_endpoints_
=
peer_endpoints
;
}
void
SetFanin
(
const
int
&
fan_in
)
{
fan_in_
=
fan_in
;
}
void
ForceExit
()
{
VLOG
(
3
)
<<
"heter service force exit"
;
is_exit_
=
true
;
return
;
}
bool
IsExit
()
{
return
is_exit_
;
}
private:
int32_t
stop_profiler
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"heter_worker_%s_profile"
,
endpoint_
));
return
0
;
}
int32_t
start_profiler
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
return
0
;
}
int32_t
stop_heter_worker
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
{
auto
client_id
=
request
.
client_id
();
stop_cpu_worker_set_
.
insert
(
client_id
);
if
(
stop_cpu_worker_set_
.
size
()
==
fan_in_
)
{
is_exit_
=
true
;
}
return
0
;
}
private:
SendAndRecvVariableHandler
service_handler_
;
std
::
string
endpoint_
;
std
::
string
endpoint_inter_
;
// for switch
std
::
vector
<
std
::
string
>
peer_endpoints_
;
std
::
unordered_map
<
int32_t
,
serviceHandler
>
_service_handler_map
;
std
::
unordered_map
<
std
::
string
,
HeterServiceHandler
>
handler_map_
;
std
::
unordered_set
<
int
>
stop_cpu_worker_set_
;
uint32_t
fan_in_
;
bool
is_exit_
=
false
;
};
class
HeterServer
{
public:
HeterServer
()
:
ready_
(
0
)
{}
virtual
~
HeterServer
()
{}
void
Stop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
stoped_
==
true
)
return
;
if
(
!
IsExit
())
{
service_
.
ForceExit
();
}
stoped_
=
true
;
cv_
.
notify_all
();
server_
.
Stop
(
1000
);
server_
.
Join
();
}
bool
IsStop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
stoped_
;
}
bool
IsExit
()
{
return
service_
.
IsExit
();
}
void
RegisterServiceHandler
(
std
::
string
message_name
,
HeterServiceHandler
func
);
void
StartHeterService
(
bool
need_encrypt
=
false
);
void
StartHeterInterService
(
bool
need_encrypt
=
false
);
void
SetEndPoint
(
const
std
::
string
&
endpoint
)
{
this
->
endpoint_
=
endpoint
;
service_
.
SetEndpoint
(
endpoint
);
}
void
SetLocalScope
()
{
request_handler_
->
local_scope_ptr
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
}
void
SetInterEndpoint
(
const
std
::
string
&
endpoint
)
{
this
->
endpoint_inter_
=
endpoint
;
service_
.
SetInterEndpoint
(
endpoint
);
}
void
SetPeerEndPoints
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
)
{
this
->
peer_endpoints_
=
peer_endpoints
;
service_
.
SetPeerEndPoints
(
peer_endpoints
);
}
void
SetFanin
(
const
int
&
fan_in
);
void
SetServiceHandler
(
std
::
shared_ptr
<
SendAndRecvVariableHandler
>
request_handler
)
{
request_handler_
=
request_handler
;
}
void
SetMiniBatchScopes
(
SharedMiniScope
mini_scopes
)
{
request_handler_
->
SetMiniScopes
(
mini_scopes
);
}
void
SetMicroBatchScopes
(
SharedMicroScope
micro_scopes
)
{
request_handler_
->
SetMicroScopes
(
micro_scopes
);
}
int
GetThreadNum
()
{
return
request_handler_
->
GetThreadNum
();
}
void
SetTaskQueue
(
SharedTaskQueue
task_queue
)
{
request_handler_
->
SetTaskQueue
(
task_queue
);
}
// HeterWrapper singleton
static
std
::
shared_ptr
<
HeterServer
>
GetInstance
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mtx_
);
if
(
s_instance_
==
nullptr
)
{
s_instance_
.
reset
(
new
HeterServer
());
}
return
s_instance_
;
}
void
WaitServerReady
();
private:
static
std
::
shared_ptr
<
HeterServer
>
s_instance_
;
mutable
std
::
mutex
mutex_
;
static
std
::
mutex
mtx_
;
std
::
condition_variable
cv_
;
std
::
condition_variable
condition_ready_
;
bool
stoped_
=
true
;
std
::
string
endpoint_
;
std
::
string
endpoint_inter_
;
// for switch
std
::
vector
<
std
::
string
>
peer_endpoints_
;
protected:
brpc
::
Server
server_
;
brpc
::
Server
server_inter_
;
HeterService
service_
;
std
::
shared_ptr
<
SendAndRecvVariableHandler
>
request_handler_
;
DISABLE_COPY_AND_ASSIGN
(
HeterServer
);
std
::
mutex
mutex_ready_
;
int
ready_
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/ps/service/ps_client.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace
paddle
{
namespace
distributed
{
REGISTER_PSCORE_CLASS
(
PSClient
,
BrpcPsClient
);
REGISTER_PSCORE_CLASS
(
PSClient
,
PsLocalClient
);
REGISTER_PSCORE_CLASS
(
PSClient
,
GraphBrpcClient
);
int32_t
PSClient
::
Configure
(
const
PSParameter
&
config
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
&
regions
,
PSEnvironment
&
env
,
size_t
client_id
)
{
_env
=
&
env
;
_config
=
config
;
_dense_pull_regions
=
regions
;
_client_id
=
client_id
;
_config
.
mutable_worker_param
()
->
mutable_downpour_worker_param
()
->
mutable_downpour_table_param
()
->
CopyFrom
(
_config
.
server_param
()
.
downpour_server_param
()
.
downpour_table_param
());
const
auto
&
work_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
int
i
=
0
;
i
<
work_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
accessor
=
CREATE_PSCORE_CLASS
(
ValueAccessor
,
work_param
.
downpour_table_param
(
i
).
accessor
().
accessor_class
());
accessor
->
Configure
(
work_param
.
downpour_table_param
(
i
).
accessor
());
accessor
->
Initialize
();
_table_accessors
[
work_param
.
downpour_table_param
(
i
).
table_id
()].
reset
(
accessor
);
}
return
Initialize
();
}
PSClient
*
PSClientFactory
::
Create
(
const
PSParameter
&
ps_config
)
{
const
auto
&
config
=
ps_config
.
server_param
();
if
(
!
config
.
has_downpour_server_param
())
{
LOG
(
ERROR
)
<<
"miss downpour_server_param in ServerParameter"
;
return
NULL
;
}
if
(
!
config
.
downpour_server_param
().
has_service_param
())
{
LOG
(
ERROR
)
<<
"miss service_param in ServerParameter.downpour_server_param"
;
return
NULL
;
}
if
(
!
config
.
downpour_server_param
().
service_param
().
has_client_class
())
{
LOG
(
ERROR
)
<<
"miss client_class in "
"ServerParameter.downpour_server_param.service_param"
;
return
NULL
;
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSClient
*
client
=
CREATE_PSCORE_CLASS
(
PSClient
,
service_param
.
client_class
());
if
(
client
==
NULL
)
{
LOG
(
ERROR
)
<<
"client is not registered, server_name:"
<<
service_param
.
client_class
();
return
NULL
;
}
TableManager
::
Instance
().
Initialize
();
VLOG
(
3
)
<<
"Create PSClient["
<<
service_param
.
client_class
()
<<
"] success"
;
return
client
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_client.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
distributed
{
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
PSClientCallBack
;
class
PSClientClosure
:
public
google
::
protobuf
::
Closure
{
public:
explicit
PSClientClosure
(
PSClientCallBack
callback
)
:
_callback
(
callback
)
{}
virtual
~
PSClientClosure
()
{}
virtual
void
set_promise_value
(
int
value
)
{
for
(
auto
&
promise
:
_promises
)
{
promise
->
set_value
(
value
);
}
}
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
)
{
// NOLINT
_promises
.
push_back
(
promise
);
}
void
add_timer
(
std
::
shared_ptr
<
CostTimer
>
&
timer
)
{
// NOLINT
_timers
.
push_back
(
timer
);
}
protected:
PSClientCallBack
_callback
;
std
::
vector
<
std
::
shared_ptr
<
CostTimer
>>
_timers
;
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
};
class
PSClient
{
public:
PSClient
()
{}
virtual
~
PSClient
()
{}
PSClient
(
PSClient
&&
)
=
delete
;
PSClient
(
const
PSClient
&
)
=
delete
;
virtual
int32_t
Configure
(
// NOLINT
const
PSParameter
&
config
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
&
regions
,
PSEnvironment
&
_env
,
size_t
client_id
)
final
;
// NOLINT
virtual
int32_t
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
=
0
;
// 触发table数据退场
virtual
std
::
future
<
int32_t
>
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
=
0
;
// 全量table进行数据load
virtual
std
::
future
<
int32_t
>
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 指定table数据load
virtual
std
::
future
<
int32_t
>
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual
std
::
future
<
int32_t
>
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual
std
::
future
<
int32_t
>
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 清空table数据
virtual
std
::
future
<
int32_t
>
Clear
()
=
0
;
virtual
std
::
future
<
int32_t
>
Clear
(
uint32_t
table_id
)
=
0
;
// pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数
// future结束前keys和values缓冲区不能再次使用
// client将values按照区块拆包后送交多个sender
// sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual
std
::
future
<
int32_t
>
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
// 保留
// firstly push dense param for parameter server
// this is necessary because dense weight initialized in trainer on cold
// start
virtual
std
::
future
<
int32_t
>
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual
std
::
future
<
int32_t
>
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
=
0
;
virtual
std
::
future
<
int32_t
>
PullSparseParam
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
PullSparsePtr
(
char
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PrintTableStat
(
uint32_t
table_id
)
=
0
;
// 确保所有积攒中的请求都发起发送
virtual
std
::
future
<
int32_t
>
Flush
()
=
0
;
// server优雅退出
virtual
std
::
future
<
int32_t
>
StopServer
()
=
0
;
// server profilera
virtual
std
::
future
<
int32_t
>
StartProfiler
()
=
0
;
virtual
std
::
future
<
int32_t
>
StopProfiler
()
=
0
;
virtual
std
::
future
<
int32_t
>
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
)
=
0
;
virtual
std
::
future
<
int32_t
>
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
uint64_t
>
*
keys
,
int
pserver_idx
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
)
=
0
;
// recv table from server and save it in LodTensor
virtual
int32_t
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
=
0
;
virtual
void
FinalizeWorker
()
=
0
;
// client to client, 消息发送
virtual
std
::
future
<
int32_t
>
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
virtual
int
RegisteClient2ClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
_msg_handler_map
[
msg_type
]
=
handler
;
return
0
;
}
virtual
int
HandleClient2ClientMsg
(
int
msg_type
,
int
from_client_id
,
const
std
::
string
&
msg
)
{
auto
itr
=
_msg_handler_map
.
find
(
msg_type
);
if
(
itr
==
_msg_handler_map
.
end
())
{
LOG
(
WARNING
)
<<
"unknown client2client_msg type:"
<<
msg_type
;
return
-
1
;
}
return
itr
->
second
(
msg_type
,
from_client_id
,
msg
);
}
virtual
ValueAccessor
*
GetTableAccessor
(
size_t
table_id
)
{
auto
itr
=
_table_accessors
.
find
(
table_id
);
if
(
itr
==
_table_accessors
.
end
())
{
return
NULL
;
}
return
itr
->
second
.
get
();
}
virtual
size_t
GetServerNums
()
=
0
;
virtual
std
::
future
<
int32_t
>
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
done
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
=
0
;
// for save cache
virtual
std
::
future
<
int32_t
>
CacheShuffle
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
CacheShuffleMultiTable
(
std
::
vector
<
int
>
tables
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
SaveCache
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
GetCacheThreshold
(
uint32_t
table_id
,
double
&
cache_threshold
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
protected:
virtual
int32_t
Initialize
()
=
0
;
size_t
_client_id
;
PSParameter
_config
;
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
_dense_pull_regions
;
PSEnvironment
*
_env
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
ValueAccessor
>>
_table_accessors
;
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
// 处理client2client消息
};
template
<
class
T
>
class
AsyncRequestTask
{
public:
AsyncRequestTask
()
:
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{}
AsyncRequestTask
(
T
&
data
,
size_t
table_id
,
std
::
shared_ptr
<
CostTimer
>
&
timer
)
:
_table_id
(
table_id
),
_timer
(
timer
),
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{
_data
=
std
::
move
(
data
);
}
AsyncRequestTask
(
AsyncRequestTask
&
data
)
// NOLINT
:
_table_id
(
data
.
table_id
()),
_timer
(
data
.
timer
()),
_promise
(
data
.
promise
())
{
_data
=
std
::
move
(
data
.
data
());
}
~
AsyncRequestTask
()
{}
inline
T
&
data
()
{
return
_data
;
}
inline
size_t
table_id
()
{
return
_table_id
;
}
inline
std
::
shared_ptr
<
CostTimer
>
&
timer
()
{
return
_timer
;
}
inline
std
::
future
<
int32_t
>
get_future
()
{
return
_promise
->
get_future
();
}
inline
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
()
{
return
_promise
;
}
private:
T
_data
;
size_t
_table_id
;
std
::
shared_ptr
<
CostTimer
>
_timer
;
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
_promise
;
};
REGISTER_PSCORE_REGISTERER
(
PSClient
);
class
PSClientFactory
{
public:
static
PSClient
*
Create
(
const
PSParameter
&
config
);
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_local_client.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
//#define pslib_debug_dense_compress
namespace
paddle
{
namespace
distributed
{
int32_t
PsLocalClient
::
Initialize
()
{
const
auto
&
downpour_param
=
_config
.
server_param
().
downpour_server_param
();
TableManager
::
Instance
().
Initialize
();
for
(
int
i
=
0
;
i
<
downpour_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
table
=
CREATE_PSCORE_CLASS
(
Table
,
downpour_param
.
downpour_table_param
(
i
).
table_class
());
table
->
SetShard
(
0
,
1
);
table
->
Initialize
(
downpour_param
.
downpour_table_param
(
i
),
_config
.
fs_client_param
());
_table_map
[
downpour_param
.
downpour_table_param
(
i
).
table_id
()].
reset
(
table
);
}
return
0
;
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
{
// TODO
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
for
(
auto
&
it
:
_table_map
)
{
Load
(
it
.
first
,
epoch
,
mode
);
}
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
auto
*
table_ptr
=
GetTable
(
table_id
);
table_ptr
->
Load
(
epoch
,
mode
);
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
for
(
auto
&
it
:
_table_map
)
{
Save
(
it
.
first
,
epoch
,
mode
);
}
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
auto
*
table_ptr
=
GetTable
(
table_id
);
table_ptr
->
Flush
();
table_ptr
->
Save
(
epoch
,
mode
);
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Clear
()
{
// TODO
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Clear
(
uint32_t
table_id
)
{
// TODO
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Flush
()
{
// no need
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
StopServer
()
{
// no need
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
*
table_ptr
=
GetTable
(
table_id
);
uint32_t
num_per_shard
=
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
1
);
std
::
vector
<
float
>
region_buffer
;
region_buffer
.
resize
(
num_per_shard
);
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
pull_context
.
values
=
region_buffer
.
data
();
table_context
.
num
=
region_buffer
.
size
();
table_ptr
->
Pull
(
table_context
);
// table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t
region_idx
=
0
;
size_t
region_data_idx
=
0
;
size_t
shard_data_size
=
num_per_shard
;
size_t
shard_buffer_remain
=
shard_data_size
*
sizeof
(
float
);
PADDLE_ENFORCE_EQ
(
shard_buffer_remain
,
region_buffer
.
size
()
*
sizeof
(
float
),
platform
::
errors
::
PreconditionNotMet
(
"pull dense size error."
));
size_t
index
=
0
;
while
(
shard_buffer_remain
>
0
&&
region_idx
<
region_num
)
{
auto
&
region
=
regions
[
region_idx
];
if
(
region
.
size
-
region_data_idx
>=
shard_buffer_remain
)
{
memcpy
((
void
*
)(
region
.
data
+
region_data_idx
),
(
uint8_t
*
)(
void
*
)(
region_buffer
.
data
())
+
index
,
shard_buffer_remain
);
region_data_idx
+=
shard_buffer_remain
;
shard_buffer_remain
=
0
;
}
else
if
(
region
.
size
-
region_data_idx
==
0
)
{
++
region_idx
;
region_data_idx
=
0
;
}
else
{
memcpy
((
void
*
)(
region
.
data
+
region_data_idx
),
(
uint8_t
*
)(
void
*
)(
region_buffer
.
data
())
+
index
,
region
.
size
-
region_data_idx
);
shard_buffer_remain
-=
(
region
.
size
-
region_data_idx
);
index
+=
(
region
.
size
-
region_data_idx
);
++
region_idx
;
region_data_idx
=
0
;
}
}
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
*
table_ptr
=
GetTable
(
table_id
);
std
::
vector
<
float
>
region_buffer
;
region_buffer
.
resize
(
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
1
),
0
);
for
(
size_t
i
=
0
,
offset
=
0
;
i
<
region_num
;
++
i
)
{
uint32_t
data_num
=
regions
[
i
].
size
/
sizeof
(
float
);
memcpy
(
region_buffer
.
data
()
+
offset
,
regions
[
i
].
data
,
regions
[
i
].
size
);
offset
+=
data_num
;
}
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
region_buffer
.
data
();
table_context
.
push_context
.
is_param
=
true
;
table_context
.
num
=
region_buffer
.
size
();
table_ptr
->
Push
(
table_context
);
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
callback
)
{
VLOG
(
1
)
<<
"wxx push_dense_raw_gradient"
;
PSClientClosure
*
closure
=
reinterpret_cast
<
PSClientClosure
*>
(
callback
);
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
total_send_data
;
table_context
.
num
=
total_send_data_size
;
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr
->
Push
(
table_context
);
delete
closure
;
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
*
table_ptr
=
GetTable
(
table_id
);
std
::
vector
<
float
>
region_buffer
;
region_buffer
.
resize
(
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
1
));
size_t
data_size
=
region_buffer
.
size
();
for
(
size_t
i
=
0
,
offset
=
0
;
i
<
region_num
;
++
i
)
{
uint32_t
data_num
=
regions
[
i
].
size
/
sizeof
(
float
);
PADDLE_ENFORCE_LE
(
offset
+
data_num
,
data_size
,
platform
::
errors
::
PreconditionNotMet
(
"invalid dense size, cur pos[%d] data_num[%d] size[%d]"
,
offset
,
data_num
,
data_size
));
memcpy
(
region_buffer
.
data
()
+
offset
,
regions
[
i
].
data
,
regions
[
i
].
size
);
offset
+=
data_num
;
}
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
region_buffer
.
data
();
table_context
.
num
=
region_buffer
.
size
();
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr
->
Push
(
table_context
);
return
done
();
}
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id,
// const uint64_t* keys,
// size_t num) {
// // FIXME
// // auto timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size();
//
// // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float));
// table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float));
// size_t offset = 0;
// for (int i = 0; i < num; ++i) {
// memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
// offset += value_size;
// }
//
// // return fut;
// return done();
//}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PullSparsePtr
(
char
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
)
{
// FIXME
// auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
pull_context
.
keys
=
keys
;
table_context
.
pull_context
.
ptr_values
=
select_values
;
table_context
.
use_ptr
=
true
;
table_context
.
num
=
num
;
// table_ptr->PullSparsePtr(select_values, keys, num);
table_ptr
->
Pull
(
table_context
);
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
callback
)
{
PSClientClosure
*
closure
=
reinterpret_cast
<
PSClientClosure
*>
(
callback
);
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
keys
;
table_context
.
push_context
.
ptr_values
=
update_values
;
table_context
.
num
=
num
;
table_context
.
use_ptr
=
true
;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr
->
Push
(
table_context
);
delete
closure
;
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
{
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
keys
;
table_context
.
push_context
.
ptr_values
=
update_values
;
table_context
.
num
=
num
;
table_context
.
use_ptr
=
true
;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr
->
Push
(
table_context
);
return
done
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_local_client.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
namespace
paddle
{
namespace
distributed
{
class
Table
;
class
PsLocalClient
:
public
PSClient
{
public:
PsLocalClient
()
{}
virtual
~
PsLocalClient
()
{
_running
=
false
;
}
virtual
int32_t
CreateClient2ClientConnection
(
int
pslib_timeout_ms
,
int
pslib_connect_timeout_ms
,
int
max_retry
)
{
return
0
;
}
virtual
::
std
::
future
<
int32_t
>
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
override
;
virtual
::
std
::
future
<
int32_t
>
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Clear
()
override
;
virtual
::
std
::
future
<
int32_t
>
Clear
(
uint32_t
table_id
)
override
;
virtual
::
std
::
future
<
int32_t
>
StopServer
()
override
;
virtual
void
FinalizeWorker
()
override
{}
virtual
::
std
::
future
<
int32_t
>
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
PullSparsePtr
(
char
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
);
virtual
::
std
::
future
<
int32_t
>
PrintTableStat
(
uint32_t
table_id
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
);
virtual
::
std
::
future
<
int32_t
>
Flush
();
// server profilera
virtual
std
::
future
<
int32_t
>
StartProfiler
()
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
};
virtual
std
::
future
<
int32_t
>
StopProfiler
()
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>*
values
,
std
::
vector
<
uint64_t
>*
keys
,
int
pserver_idx
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
// recv table from server and save it in LodTensor
virtual
int32_t
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
{
return
0
;
}
virtual
::
std
::
future
<
int32_t
>
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
override
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
size_t
GetServerNums
()
{
return
1
;
}
virtual
std
::
future
<
int32_t
>
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
callback
)
override
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
callback
)
override
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
override
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
override
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
private:
virtual
int32_t
Initialize
()
override
;
std
::
future
<
int32_t
>
done
()
{
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
prom
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int32_t
>
fut
=
prom
->
get_future
();
prom
->
set_value
(
0
);
return
fut
;
}
inline
uint32_t
DenseDimPerShard
(
uint32_t
dense_dim_total
,
uint32_t
shard_num
)
{
return
dense_dim_total
/
shard_num
+
1
;
}
inline
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>*
GetTable
()
{
return
&
_table_map
;
}
inline
Table
*
GetTable
(
size_t
table_id
)
{
auto
itr
=
_table_map
.
find
(
table_id
);
if
(
itr
!=
_table_map
.
end
())
{
return
itr
->
second
.
get
();
}
LOG
(
ERROR
)
<<
"table not found "
<<
table_id
;
return
NULL
;
}
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
_table_map
;
bool
_running
=
false
;
bool
_flushing
=
false
;
private:
float
_mae
=
0
;
float
_mse
=
0
;
uint16_t
_push_times
=
0
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_local_server.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/ps/service/server.h"
namespace
paddle
{
namespace
distributed
{
class
PsLocalServer
:
public
PSServer
{
public:
PsLocalServer
()
{}
virtual
~
PsLocalServer
()
{}
virtual
uint64_t
Start
()
{
return
0
;
}
virtual
uint64_t
Start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
return
0
;
}
virtual
int32_t
Stop
()
{
return
0
;
}
virtual
int32_t
Configure
(
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{})
{
return
0
;
}
private:
virtual
int32_t
Initialize
()
{
return
0
;
}
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h"
#include <thread> // NOLINT
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace
paddle
{
namespace
distributed
{
std
::
vector
<
std
::
string
>
GraphPyService
::
split
(
std
::
string
&
str
,
const
char
pattern
)
{
std
::
vector
<
std
::
string
>
res
;
std
::
stringstream
input
(
str
);
std
::
string
temp
;
while
(
std
::
getline
(
input
,
temp
,
pattern
))
{
res
.
push_back
(
temp
);
}
return
res
;
}
void
GraphPyService
::
add_table_feat_conf
(
std
::
string
table_name
,
std
::
string
feat_name
,
std
::
string
feat_dtype
,
int
feat_shape
)
{
if
(
feature_to_id
.
find
(
table_name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
table_name
];
VLOG
(
0
)
<<
"for table name"
<<
table_name
<<
" idx = "
<<
idx
;
if
(
table_feat_mapping
[
idx
].
find
(
feat_name
)
==
table_feat_mapping
[
idx
].
end
())
{
VLOG
(
0
)
<<
"for table name not found,make a new one"
;
int
res
=
(
int
)
table_feat_mapping
[
idx
].
size
();
table_feat_mapping
[
idx
][
feat_name
]
=
res
;
VLOG
(
0
)
<<
"seq id = "
<<
table_feat_mapping
[
idx
][
feat_name
];
}
int
feat_idx
=
table_feat_mapping
[
idx
][
feat_name
];
VLOG
(
0
)
<<
"table_name "
<<
table_name
<<
" mapping id "
<<
idx
;
VLOG
(
0
)
<<
" feat name "
<<
feat_name
<<
" feat id"
<<
feat_idx
;
if
(
static_cast
<
size_t
>
(
feat_idx
)
<
table_feat_conf_feat_name
[
idx
].
size
())
{
// overide
table_feat_conf_feat_name
[
idx
][
feat_idx
]
=
feat_name
;
table_feat_conf_feat_dtype
[
idx
][
feat_idx
]
=
feat_dtype
;
table_feat_conf_feat_shape
[
idx
][
feat_idx
]
=
feat_shape
;
}
else
{
// new
table_feat_conf_feat_name
[
idx
].
push_back
(
feat_name
);
table_feat_conf_feat_dtype
[
idx
].
push_back
(
feat_dtype
);
table_feat_conf_feat_shape
[
idx
].
push_back
(
feat_shape
);
}
}
VLOG
(
0
)
<<
"add conf over"
;
}
void
add_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
bool
>
weight_list
)
{}
void
remove_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
)
{}
void
GraphPyService
::
set_up
(
std
::
string
ips_str
,
int
shard_num
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
edge_types
)
{
set_shard_num
(
shard_num
);
set_num_node_types
(
node_types
.
size
());
/*
int num_node_types;
std::unordered_map<std::string, uint32_t> edge_idx, feature_idx;
std::vector<std::unordered_map<std::string,uint32_t>> table_feat_mapping;
std::vector<std::vector<std::string>> table_feat_conf_feat_name;
std::vector<std::vector<std::string>> table_feat_conf_feat_dtype;
std::vector<std::vector<int32_t>> table_feat_conf_feat_shape;
*/
id_to_edge
=
edge_types
;
for
(
size_t
table_id
=
0
;
table_id
<
edge_types
.
size
();
table_id
++
)
{
int
res
=
(
int
)
edge_to_id
.
size
();
edge_to_id
[
edge_types
[
table_id
]]
=
res
;
}
id_to_feature
=
node_types
;
for
(
size_t
table_id
=
0
;
table_id
<
node_types
.
size
();
table_id
++
)
{
int
res
=
(
int
)
feature_to_id
.
size
();
feature_to_id
[
node_types
[
table_id
]]
=
res
;
}
table_feat_mapping
.
resize
(
node_types
.
size
());
this
->
table_feat_conf_feat_name
.
resize
(
node_types
.
size
());
this
->
table_feat_conf_feat_dtype
.
resize
(
node_types
.
size
());
this
->
table_feat_conf_feat_shape
.
resize
(
node_types
.
size
());
std
::
istringstream
stream
(
ips_str
);
std
::
string
ip
;
server_size
=
0
;
std
::
vector
<
std
::
string
>
ips_list
=
split
(
ips_str
,
';'
);
int
index
=
0
;
VLOG
(
0
)
<<
"start to build server"
;
for
(
auto
ips
:
ips_list
)
{
auto
ip_and_port
=
split
(
ips
,
':'
);
server_list
.
push_back
(
ip_and_port
[
0
]);
port_list
.
push_back
(
ip_and_port
[
1
]);
uint32_t
port
=
stoul
(
ip_and_port
[
1
]);
auto
ph_host
=
paddle
::
distributed
::
PSHost
(
ip_and_port
[
0
],
port
,
index
);
host_sign_list
.
push_back
(
ph_host
.
SerializeToString
());
index
++
;
}
VLOG
(
0
)
<<
"build server done"
;
}
void
GraphPyClient
::
start_client
()
{
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
dense_regions
;
dense_regions
.
insert
(
std
::
pair
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
(
0
,
{}));
auto
regions
=
dense_regions
[
0
];
::
paddle
::
distributed
::
PSParameter
worker_proto
=
GetWorkerProto
();
paddle
::
distributed
::
PaddlePSEnvironment
_ps_env
;
auto
servers_
=
host_sign_list
.
size
();
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
SetPsServers
(
&
host_sign_list
,
servers_
);
worker_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>
(
(
paddle
::
distributed
::
GraphBrpcClient
*
)
paddle
::
distributed
::
PSClientFactory
::
Create
(
worker_proto
));
worker_ptr
->
Configure
(
worker_proto
,
dense_regions
,
_ps_env
,
client_id
);
worker_ptr
->
set_shard_num
(
get_shard_num
());
}
void
GraphPyServer
::
start_server
(
bool
block
)
{
std
::
string
ip
=
server_list
[
rank
];
uint32_t
port
=
std
::
stoul
(
port_list
[
rank
]);
::
paddle
::
distributed
::
PSParameter
server_proto
=
this
->
GetServerProto
();
auto
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
SetPsServers
(
&
this
->
host_sign_list
,
this
->
host_sign_list
.
size
());
// test
pserver_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcServer
>
(
(
paddle
::
distributed
::
GraphBrpcServer
*
)
paddle
::
distributed
::
PSServerFactory
::
Create
(
server_proto
));
VLOG
(
0
)
<<
"pserver-ptr created "
;
std
::
vector
<
framework
::
ProgramDesc
>
empty_vec
;
framework
::
ProgramDesc
empty_prog
;
empty_vec
.
push_back
(
empty_prog
);
pserver_ptr
->
Configure
(
server_proto
,
_ps_env
,
rank
,
empty_vec
);
pserver_ptr
->
Start
(
ip
,
port
);
pserver_ptr
->
build_peer2peer_connection
(
rank
);
std
::
condition_variable
*
cv_
=
pserver_ptr
->
export_cv
();
if
(
block
)
{
std
::
mutex
mutex_
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
->
wait
(
lock
);
}
}
::
paddle
::
distributed
::
PSParameter
GraphPyServer
::
GetServerProto
()
{
// Generate server proto desc
::
paddle
::
distributed
::
PSParameter
server_fleet_desc
;
::
paddle
::
distributed
::
ServerParameter
*
server_proto
=
server_fleet_desc
.
mutable_server_param
();
::
paddle
::
distributed
::
DownpourServerParameter
*
downpour_server_proto
=
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"GraphBrpcService"
);
server_service_proto
->
set_server_class
(
"GraphBrpcServer"
);
server_service_proto
->
set_client_class
(
"GraphBrpcClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_server_thread_num
(
12
);
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::
paddle
::
distributed
::
TableParameter
*
sparse_table_proto
=
downpour_server_proto
->
add_downpour_table_param
();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto
(
sparse_table_proto
);
//}
return
server_fleet_desc
;
}
::
paddle
::
distributed
::
PSParameter
GraphPyClient
::
GetWorkerProto
()
{
::
paddle
::
distributed
::
PSParameter
worker_fleet_desc
;
::
paddle
::
distributed
::
WorkerParameter
*
worker_proto
=
worker_fleet_desc
.
mutable_worker_param
();
::
paddle
::
distributed
::
DownpourWorkerParameter
*
downpour_worker_proto
=
worker_proto
->
mutable_downpour_worker_param
();
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::
paddle
::
distributed
::
TableParameter
*
worker_sparse_table_proto
=
downpour_worker_proto
->
add_downpour_table_param
();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto
(
worker_sparse_table_proto
);
//}
::
paddle
::
distributed
::
ServerParameter
*
server_proto
=
worker_fleet_desc
.
mutable_server_param
();
::
paddle
::
distributed
::
DownpourServerParameter
*
downpour_server_proto
=
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"GraphBrpcService"
);
server_service_proto
->
set_server_class
(
"GraphBrpcServer"
);
server_service_proto
->
set_client_class
(
"GraphBrpcClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_server_thread_num
(
12
);
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::
paddle
::
distributed
::
TableParameter
*
sparse_table_proto
=
downpour_server_proto
->
add_downpour_table_param
();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto
(
sparse_table_proto
);
//}
return
worker_fleet_desc
;
}
void
GraphPyClient
::
load_edge_file
(
std
::
string
name
,
std
::
string
filepath
,
bool
reverse
)
{
// 'e' means load edge
std
::
string
params
=
"e"
;
if
(
reverse
)
{
// 'e<' means load edges from $2 to $1
params
+=
"<"
+
name
;
}
else
{
// 'e>' means load edges from $1 to $2
params
+=
">"
+
name
;
}
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
auto
status
=
get_ps_client
()
->
Load
(
0
,
std
::
string
(
filepath
),
params
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// VLOG(0) << "loadding data with type " << name << " from " << filepath;
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->Load(table_id, std::string(filepath), params);
// status.wait();
// }
}
void
GraphPyClient
::
clear_nodes
(
std
::
string
name
)
{
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
clear_nodes
(
0
,
0
,
idx
);
status
.
wait
();
}
else
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
clear_nodes
(
0
,
1
,
idx
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = get_ps_client()->clear_nodes(table_id);
// status.wait();
// }
}
void
GraphPyClient
::
add_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
,
std
::
vector
<
bool
>&
weight_list
)
{
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->add_graph_node(table_id, node_ids, weight_list);
// status.wait();
// }
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
add_graph_node
(
0
,
idx
,
node_ids
,
weight_list
);
status
.
wait
();
}
}
void
GraphPyClient
::
remove_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
)
{
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
remove_graph_node
(
0
,
idx
,
node_ids
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
// status.wait();
// }
}
void
GraphPyClient
::
load_node_file
(
std
::
string
name
,
std
::
string
filepath
)
{
// 'n' means load nodes and 'node_type' follows
std
::
string
params
=
"n"
+
name
;
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
auto
status
=
get_ps_client
()
->
Load
(
0
,
std
::
string
(
filepath
),
params
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->Load(table_id, std::string(filepath), params);
// status.wait();
// }
}
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
GraphPyClient
::
batch_sample_neighbors
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
bool
return_weight
,
bool
return_edges
)
{
std
::
vector
<
std
::
vector
<
int64_t
>>
v
;
std
::
vector
<
std
::
vector
<
float
>>
v1
;
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
batch_sample_neighbors
(
0
,
idx
,
node_ids
,
sample_size
,
v
,
v1
,
return_weight
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = worker_ptr->batch_sample_neighbors(
// table_id, node_ids, sample_size, v, v1, return_weight);
// status.wait();
// }
// res.first[0]: neighbors (nodes)
// res.first[1]: slice index
// res.first[2]: src nodes
// res.second: edges weight
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
res
;
res
.
first
.
push_back
({});
res
.
first
.
push_back
({});
if
(
return_edges
)
res
.
first
.
push_back
({});
for
(
size_t
i
=
0
;
i
<
v
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
v
[
i
].
size
();
j
++
)
{
// res.first[0].push_back(v[i][j].first);
res
.
first
[
0
].
push_back
(
v
[
i
][
j
]);
if
(
return_edges
)
res
.
first
[
2
].
push_back
(
node_ids
[
i
]);
if
(
return_weight
)
res
.
second
.
push_back
(
v1
[
i
][
j
]);
}
if
(
i
==
v
.
size
()
-
1
)
break
;
if
(
i
==
0
)
{
res
.
first
[
1
].
push_back
(
v
[
i
].
size
());
}
else
{
res
.
first
[
1
].
push_back
(
v
[
i
].
size
()
+
res
.
first
[
1
].
back
());
}
}
return
res
;
}
std
::
vector
<
int64_t
>
GraphPyClient
::
random_sample_nodes
(
std
::
string
name
,
int
server_index
,
int
sample_size
)
{
std
::
vector
<
int64_t
>
v
;
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
random_sample_nodes
(
0
,
1
,
idx
,
server_index
,
sample_size
,
v
);
status
.
wait
();
}
else
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
random_sample_nodes
(
0
,
0
,
idx
,
server_index
,
sample_size
,
v
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// worker_ptr->random_sample_nodes(table_id, server_index, sample_size,
// v);
// status.wait();
// }
return
v
;
}
// (name, dtype, ndarray)
std
::
vector
<
std
::
vector
<
std
::
string
>>
GraphPyClient
::
get_node_feat
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
)
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
v
(
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_ids
.
size
()));
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
get_node_feat
(
0
,
idx
,
node_ids
,
feature_names
,
v
);
status
.
wait
();
}
// if (this->table_id_map.count(node_type)) {
// uint32_t table_id = this->table_id_map[node_type];
// auto status =
// worker_ptr->get_node_feat(table_id, node_ids, feature_names, v);
// status.wait();
// }
return
v
;
}
void
GraphPyClient
::
set_node_feat
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
)
{
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
set_node_feat
(
0
,
idx
,
node_ids
,
feature_names
,
features
);
status
.
wait
();
}
// if (this->table_id_map.count(node_type)) {
// uint32_t table_id = this->table_id_map[node_type];
// auto status =
// worker_ptr->set_node_feat(table_id, node_ids, feature_names,
// features);
// status.wait();
// }
return
;
}
std
::
vector
<
FeatureNode
>
GraphPyClient
::
pull_graph_list
(
std
::
string
name
,
int
server_index
,
int
start
,
int
size
,
int
step
)
{
std
::
vector
<
FeatureNode
>
res
;
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = worker_ptr->pull_graph_list(table_id, server_index, start,
// size, step, res);
// status.wait();
// }
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
pull_graph_list
(
0
,
1
,
idx
,
server_index
,
start
,
size
,
step
,
res
);
status
.
wait
();
}
else
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
pull_graph_list
(
0
,
0
,
idx
,
server_index
,
start
,
size
,
step
,
res
);
status
.
wait
();
}
return
res
;
}
void
GraphPyClient
::
StopServer
()
{
VLOG
(
0
)
<<
"going to stop server"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
stoped_
)
return
;
auto
status
=
this
->
worker_ptr
->
StopServer
();
if
(
status
.
get
()
==
0
)
stoped_
=
true
;
}
void
GraphPyClient
::
FinalizeWorker
()
{
this
->
worker_ptr
->
FinalizeWorker
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <vector>
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
distributed
{
class
GraphPyService
{
protected:
std
::
vector
<
std
::
string
>
server_list
,
port_list
,
host_sign_list
;
int
server_size
,
shard_num
;
int
num_node_types
;
std
::
unordered_map
<
std
::
string
,
int
>
edge_to_id
,
feature_to_id
;
std
::
vector
<
std
::
string
>
id_to_feature
,
id_to_edge
;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
int
>>
table_feat_mapping
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
table_feat_conf_feat_name
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
table_feat_conf_feat_dtype
;
std
::
vector
<
std
::
vector
<
int
>>
table_feat_conf_feat_shape
;
public:
int
get_shard_num
()
{
return
shard_num
;
}
void
set_shard_num
(
int
shard_num
)
{
this
->
shard_num
=
shard_num
;
}
void
GetDownpourSparseTableProto
(
::
paddle
::
distributed
::
TableParameter
*
sparse_table_proto
)
{
sparse_table_proto
->
set_table_id
(
0
);
sparse_table_proto
->
set_table_class
(
"GraphTable"
);
sparse_table_proto
->
set_shard_num
(
shard_num
);
sparse_table_proto
->
set_type
(
::
paddle
::
distributed
::
PS_SPARSE_TABLE
);
::
paddle
::
distributed
::
TableAccessorParameter
*
accessor_proto
=
sparse_table_proto
->
mutable_accessor
();
// ::paddle::distributed::CommonAccessorParameter* common_proto =
// sparse_table_proto->mutable_common();
::
paddle
::
distributed
::
GraphParameter
*
graph_proto
=
sparse_table_proto
->
mutable_graph_parameter
();
// ::paddle::distributed::GraphFeature* graph_feature =
// graph_proto->mutable_graph_feature();
graph_proto
->
set_task_pool_size
(
24
);
graph_proto
->
set_table_name
(
"cpu_graph_table"
);
graph_proto
->
set_use_cache
(
false
);
for
(
size_t
i
=
0
;
i
<
id_to_edge
.
size
();
i
++
)
graph_proto
->
add_edge_types
(
id_to_edge
[
i
]);
for
(
size_t
i
=
0
;
i
<
id_to_feature
.
size
();
i
++
)
{
graph_proto
->
add_node_types
(
id_to_feature
[
i
]);
auto
feat_node
=
id_to_feature
[
i
];
::
paddle
::
distributed
::
GraphFeature
*
g_f
=
graph_proto
->
add_graph_feature
();
for
(
size_t
x
=
0
;
x
<
table_feat_conf_feat_name
[
i
].
size
();
x
++
)
{
g_f
->
add_name
(
table_feat_conf_feat_name
[
i
][
x
]);
g_f
->
add_dtype
(
table_feat_conf_feat_dtype
[
i
][
x
]);
g_f
->
add_shape
(
table_feat_conf_feat_shape
[
i
][
x
]);
}
}
// Set GraphTable Parameter
// common_proto->set_table_name(table_name);
// common_proto->set_name(table_type);
// for (size_t i = 0; i < feat_name.size(); i++) {
// common_proto->add_params(feat_dtype[i]);
// common_proto->add_dims(feat_shape[i]);
// common_proto->add_attributes(feat_name[i]);
// }
// for (size_t i = 0; i < feat_name.size(); i++) {
// graph_feature->add_dtype(feat_dtype[i]);
// graph_feature->add_shape(feat_shape[i]);
// graph_feature->add_name(feat_name[i]);
// }
accessor_proto
->
set_accessor_class
(
"CommMergeAccessor"
);
}
void
set_server_size
(
int
server_size
)
{
this
->
server_size
=
server_size
;
}
void
set_num_node_types
(
int
num_node_types
)
{
this
->
num_node_types
=
num_node_types
;
}
int
get_server_size
(
int
server_size
)
{
return
server_size
;
}
std
::
vector
<
std
::
string
>
split
(
std
::
string
&
str
,
const
char
pattern
);
void
set_up
(
std
::
string
ips_str
,
int
shard_num
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
edge_types
);
void
add_table_feat_conf
(
std
::
string
node_type
,
std
::
string
feat_name
,
std
::
string
feat_dtype
,
int32_t
feat_shape
);
};
class
GraphPyServer
:
public
GraphPyService
{
public:
GraphPyServer
()
{}
void
set_up
(
std
::
string
ips_str
,
int
shard_num
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
edge_types
,
int
rank
)
{
set_rank
(
rank
);
GraphPyService
::
set_up
(
ips_str
,
shard_num
,
node_types
,
edge_types
);
}
int
GetRank
()
{
return
rank
;
}
void
set_rank
(
int
rank
)
{
this
->
rank
=
rank
;
}
void
start_server
(
bool
block
=
true
);
::
paddle
::
distributed
::
PSParameter
GetServerProto
();
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcServer
>
get_ps_server
()
{
return
pserver_ptr
;
}
protected:
int
rank
;
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcServer
>
pserver_ptr
;
std
::
thread
*
server_thread
;
};
class
GraphPyClient
:
public
GraphPyService
{
public:
void
set_up
(
std
::
string
ips_str
,
int
shard_num
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
edge_types
,
int
client_id
)
{
set_client_id
(
client_id
);
GraphPyService
::
set_up
(
ips_str
,
shard_num
,
node_types
,
edge_types
);
}
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>
get_ps_client
()
{
return
worker_ptr
;
}
void
bind_local_server
(
int
local_channel_index
,
GraphPyServer
&
server
)
{
worker_ptr
->
set_local_channel
(
local_channel_index
);
worker_ptr
->
set_local_graph_service
(
(
paddle
::
distributed
::
GraphBrpcService
*
)
server
.
get_ps_server
()
->
get_service
());
}
void
StopServer
();
void
FinalizeWorker
();
void
load_edge_file
(
std
::
string
name
,
std
::
string
filepath
,
bool
reverse
);
void
load_node_file
(
std
::
string
name
,
std
::
string
filepath
);
void
clear_nodes
(
std
::
string
name
);
void
add_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
,
std
::
vector
<
bool
>&
weight_list
);
void
remove_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
);
int
get_client_id
()
{
return
client_id
;
}
void
set_client_id
(
int
client_id
)
{
this
->
client_id
=
client_id
;
}
void
start_client
();
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
batch_sample_neighbors
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
bool
return_weight
,
bool
return_edges
);
std
::
vector
<
int64_t
>
random_sample_nodes
(
std
::
string
name
,
int
server_index
,
int
sample_size
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
get_node_feat
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
);
void
set_node_feat
(
std
::
string
node_type
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
);
std
::
vector
<
FeatureNode
>
pull_graph_list
(
std
::
string
name
,
int
server_index
,
int
start
,
int
size
,
int
step
=
1
);
::
paddle
::
distributed
::
PSParameter
GetWorkerProto
();
protected:
mutable
std
::
mutex
mutex_
;
int
client_id
;
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>
worker_ptr
;
std
::
thread
*
client_thread
;
bool
stoped_
=
false
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_service/service.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include <fcntl.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <iostream>
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include "paddle/fluid/string/string_helper.h"
using
namespace
std
;
// NOLINT
namespace
paddle
{
namespace
distributed
{
paddle
::
distributed
::
PSParameter
load_from_prototxt
(
const
std
::
string
&
filename
)
{
paddle
::
distributed
::
PSParameter
param
;
int
file_descriptor
=
open
(
filename
.
c_str
(),
O_RDONLY
);
if
(
file_descriptor
==
-
1
)
{
VLOG
(
3
)
<<
"FATAL: fail to parse "
<<
filename
;
exit
(
-
1
);
}
google
::
protobuf
::
io
::
FileInputStream
fileInput
(
file_descriptor
);
if
(
!
google
::
protobuf
::
TextFormat
::
Parse
(
&
fileInput
,
&
param
))
{
VLOG
(
3
)
<<
"FATAL: fail to parse "
<<
filename
;
exit
(
-
1
);
}
close
(
file_descriptor
);
return
param
;
}
void
PSCore
::
InitGFlag
(
const
std
::
string
&
gflags
)
{
VLOG
(
3
)
<<
"Init With Gflags:"
<<
gflags
;
std
::
vector
<
std
::
string
>
flags
=
paddle
::
string
::
split_string
(
gflags
);
if
(
flags
.
size
()
<
1
)
{
flags
.
push_back
(
"-max_body_size=314217728"
);
flags
.
push_back
(
"-socket_max_unwritten_bytes=2048000000"
);
flags
.
push_back
(
"-max_connection_pool_size=1950"
);
}
auto
it
=
flags
.
begin
();
flags
.
insert
(
it
,
"exe default"
);
char
*
flags_ptr
[
flags
.
size
()];
for
(
size_t
i
=
0
;
i
<
flags
.
size
();
++
i
)
{
flags_ptr
[
i
]
=
(
char
*
)(
flags
[
i
].
c_str
());
// NOLINT
}
int
params_cnt
=
flags
.
size
();
char
**
params_ptr
=
&
(
flags_ptr
[
0
]);
::
GFLAGS_NAMESPACE
::
ParseCommandLineFlags
(
&
params_cnt
,
&
params_ptr
,
true
);
}
int
PSCore
::
InitServer
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>*
host_sign_list
,
int
node_num
,
int
index
,
int
trainers
,
const
std
::
vector
<
framework
::
ProgramDesc
>&
server_sub_program
)
{
google
::
protobuf
::
TextFormat
::
ParseFromString
(
dist_desc
,
&
_ps_param
);
InitGFlag
(
_ps_param
.
init_gflags
());
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
SetPsServers
(
host_sign_list
,
node_num
);
_ps_env
.
SetTrainers
(
trainers
);
int
ret
=
0
;
_server_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSServer
>
(
paddle
::
distributed
::
PSServerFactory
::
Create
(
_ps_param
));
ret
=
_server_ptr
->
Configure
(
_ps_param
,
_ps_env
,
index
,
server_sub_program
);
CHECK
(
ret
==
0
)
<<
"failed to configure server"
;
return
ret
;
}
int
PSCore
::
InitWorker
(
const
std
::
string
&
dist_desc
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>&
regions
,
const
std
::
vector
<
std
::
string
>*
host_sign_list
,
int
node_num
,
int
index
)
{
google
::
protobuf
::
TextFormat
::
ParseFromString
(
dist_desc
,
&
_ps_param
);
InitGFlag
(
_ps_param
.
init_gflags
());
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
SetPsServers
(
host_sign_list
,
node_num
);
int
ret
=
0
;
VLOG
(
1
)
<<
"PSCore::InitWorker"
;
auto
*
communicator
=
Communicator
::
GetInstance
();
ret
=
communicator
->
GetPsClient
()
->
Configure
(
_ps_param
,
regions
,
_ps_env
,
index
);
communicator
->
Start
();
return
ret
;
}
std
::
vector
<
uint64_t
>
PSCore
::
GetClientInfo
()
{
return
_ps_env
.
GetClientInfo
();
}
int
PSCore
::
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
{
int
ret
=
_worker_ptr
->
CreateClient2ClientConnection
(
pserver_timeout_ms
,
pserver_connect_timeout_ms
,
max_retry
);
return
ret
;
}
uint64_t
PSCore
::
RunServer
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
return
_server_ptr
->
Start
(
ip
,
port
);
}
int
PSCore
::
FinalizeWorker
()
{
_worker_ptr
->
FinalizeWorker
();
return
0
;
}
int
PSCore
::
StopServer
()
{
auto
stop_status
=
_worker_ptr
->
StopServer
();
stop_status
.
wait
();
return
0
;
}
paddle
::
distributed
::
PSParameter
*
PSCore
::
GetParam
()
{
return
&
_ps_param
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_service/service.h
0 → 100644
View file @
d2d32668
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/service/server.h"
namespace
paddle
{
namespace
distributed
{
class
PSClient
;
class
PSServer
;
class
PsRequestMessage
;
class
PsResponseMessage
;
class
PsService
;
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
using
paddle
::
distributed
::
PsService
;
class
PSCore
{
public:
explicit
PSCore
()
{}
virtual
~
PSCore
()
{}
virtual
int
InitServer
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>*
host_sign_list
,
int
node_num
,
int
index
,
int
trainers
,
const
std
::
vector
<
framework
::
ProgramDesc
>&
server_sub_program
=
{});
virtual
int
InitWorker
(
const
std
::
string
&
dist_desc
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>&
regions
,
const
std
::
vector
<
std
::
string
>*
host_sign_list
,
int
node_num
,
int
index
);
virtual
uint64_t
RunServer
(
const
std
::
string
&
ip
,
uint32_t
port
);
virtual
int
StopServer
();
virtual
int
FinalizeWorker
();
virtual
std
::
vector
<
uint64_t
>
GetClientInfo
();
virtual
int
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
);
std
::
shared_ptr
<
paddle
::
distributed
::
PSServer
>
_server_ptr
;
// pointer to server
std
::
shared_ptr
<
paddle
::
distributed
::
PSClient
>
_worker_ptr
;
// pointer to worker
virtual
paddle
::
distributed
::
PSParameter
*
GetParam
();
private:
void
InitGFlag
(
const
std
::
string
&
gflags
);
paddle
::
distributed
::
PSParameter
_ps_param
;
paddle
::
distributed
::
PaddlePSEnvironment
_ps_env
;
};
}
// namespace distributed
}
// namespace paddle
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment